Private
Public Access
1
0

prepare for tower middleware adoption

This commit is contained in:
2024-06-01 18:16:25 -07:00
parent cf4195858e
commit a2caa2ddca
3 changed files with 936 additions and 55 deletions

922
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,17 @@ edition = "2021"
[dependencies] [dependencies]
async-trait = "0.1.80" async-trait = "0.1.80"
ctor = "0.2.8"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
hyper-tls = "0.5.0" hyper-tls = "0.5.0"
log = { version = "0.4.21", features = [] }
pretty_env_logger = "0.5.0"
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91" serde_json = "1.0.91"
serde_plain = "1.0.2"
time = { version = "0.3.17", features = ["parsing", "serde"] } time = { version = "0.3.17", features = ["parsing", "serde"] }
tokio = { version = "1.37.0", features = ["full"] } tokio = { version = "1.37.0", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["trace"] }
tower-hyper = "0.1.1"
uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] } uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] }

View File

@@ -1,12 +1,22 @@
extern crate hyper; extern crate hyper;
extern crate serde; extern crate serde;
use std::path::PathBuf; use std::{path::PathBuf, str::FromStr, str};
use log::{info, warn, error, trace};
use hyper::{Body, Client, Method, Request, Uri, body}; use hyper::{Body, Client, Method, Request, Uri, body};
use tower_hyper::client::Client as TowerClient;
use tower_http::{
trace::TraceLayer,
classify::StatusInRangeAsFailures,
};
use tower::{ServiceBuilder, Service};
use async_trait::async_trait; use async_trait::async_trait;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde_json::Error as SerdeError;
use serde_plain::{Deserializer, derive_deserialize_from_fromstr};
use crate::{APIInterface, model::Conversation}; use crate::{APIInterface, model::Conversation};
type HttpClient = Client<hyper::client::HttpConnector>; type HttpClient = Client<hyper::client::HttpConnector>;
@@ -21,6 +31,7 @@ pub enum Error {
ClientError(String), ClientError(String),
HTTPError(hyper::Error), HTTPError(hyper::Error),
SerdeError(serde_json::Error), SerdeError(serde_json::Error),
DecodeError,
} }
impl From <hyper::Error> for Error { impl From <hyper::Error> for Error {
@@ -35,15 +46,6 @@ impl From <serde_json::Error> for Error {
} }
} }
impl HTTPClient {
pub fn new(base_url: Uri) -> HTTPClient {
HTTPClient {
base_url: base_url,
client: Client::new(),
}
}
}
#[async_trait] #[async_trait]
impl APIInterface for HTTPClient { impl APIInterface for HTTPClient {
type Error = Error; type Error = Error;
@@ -60,6 +62,16 @@ impl APIInterface for HTTPClient {
} }
impl HTTPClient { impl HTTPClient {
pub fn new(base_url: Uri) -> HTTPClient {
let mut client = ServiceBuilder::new()
.service(Client::new());
HTTPClient {
base_url: base_url,
client: client,
}
}
fn uri_for_endpoint(&self, endpoint: &str) -> Uri { fn uri_for_endpoint(&self, endpoint: &str) -> Uri {
let mut parts = self.base_url.clone().into_parts(); let mut parts = self.base_url.clone().into_parts();
let root_path: PathBuf = parts.path_and_query.unwrap().path().into(); let root_path: PathBuf = parts.path_and_query.unwrap().path().into();
@@ -92,7 +104,15 @@ impl HTTPClient {
// Read and parse response body // Read and parse response body
let body = hyper::body::to_bytes(res.into_body()).await?; let body = hyper::body::to_bytes(res.into_body()).await?;
let parsed: T = serde_json::from_slice(&body)?; let parsed: T = match serde_json::from_slice(&body) {
Ok(result) => Ok(result),
Err(json_err) => {
// If JSON deserialization fails, try to interpret it as plain text
// Unfortunately the server does return things like this...
let s = str::from_utf8(&body).map_err(|_| Error::DecodeError)?;
serde_plain::from_str(s).map_err(|_| json_err)
}
}?;
Ok(parsed) Ok(parsed)
} }
@@ -100,6 +120,13 @@ impl HTTPClient {
mod test { mod test {
use super::*; use super::*;
use ctor::ctor;
#[ctor]
fn init() {
pretty_env_logger::init();
log::set_max_level(log::LevelFilter::Trace);
}
fn local_mock_client() -> HTTPClient { fn local_mock_client() -> HTTPClient {
let base_url = "http://localhost:5738".parse().unwrap(); let base_url = "http://localhost:5738".parse().unwrap();
@@ -109,13 +136,20 @@ mod test {
async fn mock_client_is_reachable() -> bool { async fn mock_client_is_reachable() -> bool {
let client = local_mock_client(); let client = local_mock_client();
let version = client.get_version().await; let version = client.get_version().await;
version.is_ok()
match version {
Ok(_) => true,
Err(e) => {
error!("Mock client error: {:?}", e);
false
}
}
} }
#[tokio::test] #[tokio::test]
async fn test_version() { async fn test_version() {
if !mock_client_is_reachable().await { if !mock_client_is_reachable().await {
println!("Skipping http_client tests (mock server not reachable)"); log::warn!("Skipping http_client tests (mock server not reachable)");
return; return;
} }
@@ -127,7 +161,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_conversations() { async fn test_conversations() {
if !mock_client_is_reachable().await { if !mock_client_is_reachable().await {
println!("Skipping http_client tests (mock server not reachable)"); log::warn!("Skipping http_client tests (mock server not reachable)");
return; return;
} }