1pub mod client;
2#[cfg(feature = "client")]
3pub mod frame;
4#[cfg(feature = "client")]
5pub mod request;
6pub mod server;
7#[cfg(all(test, any(feature = "client", feature = "server")))]
8mod test;
9pub mod types;
10
11use std::net::SocketAddr;
12
13use serde::{Deserialize, Serialize};
14#[cfg(any(feature = "client", feature = "server"))]
15use thiserror::Error;
16
17use crate::types::Status;
18
19#[cfg(any(feature = "client", feature = "server"))]
21#[derive(Debug, Error)]
22pub enum HandshakeError {
23 #[error("connection closed by peer")]
24 ConnectionClosed,
25 #[error("connection lost")]
26 ConnectionLost(#[from] quinn::ConnectionError),
27 #[error("cannot receive a message: {0}")]
28 ReadError(std::io::Error),
29 #[error("cannot send a message")]
30 WriteError(std::io::Error),
31 #[error("arguments are too long")]
32 MessageTooLarge,
33 #[error("invalid message")]
34 InvalidMessage,
35 #[error("protocol version {0} is not supported; version {1} is required")]
36 IncompatibleProtocol(String, String),
37}
38
39#[cfg(feature = "server")]
40fn handle_handshake_send_io_error(e: std::io::Error) -> HandshakeError {
41 if e.kind() == std::io::ErrorKind::InvalidData {
42 HandshakeError::MessageTooLarge
43 } else {
44 HandshakeError::WriteError(e)
45 }
46}
47
48#[cfg(feature = "server")]
49fn handle_handshake_recv_io_error(e: std::io::Error) -> HandshakeError {
50 match e.kind() {
51 std::io::ErrorKind::InvalidData => HandshakeError::InvalidMessage,
52 std::io::ErrorKind::UnexpectedEof => HandshakeError::ConnectionClosed,
53 _ => HandshakeError::ReadError(e),
54 }
55}
56
57#[derive(Clone, Debug, Deserialize, Serialize)]
59pub struct AgentInfo {
60 pub app_name: String,
61 pub version: String,
62 pub protocol_version: String,
63 pub addr: SocketAddr,
64 pub status: Status,
65}
66
67#[cfg(any(feature = "client", feature = "server"))]
74pub async fn unary_request<I, O>(
75 send: &mut quinn::SendStream,
76 recv: &mut quinn::RecvStream,
77 code: u32,
78 input: I,
79) -> std::io::Result<O>
80where
81 I: serde::Serialize,
82 O: serde::de::DeserializeOwned,
83{
84 let mut buf = vec![];
85 oinq::message::send_request(send, &mut buf, code, input).await?;
86
87 oinq::frame::recv(recv, &mut buf).await
88}
89
90#[cfg(test)]
91mod tests {
92 #[cfg(feature = "server")]
93 use crate::test::{TOKEN, channel};
94
95 #[cfg(feature = "server")]
96 #[tokio::test]
97 async fn handshake() {
98 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
99
100 use crate::Status;
101
102 const APP_NAME: &str = "oinq";
103 const APP_VERSION: &str = "1.0.0";
104 const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
105
106 let _lock = TOKEN.lock().await;
107 let channel = channel().await;
108 let (server, client) = (channel.server, channel.client);
109
110 let handle = tokio::spawn(async move {
111 super::client::handshake(
112 &client.conn,
113 APP_NAME,
114 APP_VERSION,
115 PROTOCOL_VERSION,
116 Status::Ready,
117 )
118 .await
119 });
120
121 let agent_info = super::server::handshake(
122 &server.conn,
123 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
124 PROTOCOL_VERSION,
125 PROTOCOL_VERSION,
126 )
127 .await
128 .unwrap();
129
130 assert_eq!(agent_info.app_name, APP_NAME);
131 assert_eq!(agent_info.version, APP_VERSION);
132 assert_eq!(agent_info.protocol_version, PROTOCOL_VERSION);
133
134 let res = tokio::join!(handle).0.unwrap();
135 assert!(res.is_ok());
136 }
137
138 #[cfg(feature = "server")]
139 #[tokio::test]
140 async fn handshake_version_incompatible_err() {
141 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
142
143 use crate::Status;
144
145 const APP_NAME: &str = "oinq";
146 const APP_VERSION: &str = "1.0.0";
147 const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
148
149 let _lock = TOKEN.lock().await;
150 let channel = channel().await;
151 let (server, client) = (channel.server, channel.client);
152
153 let handle = tokio::spawn(async move {
154 super::client::handshake(
155 &client.conn,
156 APP_NAME,
157 APP_VERSION,
158 PROTOCOL_VERSION,
159 Status::Ready,
160 )
161 .await
162 });
163
164 let res = super::server::handshake(
165 &server.conn,
166 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
167 &format!("<{PROTOCOL_VERSION}"),
168 PROTOCOL_VERSION,
169 )
170 .await;
171
172 assert!(res.is_err());
173
174 let res = tokio::join!(handle).0.unwrap();
175 assert!(res.is_err());
176 }
177
178 #[cfg(feature = "server")]
179 #[tokio::test]
180 async fn handshake_incompatible_err() {
181 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
182
183 use crate::Status;
184
185 const APP_NAME: &str = "oinq";
186 const APP_VERSION: &str = "1.0.0";
187 const PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
188
189 let version_req = semver::VersionReq::parse(&format!(">={PROTOCOL_VERSION}")).unwrap();
190 let mut highest_version = semver::Version::parse(PROTOCOL_VERSION).unwrap();
191 highest_version.patch += 1;
192 let mut protocol_version = highest_version.clone();
193 protocol_version.minor += 1;
194
195 let _lock = TOKEN.lock().await;
196 let channel = channel().await;
197 let (server, client) = (channel.server, channel.client);
198
199 let handle = tokio::spawn(async move {
200 super::client::handshake(
201 &client.conn,
202 APP_NAME,
203 APP_VERSION,
204 &protocol_version.to_string(),
205 Status::Ready,
206 )
207 .await
208 });
209
210 let res = super::server::handshake(
211 &server.conn,
212 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
213 &version_req.to_string(),
214 &highest_version.to_string(),
215 )
216 .await;
217
218 assert!(res.is_err());
219
220 let res = tokio::join!(handle).0.unwrap();
221 assert!(res.is_err());
222 }
223}