review_protocol/client/
api.rs1use std::{collections::HashSet, io};
2
3use serde::{Serialize, de::DeserializeOwned};
4
5use super::Connection;
6use crate::{
7 server,
8 types::{DataSource, DataSourceKey, HostNetworkGroup},
9 unary_request,
10};
11
12impl Connection {
14 pub async fn get_config(&self) -> io::Result<String> {
22 let res: Result<String, String> = request(self, server::RequestCode::GetConfig, ()).await?;
23 res.map_err(io::Error::other)
24 }
25
26 pub async fn get_allowlist(&self) -> io::Result<HostNetworkGroup> {
32 let res: Result<HostNetworkGroup, String> =
33 request(self, server::RequestCode::GetAllowlist, ()).await?;
34 res.map_err(io::Error::other)
35 }
36
37 pub async fn get_blocklist(&self) -> io::Result<HostNetworkGroup> {
43 let res: Result<HostNetworkGroup, String> =
44 request(self, server::RequestCode::GetBlocklist, ()).await?;
45 res.map_err(io::Error::other)
46 }
47
48 pub async fn get_data_source(&self, key: &DataSourceKey<'_>) -> io::Result<DataSource> {
54 let res: Result<Option<DataSource>, String> =
55 request(self, server::RequestCode::GetDataSource, key).await?;
56 res.map_err(io::Error::other)
57 .and_then(|res| res.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)))
58 }
59
60 pub async fn get_indicator(&self, name: &str) -> io::Result<HashSet<Vec<String>>> {
66 let res: Result<HashSet<Vec<String>>, String> =
67 request(self, server::RequestCode::GetIndicator, name).await?;
68 res.map_err(io::Error::other)
69 }
70
71 pub async fn get_internal_network_list(&self) -> io::Result<HostNetworkGroup> {
77 let res: Result<HostNetworkGroup, String> =
78 request(self, server::RequestCode::GetInternalNetworkList, ()).await?;
79 res.map_err(io::Error::other)
80 }
81
82 pub async fn get_tidb_patterns(
88 &self,
89 tidbs: &[(&str, &str)],
90 ) -> io::Result<Vec<(String, Option<crate::types::Tidb>)>> {
91 let res: Result<Vec<(String, Option<crate::types::Tidb>)>, String> =
92 request(self, server::RequestCode::GetTidbPatterns, tidbs).await?;
93 res.map_err(io::Error::other)
94 }
95
96 pub async fn get_tor_exit_node_list(&self) -> io::Result<Vec<String>> {
102 let res: Result<Vec<String>, String> =
103 request(self, server::RequestCode::GetTorExitNodeList, ()).await?;
104 res.map_err(io::Error::other)
105 }
106
107 pub async fn get_trusted_domain_list(&self) -> io::Result<Vec<String>> {
113 let res: Result<Vec<String>, String> =
114 request(self, server::RequestCode::GetTrustedDomainList, ()).await?;
115 res.map_err(io::Error::other)
116 }
117
118 pub async fn get_trusted_user_agent_list(&self) -> io::Result<Vec<String>> {
124 let res: Result<Vec<String>, String> =
125 request(self, server::RequestCode::GetTrustedUserAgentList, ()).await?;
126 res.map_err(io::Error::other)
127 }
128
129 pub async fn get_pretrained_model(&self, name: &str) -> io::Result<Vec<u8>> {
135 let res: Result<Vec<u8>, String> =
136 request(self, server::RequestCode::GetPretrainedModel, name).await?;
137 res.map_err(io::Error::other)
138 }
139
140 pub async fn renew_certificate(&self, cert: &[u8]) -> io::Result<(String, String)> {
146 let res: Result<(String, String), String> =
147 request(self, server::RequestCode::RenewCertificate, cert).await?;
148 res.map_err(io::Error::other)
149 }
150}
151
152async fn request<I, O>(conn: &Connection, code: server::RequestCode, input: I) -> io::Result<O>
153where
154 I: Serialize,
155 O: DeserializeOwned,
156{
157 let (mut send, mut recv) = conn.open_bi().await?;
158 unary_request(&mut send, &mut recv, u32::from(code), input).await
159}
160
161#[cfg(all(test, feature = "server"))]
162mod tests {
163 use crate::{
164 server::handle,
165 test::{TEST_ENV, TestServerHandler},
166 types::DataSourceKey,
167 };
168
169 #[tokio::test]
170 async fn get_data_source() {
171 let test_env = TEST_ENV.lock().await;
172 let (server_conn, client_conn) = test_env.setup().await;
173
174 let handler_conn = server_conn.clone();
175 let server_handle = tokio::spawn(async move {
176 let mut handler = TestServerHandler;
177 let (mut send, mut recv) = handler_conn.as_quinn().accept_bi().await.unwrap();
178 handle(&mut handler, &mut send, &mut recv).await?;
179 Ok(()) as std::io::Result<()>
180 });
181
182 let client_res = client_conn.get_data_source(&DataSourceKey::Id(5)).await;
183 assert!(client_res.is_ok());
184 let received_data_source = client_res.unwrap();
185 assert_eq!(received_data_source.name, "name5");
186
187 let server_res = server_handle.await.unwrap();
188 assert!(server_res.is_ok());
189
190 test_env.teardown(&server_conn);
191 }
192
193 #[tokio::test]
194 async fn get_tidb_patterns() {
195 let test_env = TEST_ENV.lock().await;
196 let (server_conn, client_conn) = test_env.setup().await;
197
198 let handler_conn = server_conn.clone();
199 let server_handle = tokio::spawn(async move {
200 let mut handler = TestServerHandler;
201 let (mut send, mut recv) = handler_conn.as_quinn().accept_bi().await.unwrap();
202 handle(&mut handler, &mut send, &mut recv).await?;
203 Ok(()) as std::io::Result<()>
204 });
205
206 let db_names = vec![("db1", "1.0.0"), ("db2", "2.0.0")];
207 let client_res = client_conn.get_tidb_patterns(&db_names).await;
208 assert!(client_res.is_ok());
209 let received_patterns = client_res.unwrap();
210 assert_eq!(received_patterns.len(), db_names.len());
211 assert_eq!(received_patterns[0].0, "db1");
212 assert!(received_patterns[0].1.is_some());
213 assert_eq!(received_patterns[1].0, "db2");
214 assert!(received_patterns[1].1.is_none());
215 let server_res = server_handle.await.unwrap();
216 assert!(server_res.is_ok());
217
218 test_env.teardown(&server_conn);
219 }
220}