1#[cfg(feature = "client")]
4mod api;
5
6#[cfg(any(feature = "client", all(test, feature = "server")))]
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
8
9#[cfg(any(feature = "client", feature = "server"))]
10use num_enum::{FromPrimitive, IntoPrimitive};
11#[cfg(any(feature = "client", all(test, feature = "server")))]
12use oinq::frame::{self};
13#[cfg(feature = "client")]
14pub use oinq::message::{send_err, send_ok, send_request};
15
16#[cfg(any(feature = "client", all(test, feature = "server")))]
17use crate::AgentInfo;
18
19#[cfg(any(feature = "client", feature = "server"))]
21#[derive(Clone, Copy, Debug, Eq, FromPrimitive, IntoPrimitive, PartialEq)]
22#[repr(u32)]
23pub(crate) enum RequestCode {
24 DnsStart = 1,
26
27 DnsStop = 2,
29
30 Reboot = 4,
32
33 ReloadConfig = 6,
35
36 ReloadTi = 5,
38
39 ResourceUsage = 7,
41
42 TorExitNodeList = 8,
44
45 SamplingPolicyList = 9,
47
48 ReloadFilterRule = 10,
50
51 UpdateConfig = 12,
53
54 DeleteSamplingPolicy = 13,
56
57 InternalNetworkList = 14,
59
60 Allowlist = 15,
62
63 Blocklist = 16,
65
66 EchoRequest = 17,
68
69 TrustedUserAgentList = 18,
71
72 TrustedDomainList = 0,
74
75 ProcessList = 19,
77
78 SemiSupervisedModels = 20,
80
81 Shutdown = 21,
83
84 #[num_enum(default)]
86 Unknown = u32::MAX,
87}
88
89#[cfg(feature = "client")]
90#[derive(Debug)]
92pub struct ConnectionBuilder {
93 remote_name: String,
94 remote_addr: SocketAddr,
95 local_addr: IpAddr,
96 app_name: String,
97 app_version: String,
98 protocol_version: String,
99 status: crate::Status,
100 roots: rustls::RootCertStore,
101 cert: rustls::pki_types::CertificateDer<'static>,
102 key: rustls::pki_types::PrivateKeyDer<'static>,
103}
104
105#[cfg(feature = "client")]
106impl ConnectionBuilder {
107 #[allow(clippy::too_many_arguments)]
113 pub fn new(
114 remote_name: &str,
115 remote_addr: SocketAddr,
116 app_name: &str,
117 app_version: &str,
118 protocol_version: &str,
119 status: crate::Status,
120 cert: &[u8],
121 key: &[u8],
122 ) -> std::io::Result<Self> {
123 let local_addr = if remote_addr.is_ipv6() {
124 IpAddr::V6(Ipv6Addr::UNSPECIFIED)
125 } else {
126 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
127 };
128 let cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert))
129 .next()
130 .ok_or_else(|| {
131 std::io::Error::new(std::io::ErrorKind::InvalidData, "no certificate")
132 })??;
133 let key = rustls_pemfile::private_key(&mut std::io::Cursor::new(key))
134 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
135 .ok_or_else(|| {
136 std::io::Error::new(std::io::ErrorKind::InvalidData, "no private key")
137 })?;
138 Ok(Self {
139 remote_name: remote_name.to_string(),
140 remote_addr,
141 local_addr,
142 app_name: app_name.to_string(),
143 app_version: app_version.to_string(),
144 protocol_version: protocol_version.to_string(),
145 status,
146 roots: rustls::RootCertStore::empty(),
147 cert,
148 key,
149 })
150 }
151
152 pub fn cert(&mut self, cert: &[u8]) -> std::io::Result<&mut Self> {
158 self.cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert))
159 .next()
160 .ok_or_else(|| {
161 std::io::Error::new(std::io::ErrorKind::InvalidData, "no certificate")
162 })??;
163 Ok(self)
164 }
165
166 pub fn key(&mut self, key: &[u8]) -> std::io::Result<&mut Self> {
172 self.key = rustls_pemfile::private_key(&mut std::io::Cursor::new(key))
173 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
174 .ok_or_else(|| {
175 std::io::Error::new(std::io::ErrorKind::InvalidData, "no private key")
176 })?;
177 Ok(self)
178 }
179
180 pub fn root_certs<I>(&mut self, certs: I) -> std::io::Result<&mut Self>
186 where
187 I: IntoIterator,
188 I::Item: AsRef<[u8]>,
189 {
190 self.roots = rustls::RootCertStore::empty();
191 for cert in certs {
192 let cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert.as_ref()))
193 .next()
194 .ok_or_else(|| {
195 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid certificate")
196 })??;
197 self.roots
198 .add(cert)
199 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
200 }
201 Ok(self)
202 }
203
204 pub fn add_root_certs(&mut self, rd: &mut dyn std::io::BufRead) -> std::io::Result<&mut Self> {
214 for cert in rustls_pemfile::certs(rd) {
215 let cert = cert?;
216 self.roots
217 .add(cert)
218 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
219 }
220 Ok(self)
221 }
222
223 pub fn local_addr(&mut self, addr: IpAddr) -> &mut Self {
228 self.local_addr = addr;
229 self
230 }
231
232 #[cfg(feature = "client")]
239 pub async fn connect(&self) -> std::io::Result<Connection> {
240 use std::io;
241
242 let endpoint = self.build_endpoint()?;
243 let connecting = endpoint
244 .connect(self.remote_addr, &self.remote_name)
245 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
246 let connection = connecting.await.map_err(|e| {
247 use quinn::ConnectionError;
253 match e {
254 ConnectionError::ApplicationClosed(e) => {
255 std::io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
256 }
257 ConnectionError::CidsExhausted => io::Error::other("connection IDs exhausted"),
258 ConnectionError::ConnectionClosed(e) => {
259 std::io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
260 }
261 ConnectionError::LocallyClosed => {
262 io::Error::new(io::ErrorKind::NotConnected, "locally closed")
263 }
264 ConnectionError::Reset => io::Error::from(io::ErrorKind::ConnectionReset),
265 ConnectionError::TimedOut => io::Error::from(io::ErrorKind::TimedOut),
266 ConnectionError::TransportError(e) => {
267 std::io::Error::new(io::ErrorKind::InvalidData, e.to_string())
268 }
269 ConnectionError::VersionMismatch => {
270 io::Error::new(io::ErrorKind::ConnectionRefused, "version mismatch")
271 }
272 }
273 })?;
274
275 let addr = if connection.remote_address().is_ipv6() {
281 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
282 } else {
283 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
284 };
285
286 let agent_info = AgentInfo {
287 app_name: self.app_name.clone(),
288 version: self.app_version.clone(),
289 protocol_version: self.protocol_version.clone(),
290 status: self.status,
291 addr,
292 };
293
294 let (mut send, mut recv) = connection.open_bi().await?;
295 let mut buf = Vec::new();
296 frame::send(&mut send, &mut buf, &agent_info).await?;
297 match frame::recv::<Result<&str, &str>>(&mut recv, &mut buf).await? {
298 Ok(_) => Ok(Connection {
299 endpoint,
300 connection,
301 }),
302 Err(e) => Err(io::Error::new(
303 io::ErrorKind::ConnectionRefused,
304 format!("server requires protocol version {e}"),
305 )),
306 }
307 }
308
309 fn build_endpoint(&self) -> std::io::Result<quinn::Endpoint> {
315 use std::sync::Arc;
316 use std::time::Duration;
317
318 const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(5_000);
319
320 let tls_cfg = rustls::ClientConfig::builder()
321 .with_root_certificates(self.roots.clone())
322 .with_client_auth_cert(vec![self.cert.clone()], self.key.clone_key())
323 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
324 let mut transport = quinn::TransportConfig::default();
325 transport.keep_alive_interval(Some(KEEP_ALIVE_INTERVAL));
326 let mut config = quinn::ClientConfig::new(Arc::new(
327 quinn::crypto::rustls::QuicClientConfig::try_from(tls_cfg)
328 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?,
329 ));
330 config.transport_config(Arc::new(transport));
331
332 let mut endpoint = quinn::Endpoint::client(SocketAddr::new(self.local_addr, 0))?;
333 endpoint.set_default_client_config(config);
334 Ok(endpoint)
335 }
336}
337
338#[cfg(feature = "client")]
339#[derive(Clone, Debug)]
341pub struct Connection {
342 endpoint: quinn::Endpoint,
343 connection: quinn::Connection,
344}
345
346#[cfg(feature = "client")]
347impl Connection {
348 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
355 self.endpoint.local_addr()
356 }
357
358 #[must_use]
360 pub fn remote_addr(&self) -> SocketAddr {
361 self.connection.remote_address()
362 }
363
364 #[must_use]
366 pub fn close_reason(&self) -> Option<std::io::Error> {
367 self.connection.close_reason().map(Into::into)
368 }
369
370 #[must_use]
376 pub fn open_bi(&self) -> quinn::OpenBi {
377 self.connection.open_bi()
378 }
379
380 #[must_use]
386 pub fn open_uni(&self) -> quinn::OpenUni {
387 self.connection.open_uni()
388 }
389
390 #[must_use]
396 pub fn accept_bi(&self) -> quinn::AcceptBi {
397 self.connection.accept_bi()
398 }
399}
400
401#[cfg(test)]
407#[cfg(feature = "server")]
408pub(crate) async fn handshake(
409 conn: &quinn::Connection,
410 app_name: &str,
411 app_version: &str,
412 protocol_version: &str,
413 status: crate::Status,
414) -> Result<(), super::HandshakeError> {
415 use crate::{handle_handshake_recv_io_error, handle_handshake_send_io_error};
422 let addr = if conn.remote_address().is_ipv6() {
423 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
424 } else {
425 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
426 };
427
428 let agent_info = AgentInfo {
429 app_name: app_name.to_string(),
430 version: app_version.to_string(),
431 protocol_version: protocol_version.to_string(),
432 status,
433 addr,
434 };
435
436 let (mut send, mut recv) = conn.open_bi().await?;
437 let mut buf = Vec::new();
438 frame::send(&mut send, &mut buf, &agent_info)
439 .await
440 .map_err(handle_handshake_send_io_error)?;
441
442 match frame::recv::<Result<&str, &str>>(&mut recv, &mut buf).await {
443 Ok(Ok(_)) => Ok(()),
444 Ok(Err(e)) => Err(super::HandshakeError::IncompatibleProtocol(
445 protocol_version.to_string(),
446 e.to_string(),
447 )),
448 Err(e) => Err(handle_handshake_recv_io_error(e)),
449 }
450}