review_protocol/
client.rs

1//! Client-specific protocol implementation.
2
3#[cfg(feature = "client")]
4mod api;
5
6#[cfg(any(feature = "client", all(test, feature = "server")))]
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
8
9#[cfg(any(feature = "client", feature = "server"))]
10use num_enum::{FromPrimitive, IntoPrimitive};
11#[cfg(any(feature = "client", all(test, feature = "server")))]
12use oinq::frame::{self};
13#[cfg(feature = "client")]
14pub use oinq::message::{send_err, send_ok, send_request};
15
16#[cfg(any(feature = "client", all(test, feature = "server")))]
17use crate::AgentInfo;
18
19/// Numeric representation of the message types that a client should handle.
20#[cfg(any(feature = "client", feature = "server"))]
21#[derive(Clone, Copy, Debug, Eq, FromPrimitive, IntoPrimitive, PartialEq)]
22#[repr(u32)]
23pub(crate) enum RequestCode {
24    /// Start DNS filtering
25    DnsStart = 1,
26
27    /// Stop DNS filtering
28    DnsStop = 2,
29
30    /// Reboot the host
31    Reboot = 4,
32
33    /// Reload the configuration
34    ReloadConfig = 6,
35
36    /// Fetch the TI database and reload it
37    ReloadTi = 5,
38
39    /// Collect resource usage stats
40    ResourceUsage = 7,
41
42    /// Update the list of tor exit nodes
43    TorExitNodeList = 8,
44
45    /// Update the list of sampling policies
46    SamplingPolicyList = 9,
47
48    /// Update traffic filter rules
49    ReloadFilterRule = 10,
50
51    /// Update Configuration
52    UpdateConfig = 12,
53
54    /// Delete the list of sampling policies
55    DeleteSamplingPolicy = 13,
56
57    /// Update the list of Internal network
58    InternalNetworkList = 14,
59
60    /// Update the list of allow
61    Allowlist = 15,
62
63    /// Update the list of block
64    Blocklist = 16,
65
66    /// Request Echo (for ping)
67    EchoRequest = 17,
68
69    /// Update the list of trusted User-agent
70    TrustedUserAgentList = 18,
71
72    /// Update the list of trusted domains
73    TrustedDomainList = 0,
74
75    /// Collect process list
76    ProcessList = 19,
77
78    /// Update the semi-supervised models
79    SemiSupervisedModels = 20,
80
81    /// Shutdown the host
82    Shutdown = 21,
83
84    /// Unknown request
85    #[num_enum(default)]
86    Unknown = u32::MAX,
87}
88
89#[cfg(feature = "client")]
90/// A builder for creating a new endpoint.
91#[derive(Debug)]
92pub struct ConnectionBuilder {
93    remote_name: String,
94    remote_addr: SocketAddr,
95    local_addr: IpAddr,
96    app_name: String,
97    app_version: String,
98    protocol_version: String,
99    status: crate::Status,
100    roots: rustls::RootCertStore,
101    cert: rustls::pki_types::CertificateDer<'static>,
102    key: rustls::pki_types::PrivateKeyDer<'static>,
103}
104
105#[cfg(feature = "client")]
106impl ConnectionBuilder {
107    /// Creates a new builder with the remote address, certificate, and key.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the certificate or key is invalid.
112    #[allow(clippy::too_many_arguments)]
113    pub fn new(
114        remote_name: &str,
115        remote_addr: SocketAddr,
116        app_name: &str,
117        app_version: &str,
118        protocol_version: &str,
119        status: crate::Status,
120        cert: &[u8],
121        key: &[u8],
122    ) -> std::io::Result<Self> {
123        let local_addr = if remote_addr.is_ipv6() {
124            IpAddr::V6(Ipv6Addr::UNSPECIFIED)
125        } else {
126            IpAddr::V4(Ipv4Addr::UNSPECIFIED)
127        };
128        let cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert))
129            .next()
130            .ok_or_else(|| {
131                std::io::Error::new(std::io::ErrorKind::InvalidData, "no certificate")
132            })??;
133        let key = rustls_pemfile::private_key(&mut std::io::Cursor::new(key))
134            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
135            .ok_or_else(|| {
136                std::io::Error::new(std::io::ErrorKind::InvalidData, "no private key")
137            })?;
138        Ok(Self {
139            remote_name: remote_name.to_string(),
140            remote_addr,
141            local_addr,
142            app_name: app_name.to_string(),
143            app_version: app_version.to_string(),
144            protocol_version: protocol_version.to_string(),
145            status,
146            roots: rustls::RootCertStore::empty(),
147            cert,
148            key,
149        })
150    }
151
152    /// Sets the certificate for the connection.
153    ///
154    /// # Errors
155    ///
156    /// Returns an error if the certificate is invalid.
157    pub fn cert(&mut self, cert: &[u8]) -> std::io::Result<&mut Self> {
158        self.cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert))
159            .next()
160            .ok_or_else(|| {
161                std::io::Error::new(std::io::ErrorKind::InvalidData, "no certificate")
162            })??;
163        Ok(self)
164    }
165
166    /// Sets the private key for the connection.
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if the key is invalid.
171    pub fn key(&mut self, key: &[u8]) -> std::io::Result<&mut Self> {
172        self.key = rustls_pemfile::private_key(&mut std::io::Cursor::new(key))
173            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
174            .ok_or_else(|| {
175                std::io::Error::new(std::io::ErrorKind::InvalidData, "no private key")
176            })?;
177        Ok(self)
178    }
179
180    /// Sets the root certificates for the connection.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if any of the certificates are invalid.
185    pub fn root_certs<I>(&mut self, certs: I) -> std::io::Result<&mut Self>
186    where
187        I: IntoIterator,
188        I::Item: AsRef<[u8]>,
189    {
190        self.roots = rustls::RootCertStore::empty();
191        for cert in certs {
192            let cert = rustls_pemfile::certs(&mut std::io::Cursor::new(cert.as_ref()))
193                .next()
194                .ok_or_else(|| {
195                    std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid certificate")
196                })??;
197            self.roots
198                .add(cert)
199                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
200        }
201        Ok(self)
202    }
203
204    /// Adds root certificates to the certificate store.
205    ///
206    /// It reads certificates from the given reader, filtering out any PEM
207    /// sections.
208    ///
209    /// # Errors
210    ///
211    /// Returns an error if the reader is invalid or the certificates are
212    /// invalid.
213    pub fn add_root_certs(&mut self, rd: &mut dyn std::io::BufRead) -> std::io::Result<&mut Self> {
214        for cert in rustls_pemfile::certs(rd) {
215            let cert = cert?;
216            self.roots
217                .add(cert)
218                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
219        }
220        Ok(self)
221    }
222
223    /// Sets the local address to bind to.
224    ///
225    /// This is only necessary if the unspecified address (:: for IPv6 and
226    /// 0.0.0.0 for IPv4) is not desired.
227    pub fn local_addr(&mut self, addr: IpAddr) -> &mut Self {
228        self.local_addr = addr;
229        self
230    }
231
232    /// Connects to the server and performs a handshake.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the connection fails or the server requires a different
237    /// protocol version.
238    #[cfg(feature = "client")]
239    pub async fn connect(&self) -> std::io::Result<Connection> {
240        use std::io;
241
242        let endpoint = self.build_endpoint()?;
243        let connecting = endpoint
244            .connect(self.remote_addr, &self.remote_name)
245            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
246        let connection = connecting.await.map_err(|e| {
247            // quinn (as of 0.11) provides automatic conversion from
248            // `ConnectionError` to `ReadError`, and from `ReadError` to
249            // `io::Error`. However, the conversion treats all `ConnectionError`
250            // variants as `NotConnected`, which is too generic. We need to provide
251            // more specific error messages.
252            use quinn::ConnectionError;
253            match e {
254                ConnectionError::ApplicationClosed(e) => {
255                    std::io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
256                }
257                ConnectionError::CidsExhausted => io::Error::other("connection IDs exhausted"),
258                ConnectionError::ConnectionClosed(e) => {
259                    std::io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
260                }
261                ConnectionError::LocallyClosed => {
262                    io::Error::new(io::ErrorKind::NotConnected, "locally closed")
263                }
264                ConnectionError::Reset => io::Error::from(io::ErrorKind::ConnectionReset),
265                ConnectionError::TimedOut => io::Error::from(io::ErrorKind::TimedOut),
266                ConnectionError::TransportError(e) => {
267                    std::io::Error::new(io::ErrorKind::InvalidData, e.to_string())
268                }
269                ConnectionError::VersionMismatch => {
270                    io::Error::new(io::ErrorKind::ConnectionRefused, "version mismatch")
271                }
272            }
273        })?;
274
275        // A placeholder for the address of this agent. Will be replaced by the
276        // server.
277        //
278        // TODO: This is unnecessary in handshake, and thus should be removed in the
279        // future.
280        let addr = if connection.remote_address().is_ipv6() {
281            SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
282        } else {
283            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
284        };
285
286        let agent_info = AgentInfo {
287            app_name: self.app_name.clone(),
288            version: self.app_version.clone(),
289            protocol_version: self.protocol_version.clone(),
290            status: self.status,
291            addr,
292        };
293
294        let (mut send, mut recv) = connection.open_bi().await?;
295        let mut buf = Vec::new();
296        frame::send(&mut send, &mut buf, &agent_info).await?;
297        match frame::recv::<Result<&str, &str>>(&mut recv, &mut buf).await? {
298            Ok(_) => Ok(Connection {
299                endpoint,
300                connection,
301            }),
302            Err(e) => Err(io::Error::new(
303                io::ErrorKind::ConnectionRefused,
304                format!("server requires protocol version {e}"),
305            )),
306        }
307    }
308
309    /// Creates a new endpoint.
310    ///
311    /// # Errors
312    ///
313    /// Returns an error if the stored TLS configuration is invalid.
314    fn build_endpoint(&self) -> std::io::Result<quinn::Endpoint> {
315        use std::sync::Arc;
316        use std::time::Duration;
317
318        const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(5_000);
319
320        let tls_cfg = rustls::ClientConfig::builder()
321            .with_root_certificates(self.roots.clone())
322            .with_client_auth_cert(vec![self.cert.clone()], self.key.clone_key())
323            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
324        let mut transport = quinn::TransportConfig::default();
325        transport.keep_alive_interval(Some(KEEP_ALIVE_INTERVAL));
326        let mut config = quinn::ClientConfig::new(Arc::new(
327            quinn::crypto::rustls::QuicClientConfig::try_from(tls_cfg)
328                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?,
329        ));
330        config.transport_config(Arc::new(transport));
331
332        let mut endpoint = quinn::Endpoint::client(SocketAddr::new(self.local_addr, 0))?;
333        endpoint.set_default_client_config(config);
334        Ok(endpoint)
335    }
336}
337
338#[cfg(feature = "client")]
339/// A connection to a server.
340#[derive(Clone, Debug)]
341pub struct Connection {
342    endpoint: quinn::Endpoint,
343    connection: quinn::Connection,
344}
345
346#[cfg(feature = "client")]
347impl Connection {
348    /// Gets the local address of the connection.
349    ///
350    /// # Errors
351    ///
352    /// Returns an error if the call to the underlying
353    /// [`local_addr`](quinn::Connection::local_addr) fails.
354    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
355        self.endpoint.local_addr()
356    }
357
358    /// Gets the remote address of the connection.
359    #[must_use]
360    pub fn remote_addr(&self) -> SocketAddr {
361        self.connection.remote_address()
362    }
363
364    /// If the connection is cloesd, returns the reason; otherwise, returns `None`.
365    #[must_use]
366    pub fn close_reason(&self) -> Option<std::io::Error> {
367        self.connection.close_reason().map(Into::into)
368    }
369
370    /// Initiates an outgoing bidirectional stream.
371    ///
372    /// This directly corresponds to the `open_bi` method of the underlying
373    /// `quinn::Connection`. In the future, this method may be removed in favor
374    /// of this crate's own implementation to provide additional features.
375    #[must_use]
376    pub fn open_bi(&self) -> quinn::OpenBi {
377        self.connection.open_bi()
378    }
379
380    /// Initiates an outgoing unidirectional stream.
381    ///
382    /// This directly corresponds to the `open_uni` method of the underlying
383    /// `quinn::Connection`. In the future, this method may be removed in favor
384    /// of this crate's own implementation to provide additional features.
385    #[must_use]
386    pub fn open_uni(&self) -> quinn::OpenUni {
387        self.connection.open_uni()
388    }
389
390    /// Accepts an incoming bidirectional stream.
391    ///
392    /// This directly corresponds to the `accept_bi` method of the underlying
393    /// `quinn::Connection`. In the future, this method may be removed in favor
394    /// of this crate's own implementation to provide additional features.
395    #[must_use]
396    pub fn accept_bi(&self) -> quinn::AcceptBi {
397        self.connection.accept_bi()
398    }
399}
400
401/// Sends a handshake request and processes the response.
402///
403/// # Errors
404///
405/// Returns `HandshakeError` if the handshake failed.
406#[cfg(test)]
407#[cfg(feature = "server")]
408pub(crate) async fn handshake(
409    conn: &quinn::Connection,
410    app_name: &str,
411    app_version: &str,
412    protocol_version: &str,
413    status: crate::Status,
414) -> Result<(), super::HandshakeError> {
415    // A placeholder for the address of this agent. Will be replaced by the
416    // server.
417    //
418    // TODO: This is unnecessary in handshake, and thus should be removed in the
419    // future.
420
421    use crate::{handle_handshake_recv_io_error, handle_handshake_send_io_error};
422    let addr = if conn.remote_address().is_ipv6() {
423        SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
424    } else {
425        SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
426    };
427
428    let agent_info = AgentInfo {
429        app_name: app_name.to_string(),
430        version: app_version.to_string(),
431        protocol_version: protocol_version.to_string(),
432        status,
433        addr,
434    };
435
436    let (mut send, mut recv) = conn.open_bi().await?;
437    let mut buf = Vec::new();
438    frame::send(&mut send, &mut buf, &agent_info)
439        .await
440        .map_err(handle_handshake_send_io_error)?;
441
442    match frame::recv::<Result<&str, &str>>(&mut recv, &mut buf).await {
443        Ok(Ok(_)) => Ok(()),
444        Ok(Err(e)) => Err(super::HandshakeError::IncompatibleProtocol(
445            protocol_version.to_string(),
446            e.to_string(),
447        )),
448        Err(e) => Err(handle_handshake_recv_io_error(e)),
449    }
450}