review_protocol/client/
api.rs

1use std::{collections::HashSet, io};
2
3use serde::{Serialize, de::DeserializeOwned};
4
5use super::Connection;
6use crate::{
7    server,
8    types::{DataSource, DataSourceKey, HostNetworkGroup},
9    unary_request,
10};
11
12/// The client API.
13impl Connection {
14    /// Fetches the configuration from the server.
15    ///
16    /// The format of the configuration is up to the caller to interpret.
17    ///
18    /// # Errors
19    ///
20    /// Returns an error if the request fails or the response is invalid.
21    pub async fn get_config(&self) -> io::Result<String> {
22        let res: Result<String, String> = request(self, server::RequestCode::GetConfig, ()).await?;
23        res.map_err(io::Error::other)
24    }
25
26    /// Fetches the list of allowed networks from the server.
27    ///
28    /// # Errors
29    ///
30    /// Returns an error if the request fails or the response is invalid.
31    pub async fn get_allowlist(&self) -> io::Result<HostNetworkGroup> {
32        let res: Result<HostNetworkGroup, String> =
33            request(self, server::RequestCode::GetAllowlist, ()).await?;
34        res.map_err(io::Error::other)
35    }
36
37    /// Fetches the list of blocked networks from the server.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if the request fails or the response is invalid.
42    pub async fn get_blocklist(&self) -> io::Result<HostNetworkGroup> {
43        let res: Result<HostNetworkGroup, String> =
44            request(self, server::RequestCode::GetBlocklist, ()).await?;
45        res.map_err(io::Error::other)
46    }
47
48    /// Fetches a data source from the server.
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if the request fails or the response is invalid.
53    pub async fn get_data_source(&self, key: &DataSourceKey<'_>) -> io::Result<DataSource> {
54        let res: Result<Option<DataSource>, String> =
55            request(self, server::RequestCode::GetDataSource, key).await?;
56        res.map_err(io::Error::other)
57            .and_then(|res| res.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)))
58    }
59
60    /// Fetches an indicator from the server.
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if the request fails or the response is invalid.
65    pub async fn get_indicator(&self, name: &str) -> io::Result<HashSet<Vec<String>>> {
66        let res: Result<HashSet<Vec<String>>, String> =
67            request(self, server::RequestCode::GetIndicator, name).await?;
68        res.map_err(io::Error::other)
69    }
70
71    /// Fetches the list of internal networks from the server.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the request fails or the response is invalid.
76    pub async fn get_internal_network_list(&self) -> io::Result<HostNetworkGroup> {
77        let res: Result<HostNetworkGroup, String> =
78            request(self, server::RequestCode::GetInternalNetworkList, ()).await?;
79        res.map_err(io::Error::other)
80    }
81
82    /// Fetches the patterns from the threat-intelligence database.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the request fails or the response is invalid.
87    pub async fn get_tidb_patterns(
88        &self,
89        tidbs: &[(&str, &str)],
90    ) -> io::Result<Vec<(String, Option<crate::types::Tidb>)>> {
91        let res: Result<Vec<(String, Option<crate::types::Tidb>)>, String> =
92            request(self, server::RequestCode::GetTidbPatterns, tidbs).await?;
93        res.map_err(io::Error::other)
94    }
95
96    /// Fetches the list of Tor exit nodes from the server.
97    ///
98    /// # Errors
99    ///
100    /// Returns an error if the request fails or the response is invalid.
101    pub async fn get_tor_exit_node_list(&self) -> io::Result<Vec<String>> {
102        let res: Result<Vec<String>, String> =
103            request(self, server::RequestCode::GetTorExitNodeList, ()).await?;
104        res.map_err(io::Error::other)
105    }
106
107    /// Fetches the list of trusted domains from the server.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the request fails or the response is invalid.
112    pub async fn get_trusted_domain_list(&self) -> io::Result<Vec<String>> {
113        let res: Result<Vec<String>, String> =
114            request(self, server::RequestCode::GetTrustedDomainList, ()).await?;
115        res.map_err(io::Error::other)
116    }
117
118    /// Fetches the list of trusted user agents from the server.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the request fails or the response is invalid.
123    pub async fn get_trusted_user_agent_list(&self) -> io::Result<Vec<String>> {
124        let res: Result<Vec<String>, String> =
125            request(self, server::RequestCode::GetTrustedUserAgentList, ()).await?;
126        res.map_err(io::Error::other)
127    }
128
129    /// Fetches the pretrained model from the server.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the request fails or the response is invalid.
134    pub async fn get_pretrained_model(&self, name: &str) -> io::Result<Vec<u8>> {
135        let res: Result<Vec<u8>, String> =
136            request(self, server::RequestCode::GetPretrainedModel, name).await?;
137        res.map_err(io::Error::other)
138    }
139
140    /// Obtains a new certificate from the server.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the request fails or the response is invalid.
145    pub async fn renew_certificate(&self, cert: &[u8]) -> io::Result<(String, String)> {
146        let res: Result<(String, String), String> =
147            request(self, server::RequestCode::RenewCertificate, cert).await?;
148        res.map_err(io::Error::other)
149    }
150}
151
152async fn request<I, O>(conn: &Connection, code: server::RequestCode, input: I) -> io::Result<O>
153where
154    I: Serialize,
155    O: DeserializeOwned,
156{
157    let (mut send, mut recv) = conn.open_bi().await?;
158    unary_request(&mut send, &mut recv, u32::from(code), input).await
159}
160
161#[cfg(all(test, feature = "server"))]
162mod tests {
163    use crate::{
164        server::handle,
165        test::{TEST_ENV, TestServerHandler},
166        types::DataSourceKey,
167    };
168
169    #[tokio::test]
170    async fn get_data_source() {
171        let test_env = TEST_ENV.lock().await;
172        let (server_conn, client_conn) = test_env.setup().await;
173
174        let handler_conn = server_conn.clone();
175        let server_handle = tokio::spawn(async move {
176            let mut handler = TestServerHandler;
177            let (mut send, mut recv) = handler_conn.as_quinn().accept_bi().await.unwrap();
178            handle(&mut handler, &mut send, &mut recv).await?;
179            Ok(()) as std::io::Result<()>
180        });
181
182        let client_res = client_conn.get_data_source(&DataSourceKey::Id(5)).await;
183        assert!(client_res.is_ok());
184        let received_data_source = client_res.unwrap();
185        assert_eq!(received_data_source.name, "name5");
186
187        let server_res = server_handle.await.unwrap();
188        assert!(server_res.is_ok());
189
190        test_env.teardown(&server_conn);
191    }
192
193    #[tokio::test]
194    async fn get_tidb_patterns() {
195        let test_env = TEST_ENV.lock().await;
196        let (server_conn, client_conn) = test_env.setup().await;
197
198        let handler_conn = server_conn.clone();
199        let server_handle = tokio::spawn(async move {
200            let mut handler = TestServerHandler;
201            let (mut send, mut recv) = handler_conn.as_quinn().accept_bi().await.unwrap();
202            handle(&mut handler, &mut send, &mut recv).await?;
203            Ok(()) as std::io::Result<()>
204        });
205
206        let db_names = vec![("db1", "1.0.0"), ("db2", "2.0.0")];
207        let client_res = client_conn.get_tidb_patterns(&db_names).await;
208        assert!(client_res.is_ok());
209        let received_patterns = client_res.unwrap();
210        assert_eq!(received_patterns.len(), db_names.len());
211        assert_eq!(received_patterns[0].0, "db1");
212        assert!(received_patterns[0].1.is_some());
213        assert_eq!(received_patterns[1].0, "db2");
214        assert!(received_patterns[1].1.is_none());
215        let server_res = server_handle.await.unwrap();
216        assert!(server_res.is_ok());
217
218        test_env.teardown(&server_conn);
219    }
220}