review_protocol/server/
handler.rs

1//! Requset handler for the server.
2
3use std::{collections::HashSet, io};
4
5use num_enum::FromPrimitive;
6use oinq::request::parse_args;
7
8use super::RequestCode;
9use crate::types::{DataSource, DataSourceKey, HostNetworkGroup, Tidb};
10
11/// A request handler that can handle a request to the server.
12#[async_trait::async_trait]
13pub trait Handler {
14    async fn get_allowlist(&self) -> Result<HostNetworkGroup, String> {
15        Err("not supported".to_string())
16    }
17
18    async fn get_blocklist(&self) -> Result<HostNetworkGroup, String> {
19        Err("not supported".to_string())
20    }
21
22    async fn get_data_source(
23        &self,
24        _key: &DataSourceKey<'_>,
25    ) -> Result<Option<DataSource>, String> {
26        Err("not supported".to_string())
27    }
28
29    async fn get_indicator(&self, _name: &str) -> Result<HashSet<Vec<String>>, String> {
30        Err("not supported".to_string())
31    }
32
33    async fn get_model_names(&self) -> Result<Vec<String>, String> {
34        Err("not supported".to_string())
35    }
36
37    async fn get_tidb_patterns(
38        &self,
39        _db_names: &[(&str, &str)],
40    ) -> Result<Vec<(String, Option<Tidb>)>, String> {
41        Err("not supported".to_string())
42    }
43
44    async fn get_tor_exit_node_list(&self) -> Result<Vec<String>, String> {
45        Err("not supported".to_string())
46    }
47
48    async fn get_trusted_domain_list(&self) -> Result<Vec<String>, String> {
49        Err("not supported".to_string())
50    }
51
52    async fn get_trusted_user_agent_list(&self) -> Result<Vec<String>, String> {
53        Err("not supported".to_string())
54    }
55}
56
57/// Handles requests to the server.
58///
59/// This handles only a subset of the requests that the server can receive. If
60/// the request is not supported, the request code is returned to the caller.
61///
62/// # Errors
63///
64/// - There was an error reading from the stream.
65/// - There was an error writing to the stream.
66/// - An unknown request code was received.
67/// - The arguments to the request were invalid.
68pub async fn handle<H>(
69    handler: &mut H,
70    send: &mut quinn::SendStream,
71    recv: &mut quinn::RecvStream,
72) -> io::Result<Option<(u32, Vec<u8>)>>
73where
74    H: Handler + Sync,
75{
76    let mut buf = Vec::new();
77    loop {
78        let (code, body) = match oinq::message::recv_request_raw(recv, &mut buf).await {
79            Ok(res) => res,
80            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
81            Err(e) => return Err(e),
82        };
83
84        match RequestCode::from_primitive(code) {
85            RequestCode::GetAllowlist => {
86                parse_args::<()>(body)?;
87                let result = handler.get_allowlist().await;
88                oinq::request::send_response(send, &mut buf, result).await?;
89            }
90            RequestCode::GetBlocklist => {
91                parse_args::<()>(body)?;
92                let result = handler.get_blocklist().await;
93                oinq::request::send_response(send, &mut buf, result).await?;
94            }
95            RequestCode::GetDataSource => {
96                let data_source_key = parse_args::<DataSourceKey>(body)?;
97                let result = handler.get_data_source(&data_source_key).await;
98                oinq::request::send_response(send, &mut buf, result).await?;
99            }
100            RequestCode::GetIndicator => {
101                let name = parse_args::<String>(body)?;
102                let result = handler.get_indicator(&name).await;
103                oinq::request::send_response(send, &mut buf, result).await?;
104            }
105            RequestCode::GetModelNames => {
106                parse_args::<()>(body)?;
107                let result = handler.get_model_names().await;
108                oinq::request::send_response(send, &mut buf, result).await?;
109            }
110            RequestCode::GetTidbPatterns => {
111                let db_names = parse_args::<Vec<(&str, &str)>>(body)?;
112                let result = handler.get_tidb_patterns(&db_names).await;
113                oinq::request::send_response(send, &mut buf, result).await?;
114            }
115            RequestCode::GetTorExitNodeList => {
116                parse_args::<()>(body)?;
117                let result = handler.get_tor_exit_node_list().await;
118                oinq::request::send_response(send, &mut buf, result).await?;
119            }
120            RequestCode::GetTrustedDomainList => {
121                parse_args::<()>(body)?;
122                let result = handler.get_trusted_domain_list().await;
123                oinq::request::send_response(send, &mut buf, result).await?;
124            }
125            RequestCode::GetTrustedUserAgentList => {
126                parse_args::<()>(body)?;
127                let result = handler.get_trusted_user_agent_list().await;
128                oinq::request::send_response(send, &mut buf, result).await?;
129            }
130            RequestCode::Unknown => {
131                oinq::frame::send(
132                    send,
133                    &mut buf,
134                    Err("unknown request code") as Result<(), &str>,
135                )
136                .await?;
137                return Err(io::Error::new(
138                    io::ErrorKind::InvalidData,
139                    "unknown request code",
140                ));
141            }
142            _ => {
143                return Ok(Some((code, body.into())));
144            }
145        }
146    }
147    Ok(None)
148}