Private
Public Access
1
0

Retry auth automatically, remove tower dep

This commit is contained in:
2024-06-14 20:23:44 -07:00
parent 0dde0b9c53
commit cabd3b502a
8 changed files with 413 additions and 778 deletions

View File

@@ -1,24 +1,31 @@
extern crate hyper;
extern crate serde;
use std::{path::PathBuf, str};
use std::{borrow::Cow, default, path::PathBuf, str};
use log::{error};
use hyper::{Body, Client, Method, Request, Uri};
use tower::{ServiceBuilder};
use hyper::{client::ResponseFuture, Body, Client, Method, Request, Uri};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::{APIInterface, model::Conversation};
use crate::{model::{Conversation, JwtToken}, APIInterface};
type HttpClient = Client<hyper::client::HttpConnector>;
pub struct HTTPClient {
pub struct HTTPAPIClient {
pub base_url: Uri,
credentials: Option<Credentials>,
auth_token: Option<JwtToken>,
client: HttpClient,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub enum Error {
ClientError(String),
@@ -39,29 +46,65 @@ impl From <serde_json::Error> for Error {
}
}
trait AuthBuilder {
fn with_auth(self, token: &Option<JwtToken>) -> Self;
}
impl AuthBuilder for hyper::http::request::Builder {
fn with_auth(self, token: &Option<JwtToken>) -> Self {
if let Some(token) = &token {
self.header("Authorization", token.to_header_value())
} else { self }
}
}
trait AuthSetting {
fn authenticate(&mut self, token: &Option<JwtToken>);
}
impl<B> AuthSetting for hyper::http::Request<B> {
fn authenticate(&mut self, token: &Option<JwtToken>) {
if let Some(token) = &token {
self.headers_mut().insert("Authorization", token.to_header_value());
}
}
}
#[async_trait]
impl APIInterface for HTTPClient {
impl APIInterface for HTTPAPIClient {
type Error = Error;
async fn get_version(&self) -> Result<String, Self::Error> {
async fn get_version(&mut self) -> Result<String, Self::Error> {
let version: String = self.request("/version", Method::GET).await?;
Ok(version)
}
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
let conversations: Vec<Conversation> = self.request("/conversations", Method::GET).await?;
Ok(conversations)
}
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
#[derive(Deserialize, Debug)]
struct AuthResponse {
jwt: String,
}
let body = || -> Body { serde_json::to_string(&credentials).unwrap().into() };
let token: AuthResponse = self.request_with_body_retry("/authenticate", Method::POST, body, false).await?;
let token = JwtToken::new(&token.jwt).map_err(|_| Error::DecodeError)?;
self.auth_token = Some(token.clone());
Ok(token)
}
}
impl HTTPClient {
pub fn new(base_url: Uri) -> HTTPClient {
let client = ServiceBuilder::new()
.service(Client::new());
HTTPClient {
base_url,
client,
impl HTTPAPIClient {
pub fn new(base_url: Uri, credentials: Option<Credentials>) -> HTTPAPIClient {
HTTPAPIClient {
base_url: base_url,
credentials: credentials,
auth_token: Option::None,
client: Client::new(),
}
}
@@ -74,29 +117,67 @@ impl HTTPClient {
Uri::try_from(parts).unwrap()
}
async fn request<T: DeserializeOwned>(&self, endpoint: &str, method: Method) -> Result<T, Error> {
self.request_with_body(endpoint, method, Body::empty()).await
async fn request<T: DeserializeOwned>(&mut self, endpoint: &str, method: Method) -> Result<T, Error> {
self.request_with_body(endpoint, method, || { Body::empty() }).await
}
async fn request_with_body<T: DeserializeOwned>(&self, endpoint: &str, method: Method, body: Body) -> Result<T, Error> {
async fn request_with_body<T, B>(&mut self, endpoint: &str, method: Method, body_fn: B) -> Result<T, Error>
where T: DeserializeOwned, B: Fn() -> Body
{
self.request_with_body_retry(endpoint, method, body_fn, true).await
}
async fn request_with_body_retry<T, B>(
&mut self,
endpoint: &str,
method: Method,
body_fn: B,
retry_auth: bool) -> Result<T, Error>
where
T: DeserializeOwned,
B: Fn() -> Body
{
use hyper::StatusCode;
let uri = self.uri_for_endpoint(endpoint);
let request = Request::builder()
.method(method)
.uri(uri)
.body(body)
.unwrap();
let build_request = move |auth: &Option<JwtToken>| {
let body = body_fn();
Request::builder()
.method(&method)
.uri(&uri)
.with_auth(auth)
.body(body)
.expect("Unable to build request")
};
let future = self.client.request(request);
let res = future.await?;
let status = res.status();
let request = build_request(&self.auth_token);
let mut response = self.client.request(request).await?;
match response.status() {
StatusCode::OK => { /* cool */ },
if status != hyper::StatusCode::OK {
let message = format!("Request failed ({:})", status);
return Err(Error::ClientError(message));
// 401: Unauthorized. Token may have expired or is invalid. Attempt to renew.
StatusCode::UNAUTHORIZED => {
if !retry_auth {
return Err(Error::ClientError("Unauthorized".into()));
}
if let Some(credentials) = &self.credentials {
self.authenticate(credentials.clone()).await?;
let request = build_request(&self.auth_token);
response = self.client.request(request).await?;
}
},
// Other errors: bubble up.
_ => {
let message = format!("Request failed ({:})", response.status());
return Err(Error::ClientError(message));
}
}
// Read and parse response body
let body = hyper::body::to_bytes(res.into_body()).await?;
let body = hyper::body::to_bytes(response.into_body()).await?;
let parsed: T = match serde_json::from_slice(&body) {
Ok(result) => Ok(result),
Err(json_err) => {
@@ -121,13 +202,18 @@ mod test {
log::set_max_level(log::LevelFilter::Trace);
}
fn local_mock_client() -> HTTPClient {
fn local_mock_client() -> HTTPAPIClient {
let base_url = "http://localhost:5738".parse().unwrap();
HTTPClient::new(base_url)
let credentials = Credentials {
username: "test".to_string(),
password: "test".to_string(),
};
HTTPAPIClient::new(base_url, credentials.into())
}
async fn mock_client_is_reachable() -> bool {
let client = local_mock_client();
let mut client = local_mock_client();
let version = client.get_version().await;
match version {
@@ -146,7 +232,7 @@ mod test {
return;
}
let client = local_mock_client();
let mut client = local_mock_client();
let version = client.get_version().await.unwrap();
assert!(version.starts_with("KordophoneMock-"));
}
@@ -158,7 +244,7 @@ mod test {
return;
}
let client = local_mock_client();
let mut client = local_mock_client();
let conversations = client.get_conversations().await.unwrap();
assert!(!conversations.is_empty());
}