Private
Public Access
1
0

client: implements event/updates websocket

This commit is contained in:
2025-05-01 18:07:18 -07:00
parent 13a78ccd47
commit f6ac3b5a58
14 changed files with 561 additions and 67 deletions

View File

@@ -4,13 +4,25 @@ extern crate serde;
use std::{path::PathBuf, str};
use crate::api::AuthenticationStore;
use crate::api::event_socket::EventSocket;
use hyper::{Body, Client, Method, Request, Uri};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::net::TcpStream;
use futures_util::{StreamExt, TryStreamExt};
use futures_util::stream::{SplitStream, SplitSink, TryFilterMap, MapErr, Stream};
use futures_util::stream::Map;
use futures_util::stream::BoxStream;
use std::future::Future;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use crate::{
model::{Conversation, ConversationID, JwtToken, Message, MessageID},
model::{Conversation, ConversationID, JwtToken, Message, MessageID, UpdateItem, Event},
APIInterface
};
@@ -63,6 +75,12 @@ impl From <serde_json::Error> for Error {
}
}
impl From <tungstenite::Error> for Error {
fn from(err: tungstenite::Error) -> Error {
Error::ClientError(err.to_string())
}
}
trait AuthBuilder {
fn with_auth(self, token: &Option<JwtToken>) -> Self;
}
@@ -90,6 +108,58 @@ impl<B> AuthSetting for hyper::http::Request<B> {
}
}
type WebsocketSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>;
type WebsocketStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
pub struct WebsocketEventSocket {
_sink: WebsocketSink,
stream: WebsocketStream,
}
impl WebsocketEventSocket {
pub fn new(socket: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
let (sink, stream) = socket.split();
Self { _sink: sink, stream }
}
}
impl WebsocketEventSocket {
fn raw_update_stream(self) -> impl Stream<Item = Result<Vec<UpdateItem>, Error>> {
self.stream
.map_err(Error::from)
.try_filter_map(|msg| async move {
match msg {
tungstenite::Message::Text(text) => {
serde_json::from_str::<Vec<UpdateItem>>(&text)
.map(Some)
.map_err(Error::from)
}
_ => Ok(None)
}
})
}
}
#[async_trait]
impl EventSocket for WebsocketEventSocket {
type Error = Error;
type EventStream = BoxStream<'static, Result<Event, Error>>;
type UpdateStream = BoxStream<'static, Result<Vec<UpdateItem>, Error>>;
async fn events(self) -> Self::EventStream {
use futures_util::stream::iter;
self.raw_update_stream()
.map_ok(|updates| iter(updates.into_iter().map(|update| Ok(Event::from(update)))))
.try_flatten()
.boxed()
}
async fn raw_updates(self) -> Self::UpdateStream {
self.raw_update_stream().boxed()
}
}
#[async_trait]
impl<K: AuthenticationStore + Send + Sync> APIInterface for HTTPAPIClient<K> {
type Error = Error;
@@ -146,6 +216,44 @@ impl<K: AuthenticationStore + Send + Sync> APIInterface for HTTPAPIClient<K> {
let messages: Vec<Message> = self.request(&endpoint, Method::GET).await?;
Ok(messages)
}
async fn open_event_socket(&mut self) -> Result<WebsocketEventSocket, Self::Error> {
use tungstenite::http::StatusCode;
use tungstenite::handshake::client::Request as TungsteniteRequest;
use tungstenite::handshake::client::generate_key;
let uri = self.uri_for_endpoint("updates", Some(self.websocket_scheme()));
log::debug!("Connecting to websocket: {:?}", uri);
let auth = self.auth_store.get_token().await;
let host = uri.authority().unwrap().host();
let mut request = TungsteniteRequest::builder()
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key())
.uri(uri.to_string())
.body(())
.expect("Unable to build websocket request");
log::debug!("Websocket request: {:?}", request);
if let Some(token) = &auth {
let header_value = token.to_header_value().to_str().unwrap().parse().unwrap(); // ugh
request.headers_mut().insert("Authorization", header_value);
}
let (socket, response) = connect_async(request).await.unwrap();
log::debug!("Websocket connected: {:?}", response.status());
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::ClientError("Websocket connection failed".into()));
}
Ok(WebsocketEventSocket::new(socket))
}
}
impl<K: AuthenticationStore + Send + Sync> HTTPAPIClient<K> {
@@ -157,15 +265,27 @@ impl<K: AuthenticationStore + Send + Sync> HTTPAPIClient<K> {
}
}
fn uri_for_endpoint(&self, endpoint: &str) -> Uri {
fn uri_for_endpoint(&self, endpoint: &str, scheme: Option<&str>) -> Uri {
let mut parts = self.base_url.clone().into_parts();
let root_path: PathBuf = parts.path_and_query.unwrap().path().into();
let path = root_path.join(endpoint);
parts.path_and_query = Some(path.to_str().unwrap().parse().unwrap());
if let Some(scheme) = scheme {
parts.scheme = Some(scheme.parse().unwrap());
}
Uri::try_from(parts).unwrap()
}
fn websocket_scheme(&self) -> &str {
if self.base_url.scheme().unwrap() == "https" {
"wss"
} else {
"ws"
}
}
async fn request<T: DeserializeOwned>(&mut self, endpoint: &str, method: Method) -> Result<T, Error> {
self.request_with_body(endpoint, method, || { Body::empty() }).await
}
@@ -188,7 +308,7 @@ impl<K: AuthenticationStore + Send + Sync> HTTPAPIClient<K> {
{
use hyper::StatusCode;
let uri = self.uri_for_endpoint(endpoint);
let uri = self.uri_for_endpoint(endpoint, None);
log::debug!("Requesting {:?} {:?}", method, uri);
let build_request = move |auth: &Option<JwtToken>| {
@@ -320,4 +440,18 @@ mod test {
let messages = client.get_messages(&conversation.guid, None, None, None).await.unwrap();
assert!(!messages.is_empty());
}
#[tokio::test]
async fn test_updates() {
if !mock_client_is_reachable().await {
log::warn!("Skipping http_client tests (mock server not reachable)");
return;
}
let mut client = local_mock_client();
// We just want to see if the connection is established, we won't wait for any events
let _ = client.open_event_socket().await.unwrap();
assert!(true);
}
}