review_protocol/
server.rs

1//! Server-specific protocol implementation.
2
3#[cfg(feature = "server")]
4mod api;
5#[cfg(feature = "server")]
6mod handler;
7
8#[cfg(feature = "server")]
9use std::net::SocketAddr;
10
11#[cfg(any(feature = "client", feature = "server"))]
12use num_enum::{FromPrimitive, IntoPrimitive};
13#[cfg(feature = "server")]
14use oinq::{
15    frame,
16    message::{send_err, send_ok},
17};
18#[cfg(feature = "server")]
19use semver::{Version, VersionReq};
20
21#[cfg(feature = "server")]
22pub use self::handler::{Handler, handle};
23#[cfg(feature = "server")]
24use crate::{
25    AgentInfo, HandshakeError, client, handle_handshake_recv_io_error,
26    handle_handshake_send_io_error, types::Tidb,
27};
28
29/// Numeric representation of the message types that a server should handle.
30#[cfg(any(feature = "client", feature = "server"))]
31#[derive(Clone, Copy, Debug, Eq, FromPrimitive, IntoPrimitive, PartialEq)]
32#[repr(u32)]
33pub(crate) enum RequestCode {
34    GetDataSource = 0,
35    GetIndicator = 1,
36    GetMaxEventIdNum = 2,
37    GetModel = 3,
38    GetModelNames = 4,
39    InsertColumnStatistics = 5,
40    InsertModel = 6,
41    InsertTimeSeries = 7,
42    RemoveModel = 8,
43    RemoveOutliers = 9,
44    UpdateClusters = 10,
45    UpdateModel = 11,
46    UpdateOutliers = 12,
47    InsertEventLabels = 13,
48    GetDataSourceList = 14,
49    GetTidbPatterns = 15,
50    InsertDataSource = 20,
51    RenewCertificate = 23,
52    GetTrustedDomainList = 24,
53    GetOutliers = 25,
54    GetTorExitNodeList = 26,
55    GetInternalNetworkList = 31,
56    GetAllowlist = 32,
57    GetBlocklist = 33,
58    GetPretrainedModel = 34,
59    GetTrustedUserAgentList = 35,
60    GetConfig = 36,
61
62    /// Unknown request
63    #[num_enum(default)]
64    Unknown = u32::MAX,
65}
66
67#[cfg(feature = "server")]
68/// A connection from a client.
69#[derive(Clone, Debug)]
70pub struct Connection {
71    conn: quinn::Connection,
72}
73
74#[cfg(feature = "server")]
75impl Connection {
76    /// Creates a new connection from a QUIC connection from the `quinn` crate.
77    #[must_use]
78    pub fn from_quinn(conn: quinn::Connection) -> Self {
79        Self { conn }
80    }
81
82    /// Returns the QUIC connection compatible with the `quinn` crate.
83    ///
84    /// This is for backward compatibility only and will be removed in a future
85    /// release.
86    #[cfg(test)]
87    #[must_use]
88    pub(crate) fn as_quinn(&self) -> &quinn::Connection {
89        &self.conn
90    }
91
92    /// Returns the cryptographic identity of the peer.
93    ///
94    /// This directly corresponds to the `peer_identity` method of the underlying
95    /// `quinn::Connection`. In the future, this method may be removed in favor
96    /// of this crate's own implementation to provide additional features.
97    #[must_use]
98    pub fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
99        self.conn.peer_identity()
100    }
101
102    /// Initiates an outgoing bidirectional stream.
103    ///
104    /// This directly corresponds to the `open_bi` method of the underlying
105    /// `quinn::Connection`. In the future, this method may be removed in favor
106    /// of this crate's own implementation to provide additional features.
107    #[must_use]
108    pub fn open_bi(&self) -> quinn::OpenBi {
109        self.conn.open_bi()
110    }
111
112    #[cfg(test)]
113    pub(crate) fn close(&self) {
114        self.conn.close(0u32.into(), b"");
115    }
116}
117
118#[cfg(feature = "server")]
119/// Processes a handshake message and sends a response.
120///
121/// # Errors
122///
123/// Returns `HandshakeError` if the handshake failed.
124///
125/// # Panics
126///
127/// * panic if it failed to parse version requirement string.
128pub async fn handshake(
129    conn: &quinn::Connection,
130    addr: SocketAddr,
131    version_req: &str,
132    highest_protocol_version: &str,
133) -> Result<AgentInfo, HandshakeError> {
134    let (mut send, mut recv) = conn
135        .accept_bi()
136        .await
137        .map_err(HandshakeError::ConnectionLost)?;
138    let mut buf = Vec::new();
139    let mut agent_info = frame::recv::<AgentInfo>(&mut recv, &mut buf)
140        .await
141        .map_err(handle_handshake_recv_io_error)?;
142    agent_info.addr = addr;
143    let version_req = VersionReq::parse(version_req).expect("valid version requirement");
144    let protocol_version = Version::parse(&agent_info.protocol_version).map_err(|_| {
145        HandshakeError::IncompatibleProtocol(
146            agent_info.protocol_version.clone(),
147            version_req.to_string(),
148        )
149    })?;
150    if version_req.matches(&protocol_version) {
151        let highest_protocol_version =
152            Version::parse(highest_protocol_version).expect("valid semver");
153        if protocol_version <= highest_protocol_version {
154            send_ok(&mut send, &mut buf, highest_protocol_version.to_string())
155                .await
156                .map_err(handle_handshake_send_io_error)?;
157            Ok(agent_info)
158        } else {
159            send_err(&mut send, &mut buf, &highest_protocol_version)
160                .await
161                .map_err(handle_handshake_send_io_error)?;
162            send.finish().ok();
163            Err(HandshakeError::IncompatibleProtocol(
164                protocol_version.to_string(),
165                version_req.to_string(),
166            ))
167        }
168    } else {
169        send_err(&mut send, &mut buf, version_req.to_string())
170            .await
171            .map_err(handle_handshake_send_io_error)?;
172        send.finish().ok();
173        Err(HandshakeError::IncompatibleProtocol(
174            protocol_version.to_string(),
175            version_req.to_string(),
176        ))
177    }
178}
179
180#[cfg(feature = "server")]
181/// Sends patterns from a threat-intelligence database.
182///
183/// # Errors
184///
185/// Returns an error if serialization failed or communication with the client failed.
186#[deprecated(since = "0.8.1", note = "`handle` sends the response")]
187pub async fn respond_with_tidb_patterns(
188    send: &mut quinn::SendStream,
189    patterns: &[(String, Option<Tidb>)],
190) -> anyhow::Result<()> {
191    use anyhow::Context;
192
193    let mut buf = Vec::new();
194    oinq::frame::send(send, &mut buf, Ok(patterns) as Result<_, &str>)
195        .await
196        .context("failed to send response")
197}
198
199#[cfg(feature = "server")]
200/// Sends a list of trusted domains to the client.
201///
202/// # Errors
203///
204/// Returns an error if serialization failed or communication with the client failed.
205#[deprecated(
206    since = "0.8.1",
207    note = "Use Connection::send_trusted_domain_list directly"
208)]
209pub async fn send_trusted_domain_list(
210    conn: &quinn::Connection,
211    list: &[String],
212) -> anyhow::Result<()> {
213    Connection::from_quinn(conn.clone())
214        .send_trusted_domain_list(list)
215        .await
216}
217
218#[cfg(feature = "server")]
219/// Notifies the client that it should update its configuration.
220///
221/// # Errors
222///
223/// Returns an error if serialization failed or communication with the client failed.
224pub async fn notify_config_update(conn: &quinn::Connection) -> anyhow::Result<()> {
225    use anyhow::anyhow;
226
227    let Ok(msg) = bincode::serialize::<u32>(&client::RequestCode::UpdateConfig.into()) else {
228        unreachable!("serialization of u32 into memory buffer should not fail")
229    };
230
231    let (mut send, mut recv) = conn.open_bi().await?;
232    frame::send_raw(&mut send, &msg).await?;
233
234    let mut response = vec![];
235    frame::recv::<Result<(), String>>(&mut recv, &mut response)
236        .await?
237        .map_err(|e| anyhow!(e))
238}
239
240#[cfg(test)]
241mod tests {
242    #[cfg(all(feature = "client", feature = "server"))]
243    #[tokio::test]
244    async fn trusted_domain_list() {
245        use crate::test::TEST_ENV;
246
247        struct Handler {}
248
249        #[async_trait::async_trait]
250        impl crate::request::Handler for Handler {
251            async fn trusted_domain_list(&mut self, domains: &[&str]) -> Result<(), String> {
252                if domains == TRUSTED_DOMAIN_LIST {
253                    Ok(())
254                } else {
255                    Err("unexpected domain list".to_string())
256                }
257            }
258        }
259
260        const TRUSTED_DOMAIN_LIST: &[&str] = &["example.com", "example.org"];
261
262        let test_env = TEST_ENV.lock().await;
263        let (server_conn, client_conn) = test_env.setup().await;
264
265        // Test `server::send_trusted_domain_list`
266        let domains_to_send = TRUSTED_DOMAIN_LIST
267            .iter()
268            .map(|&domain| domain.to_string())
269            .collect::<Vec<_>>();
270
271        let mut handler = Handler {};
272        let handler_conn = client_conn.clone();
273        let client_handle = tokio::spawn(async move {
274            let (mut send, mut recv) = handler_conn.accept_bi().await.unwrap();
275
276            crate::request::handle(&mut handler, &mut send, &mut recv).await
277        });
278        let server_res = server_conn.send_trusted_domain_list(&domains_to_send).await;
279        assert!(server_res.is_ok());
280        let client_res = client_handle.await.unwrap();
281        assert!(client_res.is_ok());
282
283        test_env.teardown(&server_conn);
284    }
285
286    #[cfg(all(feature = "client", feature = "server"))]
287    #[tokio::test]
288    async fn notify_config_update() {
289        use crate::test::TEST_ENV;
290
291        struct Handler {}
292
293        #[async_trait::async_trait]
294        impl crate::request::Handler for Handler {
295            async fn update_config(&mut self) -> Result<(), String> {
296                Ok(())
297            }
298        }
299
300        let test_env = TEST_ENV.lock().await;
301        let (server_conn, client_conn) = test_env.setup().await;
302
303        let mut handler = Handler {};
304        let handler_conn = client_conn.clone();
305        let client_handle = tokio::spawn(async move {
306            let (mut send, mut recv) = handler_conn.accept_bi().await.unwrap();
307
308            crate::request::handle(&mut handler, &mut send, &mut recv).await
309        });
310        let server_res = crate::server::notify_config_update(server_conn.as_quinn()).await;
311        assert!(server_res.is_ok());
312        let client_res = client_handle.await.unwrap();
313        assert!(client_res.is_ok());
314
315        test_env.teardown(&server_conn);
316    }
317}