review_protocol/server/
handler.rs1use std::{collections::HashSet, io};
4
5use num_enum::FromPrimitive;
6use oinq::request::parse_args;
7
8use super::RequestCode;
9use crate::types::{DataSource, DataSourceKey, HostNetworkGroup, Tidb};
10
11#[async_trait::async_trait]
13pub trait Handler {
14 async fn get_allowlist(&self) -> Result<HostNetworkGroup, String> {
15 Err("not supported".to_string())
16 }
17
18 async fn get_blocklist(&self) -> Result<HostNetworkGroup, String> {
19 Err("not supported".to_string())
20 }
21
22 async fn get_data_source(
23 &self,
24 _key: &DataSourceKey<'_>,
25 ) -> Result<Option<DataSource>, String> {
26 Err("not supported".to_string())
27 }
28
29 async fn get_indicator(&self, _name: &str) -> Result<HashSet<Vec<String>>, String> {
30 Err("not supported".to_string())
31 }
32
33 async fn get_model_names(&self) -> Result<Vec<String>, String> {
34 Err("not supported".to_string())
35 }
36
37 async fn get_tidb_patterns(
38 &self,
39 _db_names: &[(&str, &str)],
40 ) -> Result<Vec<(String, Option<Tidb>)>, String> {
41 Err("not supported".to_string())
42 }
43
44 async fn get_tor_exit_node_list(&self) -> Result<Vec<String>, String> {
45 Err("not supported".to_string())
46 }
47
48 async fn get_trusted_domain_list(&self) -> Result<Vec<String>, String> {
49 Err("not supported".to_string())
50 }
51
52 async fn get_trusted_user_agent_list(&self) -> Result<Vec<String>, String> {
53 Err("not supported".to_string())
54 }
55}
56
57pub async fn handle<H>(
69 handler: &mut H,
70 send: &mut quinn::SendStream,
71 recv: &mut quinn::RecvStream,
72) -> io::Result<Option<(u32, Vec<u8>)>>
73where
74 H: Handler + Sync,
75{
76 let mut buf = Vec::new();
77 loop {
78 let (code, body) = match oinq::message::recv_request_raw(recv, &mut buf).await {
79 Ok(res) => res,
80 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
81 Err(e) => return Err(e),
82 };
83
84 match RequestCode::from_primitive(code) {
85 RequestCode::GetAllowlist => {
86 parse_args::<()>(body)?;
87 let result = handler.get_allowlist().await;
88 oinq::request::send_response(send, &mut buf, result).await?;
89 }
90 RequestCode::GetBlocklist => {
91 parse_args::<()>(body)?;
92 let result = handler.get_blocklist().await;
93 oinq::request::send_response(send, &mut buf, result).await?;
94 }
95 RequestCode::GetDataSource => {
96 let data_source_key = parse_args::<DataSourceKey>(body)?;
97 let result = handler.get_data_source(&data_source_key).await;
98 oinq::request::send_response(send, &mut buf, result).await?;
99 }
100 RequestCode::GetIndicator => {
101 let name = parse_args::<String>(body)?;
102 let result = handler.get_indicator(&name).await;
103 oinq::request::send_response(send, &mut buf, result).await?;
104 }
105 RequestCode::GetModelNames => {
106 parse_args::<()>(body)?;
107 let result = handler.get_model_names().await;
108 oinq::request::send_response(send, &mut buf, result).await?;
109 }
110 RequestCode::GetTidbPatterns => {
111 let db_names = parse_args::<Vec<(&str, &str)>>(body)?;
112 let result = handler.get_tidb_patterns(&db_names).await;
113 oinq::request::send_response(send, &mut buf, result).await?;
114 }
115 RequestCode::GetTorExitNodeList => {
116 parse_args::<()>(body)?;
117 let result = handler.get_tor_exit_node_list().await;
118 oinq::request::send_response(send, &mut buf, result).await?;
119 }
120 RequestCode::GetTrustedDomainList => {
121 parse_args::<()>(body)?;
122 let result = handler.get_trusted_domain_list().await;
123 oinq::request::send_response(send, &mut buf, result).await?;
124 }
125 RequestCode::GetTrustedUserAgentList => {
126 parse_args::<()>(body)?;
127 let result = handler.get_trusted_user_agent_list().await;
128 oinq::request::send_response(send, &mut buf, result).await?;
129 }
130 RequestCode::Unknown => {
131 oinq::frame::send(
132 send,
133 &mut buf,
134 Err("unknown request code") as Result<(), &str>,
135 )
136 .await?;
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidData,
139 "unknown request code",
140 ));
141 }
142 _ => {
143 return Ok(Some((code, body.into())));
144 }
145 }
146 }
147 Ok(None)
148}