1#[cfg(feature = "server")]
4mod api;
5#[cfg(feature = "server")]
6mod handler;
7
8#[cfg(feature = "server")]
9use std::net::SocketAddr;
10
11#[cfg(any(feature = "client", feature = "server"))]
12use num_enum::{FromPrimitive, IntoPrimitive};
13#[cfg(feature = "server")]
14use oinq::{
15 frame,
16 message::{send_err, send_ok},
17};
18#[cfg(feature = "server")]
19use semver::{Version, VersionReq};
20
21#[cfg(feature = "server")]
22pub use self::handler::{Handler, handle};
23#[cfg(feature = "server")]
24use crate::{
25 AgentInfo, HandshakeError, client, handle_handshake_recv_io_error,
26 handle_handshake_send_io_error, types::Tidb,
27};
28
29#[cfg(any(feature = "client", feature = "server"))]
31#[derive(Clone, Copy, Debug, Eq, FromPrimitive, IntoPrimitive, PartialEq)]
32#[repr(u32)]
33pub(crate) enum RequestCode {
34 GetDataSource = 0,
35 GetIndicator = 1,
36 GetMaxEventIdNum = 2,
37 GetModel = 3,
38 GetModelNames = 4,
39 InsertColumnStatistics = 5,
40 InsertModel = 6,
41 InsertTimeSeries = 7,
42 RemoveModel = 8,
43 RemoveOutliers = 9,
44 UpdateClusters = 10,
45 UpdateModel = 11,
46 UpdateOutliers = 12,
47 InsertEventLabels = 13,
48 GetDataSourceList = 14,
49 GetTidbPatterns = 15,
50 InsertDataSource = 20,
51 RenewCertificate = 23,
52 GetTrustedDomainList = 24,
53 GetOutliers = 25,
54 GetTorExitNodeList = 26,
55 GetInternalNetworkList = 31,
56 GetAllowlist = 32,
57 GetBlocklist = 33,
58 GetPretrainedModel = 34,
59 GetTrustedUserAgentList = 35,
60 GetConfig = 36,
61
62 #[num_enum(default)]
64 Unknown = u32::MAX,
65}
66
67#[cfg(feature = "server")]
68#[derive(Clone, Debug)]
70pub struct Connection {
71 conn: quinn::Connection,
72}
73
74#[cfg(feature = "server")]
75impl Connection {
76 #[must_use]
78 pub fn from_quinn(conn: quinn::Connection) -> Self {
79 Self { conn }
80 }
81
82 #[cfg(test)]
87 #[must_use]
88 pub(crate) fn as_quinn(&self) -> &quinn::Connection {
89 &self.conn
90 }
91
92 #[must_use]
98 pub fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
99 self.conn.peer_identity()
100 }
101
102 #[must_use]
108 pub fn open_bi(&self) -> quinn::OpenBi {
109 self.conn.open_bi()
110 }
111
112 #[cfg(test)]
113 pub(crate) fn close(&self) {
114 self.conn.close(0u32.into(), b"");
115 }
116}
117
118#[cfg(feature = "server")]
119pub async fn handshake(
129 conn: &quinn::Connection,
130 addr: SocketAddr,
131 version_req: &str,
132 highest_protocol_version: &str,
133) -> Result<AgentInfo, HandshakeError> {
134 let (mut send, mut recv) = conn
135 .accept_bi()
136 .await
137 .map_err(HandshakeError::ConnectionLost)?;
138 let mut buf = Vec::new();
139 let mut agent_info = frame::recv::<AgentInfo>(&mut recv, &mut buf)
140 .await
141 .map_err(handle_handshake_recv_io_error)?;
142 agent_info.addr = addr;
143 let version_req = VersionReq::parse(version_req).expect("valid version requirement");
144 let protocol_version = Version::parse(&agent_info.protocol_version).map_err(|_| {
145 HandshakeError::IncompatibleProtocol(
146 agent_info.protocol_version.clone(),
147 version_req.to_string(),
148 )
149 })?;
150 if version_req.matches(&protocol_version) {
151 let highest_protocol_version =
152 Version::parse(highest_protocol_version).expect("valid semver");
153 if protocol_version <= highest_protocol_version {
154 send_ok(&mut send, &mut buf, highest_protocol_version.to_string())
155 .await
156 .map_err(handle_handshake_send_io_error)?;
157 Ok(agent_info)
158 } else {
159 send_err(&mut send, &mut buf, &highest_protocol_version)
160 .await
161 .map_err(handle_handshake_send_io_error)?;
162 send.finish().ok();
163 Err(HandshakeError::IncompatibleProtocol(
164 protocol_version.to_string(),
165 version_req.to_string(),
166 ))
167 }
168 } else {
169 send_err(&mut send, &mut buf, version_req.to_string())
170 .await
171 .map_err(handle_handshake_send_io_error)?;
172 send.finish().ok();
173 Err(HandshakeError::IncompatibleProtocol(
174 protocol_version.to_string(),
175 version_req.to_string(),
176 ))
177 }
178}
179
180#[cfg(feature = "server")]
181#[deprecated(since = "0.8.1", note = "`handle` sends the response")]
187pub async fn respond_with_tidb_patterns(
188 send: &mut quinn::SendStream,
189 patterns: &[(String, Option<Tidb>)],
190) -> anyhow::Result<()> {
191 use anyhow::Context;
192
193 let mut buf = Vec::new();
194 oinq::frame::send(send, &mut buf, Ok(patterns) as Result<_, &str>)
195 .await
196 .context("failed to send response")
197}
198
199#[cfg(feature = "server")]
200#[deprecated(
206 since = "0.8.1",
207 note = "Use Connection::send_trusted_domain_list directly"
208)]
209pub async fn send_trusted_domain_list(
210 conn: &quinn::Connection,
211 list: &[String],
212) -> anyhow::Result<()> {
213 Connection::from_quinn(conn.clone())
214 .send_trusted_domain_list(list)
215 .await
216}
217
218#[cfg(feature = "server")]
219pub async fn notify_config_update(conn: &quinn::Connection) -> anyhow::Result<()> {
225 use anyhow::anyhow;
226
227 let Ok(msg) = bincode::serialize::<u32>(&client::RequestCode::UpdateConfig.into()) else {
228 unreachable!("serialization of u32 into memory buffer should not fail")
229 };
230
231 let (mut send, mut recv) = conn.open_bi().await?;
232 frame::send_raw(&mut send, &msg).await?;
233
234 let mut response = vec![];
235 frame::recv::<Result<(), String>>(&mut recv, &mut response)
236 .await?
237 .map_err(|e| anyhow!(e))
238}
239
240#[cfg(test)]
241mod tests {
242 #[cfg(all(feature = "client", feature = "server"))]
243 #[tokio::test]
244 async fn trusted_domain_list() {
245 use crate::test::TEST_ENV;
246
247 struct Handler {}
248
249 #[async_trait::async_trait]
250 impl crate::request::Handler for Handler {
251 async fn trusted_domain_list(&mut self, domains: &[&str]) -> Result<(), String> {
252 if domains == TRUSTED_DOMAIN_LIST {
253 Ok(())
254 } else {
255 Err("unexpected domain list".to_string())
256 }
257 }
258 }
259
260 const TRUSTED_DOMAIN_LIST: &[&str] = &["example.com", "example.org"];
261
262 let test_env = TEST_ENV.lock().await;
263 let (server_conn, client_conn) = test_env.setup().await;
264
265 let domains_to_send = TRUSTED_DOMAIN_LIST
267 .iter()
268 .map(|&domain| domain.to_string())
269 .collect::<Vec<_>>();
270
271 let mut handler = Handler {};
272 let handler_conn = client_conn.clone();
273 let client_handle = tokio::spawn(async move {
274 let (mut send, mut recv) = handler_conn.accept_bi().await.unwrap();
275
276 crate::request::handle(&mut handler, &mut send, &mut recv).await
277 });
278 let server_res = server_conn.send_trusted_domain_list(&domains_to_send).await;
279 assert!(server_res.is_ok());
280 let client_res = client_handle.await.unwrap();
281 assert!(client_res.is_ok());
282
283 test_env.teardown(&server_conn);
284 }
285
286 #[cfg(all(feature = "client", feature = "server"))]
287 #[tokio::test]
288 async fn notify_config_update() {
289 use crate::test::TEST_ENV;
290
291 struct Handler {}
292
293 #[async_trait::async_trait]
294 impl crate::request::Handler for Handler {
295 async fn update_config(&mut self) -> Result<(), String> {
296 Ok(())
297 }
298 }
299
300 let test_env = TEST_ENV.lock().await;
301 let (server_conn, client_conn) = test_env.setup().await;
302
303 let mut handler = Handler {};
304 let handler_conn = client_conn.clone();
305 let client_handle = tokio::spawn(async move {
306 let (mut send, mut recv) = handler_conn.accept_bi().await.unwrap();
307
308 crate::request::handle(&mut handler, &mut send, &mut recv).await
309 });
310 let server_res = crate::server::notify_config_update(server_conn.as_quinn()).await;
311 assert!(server_res.is_ok());
312 let client_res = client_handle.await.unwrap();
313 assert!(client_res.is_ok());
314
315 test_env.teardown(&server_conn);
316 }
317}