review_protocol/
lib.rs

1pub mod client;
2#[cfg(feature = "client")]
3pub mod frame;
4#[cfg(feature = "client")]
5pub mod request;
6pub mod server;
7#[cfg(all(test, any(feature = "client", feature = "server")))]
8mod test;
9pub mod types;
10
11use std::net::SocketAddr;
12
13use serde::{Deserialize, Serialize};
14#[cfg(any(feature = "client", feature = "server"))]
15use thiserror::Error;
16
17use crate::types::Status;
18
19/// The error type for a handshake failure.
20#[cfg(any(feature = "client", feature = "server"))]
21#[derive(Debug, Error)]
22pub enum HandshakeError {
23    #[error("connection closed by peer")]
24    ConnectionClosed,
25    #[error("connection lost")]
26    ConnectionLost(#[from] quinn::ConnectionError),
27    #[error("cannot receive a message: {0}")]
28    ReadError(std::io::Error),
29    #[error("cannot send a message")]
30    WriteError(std::io::Error),
31    #[error("arguments are too long")]
32    MessageTooLarge,
33    #[error("invalid message")]
34    InvalidMessage,
35    #[error("protocol version {0} is not supported; version {1} is required")]
36    IncompatibleProtocol(String, String),
37}
38
39#[cfg(feature = "server")]
40fn handle_handshake_send_io_error(e: std::io::Error) -> HandshakeError {
41    if e.kind() == std::io::ErrorKind::InvalidData {
42        HandshakeError::MessageTooLarge
43    } else {
44        HandshakeError::WriteError(e)
45    }
46}
47
48#[cfg(feature = "server")]
49fn handle_handshake_recv_io_error(e: std::io::Error) -> HandshakeError {
50    match e.kind() {
51        std::io::ErrorKind::InvalidData => HandshakeError::InvalidMessage,
52        std::io::ErrorKind::UnexpectedEof => HandshakeError::ConnectionClosed,
53        _ => HandshakeError::ReadError(e),
54    }
55}
56
57/// Properties of an agent.
58#[derive(Clone, Debug, Deserialize, Serialize)]
59pub struct AgentInfo {
60    pub app_name: String,
61    pub version: String,
62    pub protocol_version: String,
63    pub addr: SocketAddr,
64    pub status: Status,
65}
66
67/// Sends a unary request and returns the response.
68///
69/// # Errors
70///
71/// Returns an error if there was a problem sending the request or receiving the
72/// response.
73#[cfg(any(feature = "client", feature = "server"))]
74pub async fn unary_request<I, O>(
75    send: &mut quinn::SendStream,
76    recv: &mut quinn::RecvStream,
77    code: u32,
78    input: I,
79) -> std::io::Result<O>
80where
81    I: serde::Serialize,
82    O: serde::de::DeserializeOwned,
83{
84    let mut buf = vec![];
85    oinq::message::send_request(send, &mut buf, code, input).await?;
86
87    oinq::frame::recv(recv, &mut buf).await
88}
89
90#[cfg(test)]
91mod tests {
92    #[cfg(feature = "server")]
93    use crate::test::{TOKEN, channel};
94
95    #[cfg(feature = "server")]
96    #[tokio::test]
97    async fn handshake() {
98        use std::net::{IpAddr, Ipv4Addr, SocketAddr};
99
100        use crate::Status;
101
102        const APP_NAME: &str = "oinq";
103        const APP_VERSION: &str = "1.0.0";
104        const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
105
106        let _lock = TOKEN.lock().await;
107        let channel = channel().await;
108        let (server, client) = (channel.server, channel.client);
109
110        let handle = tokio::spawn(async move {
111            super::client::handshake(
112                &client.conn,
113                APP_NAME,
114                APP_VERSION,
115                PROTOCOL_VERSION,
116                Status::Ready,
117            )
118            .await
119        });
120
121        let agent_info = super::server::handshake(
122            &server.conn,
123            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
124            PROTOCOL_VERSION,
125            PROTOCOL_VERSION,
126        )
127        .await
128        .unwrap();
129
130        assert_eq!(agent_info.app_name, APP_NAME);
131        assert_eq!(agent_info.version, APP_VERSION);
132        assert_eq!(agent_info.protocol_version, PROTOCOL_VERSION);
133
134        let res = tokio::join!(handle).0.unwrap();
135        assert!(res.is_ok());
136    }
137
138    #[cfg(feature = "server")]
139    #[tokio::test]
140    async fn handshake_version_incompatible_err() {
141        use std::net::{IpAddr, Ipv4Addr, SocketAddr};
142
143        use crate::Status;
144
145        const APP_NAME: &str = "oinq";
146        const APP_VERSION: &str = "1.0.0";
147        const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
148
149        let _lock = TOKEN.lock().await;
150        let channel = channel().await;
151        let (server, client) = (channel.server, channel.client);
152
153        let handle = tokio::spawn(async move {
154            super::client::handshake(
155                &client.conn,
156                APP_NAME,
157                APP_VERSION,
158                PROTOCOL_VERSION,
159                Status::Ready,
160            )
161            .await
162        });
163
164        let res = super::server::handshake(
165            &server.conn,
166            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
167            &format!("<{PROTOCOL_VERSION}"),
168            PROTOCOL_VERSION,
169        )
170        .await;
171
172        assert!(res.is_err());
173
174        let res = tokio::join!(handle).0.unwrap();
175        assert!(res.is_err());
176    }
177
178    #[cfg(feature = "server")]
179    #[tokio::test]
180    async fn handshake_incompatible_err() {
181        use std::net::{IpAddr, Ipv4Addr, SocketAddr};
182
183        use crate::Status;
184
185        const APP_NAME: &str = "oinq";
186        const APP_VERSION: &str = "1.0.0";
187        const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
188
189        let version_req = semver::VersionReq::parse(&format!(">={PROTOCOL_VERSION}")).unwrap();
190        let mut highest_version = semver::Version::parse(PROTOCOL_VERSION).unwrap();
191        highest_version.patch += 1;
192        let mut protocol_version = highest_version.clone();
193        protocol_version.minor += 1;
194
195        let _lock = TOKEN.lock().await;
196        let channel = channel().await;
197        let (server, client) = (channel.server, channel.client);
198
199        let handle = tokio::spawn(async move {
200            super::client::handshake(
201                &client.conn,
202                APP_NAME,
203                APP_VERSION,
204                &protocol_version.to_string(),
205                Status::Ready,
206            )
207            .await
208        });
209
210        let res = super::server::handshake(
211            &server.conn,
212            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
213            &version_req.to_string(),
214            &highest_version.to_string(),
215        )
216        .await;
217
218        assert!(res.is_err());
219
220        let res = tokio::join!(handle).0.unwrap();
221        assert!(res.is_err());
222    }
223}