extern crate hyper; extern crate serde; use std::{path::PathBuf, pin::Pin, str, task::Poll}; use crate::api::event_socket::EventSocket; use crate::api::AuthenticationStore; use bytes::Bytes; use hyper::body::HttpBody; use hyper::{Body, Client, Method, Request, Uri}; use async_trait::async_trait; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_util::io::ReaderStream; use futures_util::stream::{BoxStream, Stream}; use futures_util::task::Context; use futures_util::{StreamExt, TryStreamExt}; use tokio_tungstenite::connect_async; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use crate::{ model::{ Conversation, ConversationID, Event, JwtToken, Message, MessageID, OutgoingMessage, UpdateItem, }, APIInterface, }; type HttpClient = Client; pub struct HTTPAPIClient { pub base_url: Uri, pub auth_store: K, client: HttpClient, } #[derive(Clone, Serialize, Deserialize, Debug)] pub struct Credentials { pub username: String, pub password: String, } #[derive(Debug)] pub enum Error { ClientError(String), HTTPError(hyper::Error), SerdeError(serde_json::Error), DecodeError(String), Unauthorized, } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::HTTPError(ref err) => Some(err), _ => None, } } } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } } impl From for Error { fn from(err: hyper::Error) -> Error { Error::HTTPError(err) } } impl From for Error { fn from(err: serde_json::Error) -> Error { Error::SerdeError(err) } } impl From for Error { fn from(err: tungstenite::Error) -> Error { Error::ClientError(err.to_string()) } } trait AuthBuilder { fn with_auth(self, token: &Option) -> Self; fn with_auth_string(self, token: &Option) -> Self; } impl AuthBuilder for hyper::http::request::Builder { fn with_auth(self, token: &Option) -> Self { if let Some(token) = &token { self.header("Authorization", token.to_header_value()) } else { self } } fn with_auth_string(self, token: &Option) -> Self { if let Some(token) = &token { self.header("Authorization", format!("Bearer: {}", token)) } else { self } } } #[cfg(test)] #[allow(dead_code)] trait AuthSetting { fn authenticate(&mut self, token: &Option); } #[cfg(test)] impl AuthSetting for hyper::http::Request { fn authenticate(&mut self, token: &Option) { if let Some(token) = &token { self.headers_mut() .insert("Authorization", token.to_header_value()); } } } pub struct WebsocketEventSocket { socket: WebSocketStream>, } impl WebsocketEventSocket { pub fn new(socket: WebSocketStream>) -> Self { Self { socket } } } impl WebsocketEventSocket { fn raw_update_stream(self) -> impl Stream, Error>> { let (_, stream) = self.socket.split(); stream .map_err(Error::from) .try_filter_map(|msg| async move { match msg { tungstenite::Message::Text(text) => { serde_json::from_str::>(&text) .map(Some) .map_err(Error::from) } tungstenite::Message::Ping(_) => { // Borrowing issue here with the sink, need to handle pings at the client level (whomever // is consuming these updateitems, should be a union type of updateitem | ping). Ok(None) } tungstenite::Message::Close(_) => { // Connection was closed cleanly Err(Error::ClientError("WebSocket connection closed".into())) } _ => Ok(None), } }) } } #[async_trait] impl EventSocket for WebsocketEventSocket { type Error = Error; type EventStream = BoxStream<'static, Result>; type UpdateStream = BoxStream<'static, Result, Error>>; async fn events(self) -> Self::EventStream { use futures_util::stream::iter; self.raw_update_stream() .map_ok(|updates| iter(updates.into_iter().map(|update| Ok(Event::from(update))))) .try_flatten() .boxed() } async fn raw_updates(self) -> Self::UpdateStream { self.raw_update_stream().boxed() } } pub struct ResponseStream { body: hyper::Body, } impl Stream for ResponseStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.body .poll_next_unpin(cx) .map_err(|e| Error::HTTPError(e)) } } impl From for ResponseStream { fn from(value: hyper::Body) -> Self { ResponseStream { body: value } } } #[async_trait] impl APIInterface for HTTPAPIClient { type Error = Error; type ResponseStream = ResponseStream; async fn get_version(&mut self) -> Result { let version: String = self.deserialized_response("version", Method::GET).await?; Ok(version) } async fn get_conversations(&mut self) -> Result, Self::Error> { let conversations: Vec = self .deserialized_response("conversations", Method::GET) .await?; Ok(conversations) } async fn authenticate(&mut self, credentials: Credentials) -> Result { #[derive(Deserialize, Debug)] struct AuthResponse { jwt: String, } log::debug!("Authenticating with username: {:?}", credentials.username); let body = || -> Body { serde_json::to_string(&credentials).unwrap().into() }; let token: AuthResponse = self .deserialized_response_with_body_retry("authenticate", Method::POST, body, false) .await?; let token = JwtToken::new(&token.jwt).map_err(|e| Error::DecodeError(e.to_string()))?; log::debug!("Saving token: {:?}", token); self.auth_store.set_token(token.to_string()).await; Ok(token) } async fn get_messages( &mut self, conversation_id: &ConversationID, limit: Option, before: Option, after: Option, ) -> Result, Self::Error> { let mut endpoint = format!("messages?guid={}", conversation_id); if let Some(limit_val) = limit { endpoint.push_str(&format!("&limit={}", limit_val)); } if let Some(before_id) = before { endpoint.push_str(&format!("&beforeMessageGUID={}", before_id)); } if let Some(after_id) = after { endpoint.push_str(&format!("&afterMessageGUID={}", after_id)); } let messages: Vec = self.deserialized_response(&endpoint, Method::GET).await?; Ok(messages) } async fn send_message( &mut self, outgoing_message: &OutgoingMessage, ) -> Result { let message: Message = self .deserialized_response_with_body("sendMessage", Method::POST, || { serde_json::to_string(&outgoing_message).unwrap().into() }) .await?; Ok(message) } async fn fetch_attachment_data( &mut self, guid: &String, preview: bool, ) -> Result { let endpoint = format!("attachment?guid={}&preview={}", guid, preview); self.response_with_body_retry(&endpoint, Method::GET, Body::empty, true) .await .map(hyper::Response::into_body) .map(ResponseStream::from) } async fn upload_attachment( &mut self, data: tokio::io::BufReader, filename: &str, ) -> Result where R: tokio::io::AsyncRead + Unpin + Send + Sync + 'static, { #[derive(Deserialize, Debug)] struct UploadAttachmentResponse { #[serde(rename = "fileTransferGUID")] guid: String, } let endpoint = format!("uploadAttachment?filename={}", filename); let mut data_opt = Some(data); let response: UploadAttachmentResponse = self .deserialized_response_with_body_retry( &endpoint, Method::POST, move || { let stream = ReaderStream::new( data_opt.take().expect("Stream already consumed during retry"), ); Body::wrap_stream(stream) }, false, // don't retry auth for streaming body ) .await?; Ok(response.guid) } async fn open_event_socket( &mut self, update_seq: Option, ) -> Result { use tungstenite::handshake::client::generate_key; use tungstenite::handshake::client::Request as TungsteniteRequest; let endpoint = match update_seq { Some(seq) => format!("updates?seq={}", seq), None => "updates".to_string(), }; let uri = self.uri_for_endpoint(&endpoint, Some(self.websocket_scheme())); log::debug!("Connecting to websocket: {:?}", uri); let auth = self.auth_store.get_token().await; let host = uri.authority().unwrap().host(); let mut request = TungsteniteRequest::builder() .header("Host", host) .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", generate_key()) .uri(uri.to_string()) .body(()) .expect("Unable to build websocket request"); match &auth { Some(token) => { request.headers_mut().insert( "Authorization", format!("Bearer: {}", token).parse().unwrap(), ); } None => { log::warn!(target: "websocket", "Proceeding without auth token."); } } log::debug!("Websocket request: {:?}", request); match connect_async(request).await.map_err(Error::from) { Ok((socket, response)) => { log::debug!("Websocket connected: {:?}", response.status()); Ok(WebsocketEventSocket::new(socket)) } Err(e) => match &e { Error::ClientError(ce) => match ce.as_str() { "HTTP error: 401 Unauthorized" | "Unauthorized" => { // Try to authenticate if let Some(credentials) = &self.auth_store.get_credentials().await { log::warn!("Websocket connection failed, attempting to authenticate"); let new_token = self.authenticate(credentials.clone()).await?; self.auth_store.set_token(new_token.to_string()).await; // try again on the next attempt. return Err(Error::Unauthorized); } else { log::error!("Websocket unauthorized, no credentials provided"); return Err(Error::ClientError( "Unauthorized, no credentials provided".into(), )); } } _ => Err(e), }, _ => Err(e), }, } } } impl HTTPAPIClient { pub fn new(base_url: Uri, auth_store: K) -> HTTPAPIClient { HTTPAPIClient { base_url, auth_store, client: Client::new(), } } fn uri_for_endpoint(&self, endpoint: &str, scheme: Option<&str>) -> Uri { let mut parts = self.base_url.clone().into_parts(); let root_path: PathBuf = parts.path_and_query.unwrap().path().into(); let path = root_path.join(endpoint); parts.path_and_query = Some(path.to_str().unwrap().parse().unwrap()); if let Some(scheme) = scheme { parts.scheme = Some(scheme.parse().unwrap()); } Uri::try_from(parts).unwrap() } fn websocket_scheme(&self) -> &str { if self.base_url.scheme().unwrap() == "https" { "wss" } else { "ws" } } async fn deserialized_response( &mut self, endpoint: &str, method: Method, ) -> Result { self.deserialized_response_with_body(endpoint, method, Body::empty) .await } async fn deserialized_response_with_body( &mut self, endpoint: &str, method: Method, body_fn: impl FnMut() -> Body, ) -> Result where T: DeserializeOwned, { self.deserialized_response_with_body_retry(endpoint, method, body_fn, true) .await } async fn deserialized_response_with_body_retry( &mut self, endpoint: &str, method: Method, body_fn: impl FnMut() -> Body, retry_auth: bool, ) -> Result where T: DeserializeOwned, { let response = self .response_with_body_retry(endpoint, method, body_fn, retry_auth) .await?; // Read and parse response body let body = hyper::body::to_bytes(response.into_body()).await?; let parsed: T = match serde_json::from_slice(&body) { Ok(result) => Ok(result), Err(json_err) => { log::error!("Error deserializing JSON: {:?}", json_err); log::error!("Body: {:?}", String::from_utf8_lossy(&body)); // If JSON deserialization fails, try to interpret it as plain text // Unfortunately the server does return things like this... let s = str::from_utf8(&body).map_err(|e| Error::DecodeError(e.to_string()))?; serde_plain::from_str(s).map_err(|_| json_err) } }?; Ok(parsed) } async fn response_with_body_retry( &mut self, endpoint: &str, method: Method, mut body_fn: impl FnMut() -> Body, retry_auth: bool, ) -> Result, Error> { use hyper::StatusCode; let uri = self.uri_for_endpoint(endpoint, None); log::debug!("Requesting {:?} {:?}", method, uri); let mut build_request = |auth: &Option| { let body = body_fn(); Request::builder() .method(&method) .uri(&uri) .with_auth_string(auth) .body(body) .expect("Unable to build request") }; log::trace!("Obtaining token from auth store"); let token = self.auth_store.get_token().await; log::trace!("Token: {:?}", token); let request = build_request(&token); log::trace!("Request: {:?}. Sending request...", request); let mut response = self.client.request(request).await?; log::debug!("-> Response: {:}", response.status()); match response.status() { StatusCode::OK => { /* cool */ } // 401: Unauthorized. Token may have expired or is invalid. Attempt to renew. StatusCode::UNAUTHORIZED => { if !retry_auth { return Err(Error::ClientError("Unauthorized".into())); } if let Some(credentials) = &self.auth_store.get_credentials().await { log::debug!( "Renewing token using credentials: u: {:?}", credentials.username ); let new_token = self.authenticate(credentials.clone()).await?; let request = build_request(&Some(new_token.to_string())); response = self.client.request(request).await?; } else { return Err(Error::ClientError( "Unauthorized, no credentials provided".into(), )); } } // Other errors: bubble up. _ => { let status = response.status(); let body_str = hyper::body::to_bytes(response.into_body()).await?; let message = format!("Request failed ({:}). Response body: {:?}", status, String::from_utf8_lossy(&body_str)); return Err(Error::ClientError(message)); } } Ok(response) } } #[cfg(test)] mod test { use super::*; use crate::api::InMemoryAuthenticationStore; #[cfg(test)] fn local_mock_client() -> HTTPAPIClient { let base_url = "http://localhost:5738".parse().unwrap(); let credentials = Credentials { username: "test".to_string(), password: "test".to_string(), }; HTTPAPIClient::new( base_url, InMemoryAuthenticationStore::new(Some(credentials)), ) } #[cfg(test)] async fn mock_client_is_reachable() -> bool { let mut client = local_mock_client(); let version = client.get_version().await; match version { Ok(_) => true, Err(e) => { log::error!("Mock client error: {:?}", e); false } } } #[tokio::test] async fn test_version() { if !mock_client_is_reachable().await { log::warn!("Skipping http_client tests (mock server not reachable)"); return; } let mut client = local_mock_client(); let version = client.get_version().await.unwrap(); assert!(version.starts_with("KordophoneMock-")); } #[tokio::test] async fn test_conversations() { if !mock_client_is_reachable().await { log::warn!("Skipping http_client tests (mock server not reachable)"); return; } let mut client = local_mock_client(); let conversations = client.get_conversations().await.unwrap(); assert!(!conversations.is_empty()); } #[tokio::test] async fn test_messages() { if !mock_client_is_reachable().await { log::warn!("Skipping http_client tests (mock server not reachable)"); return; } let mut client = local_mock_client(); let conversations = client.get_conversations().await.unwrap(); let conversation = conversations.first().unwrap(); let messages = client .get_messages(&conversation.guid, None, None, None) .await .unwrap(); assert!(!messages.is_empty()); } #[tokio::test] async fn test_updates() { if !mock_client_is_reachable().await { log::warn!("Skipping http_client tests (mock server not reachable)"); return; } let mut client = local_mock_client(); // We just want to see if the connection is established, we won't wait for any events let _ = client.open_event_socket(None).await.unwrap(); assert!(true); } }