Private
Public Access
1
0

Retry auth automatically, remove tower dep

This commit is contained in:
2024-06-14 20:23:44 -07:00
parent 0dde0b9c53
commit cabd3b502a
8 changed files with 413 additions and 778 deletions

886
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,8 @@ edition = "2021"
[dependencies]
async-trait = "0.1.80"
base64 = "0.22.1"
chrono = "0.4.38"
ctor = "0.2.8"
hyper = { version = "0.14", features = ["full"] }
hyper-tls = "0.5.0"
@@ -17,7 +19,4 @@ serde_json = "1.0.91"
serde_plain = "1.0.2"
time = { version = "0.3.17", features = ["parsing", "serde"] }
tokio = { version = "1.37.0", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["trace"] }
tower-hyper = "0.1.1"
uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] }

View File

@@ -1,24 +1,31 @@
extern crate hyper;
extern crate serde;
use std::{path::PathBuf, str};
use std::{borrow::Cow, default, path::PathBuf, str};
use log::{error};
use hyper::{Body, Client, Method, Request, Uri};
use tower::{ServiceBuilder};
use hyper::{client::ResponseFuture, Body, Client, Method, Request, Uri};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::{APIInterface, model::Conversation};
use crate::{model::{Conversation, JwtToken}, APIInterface};
type HttpClient = Client<hyper::client::HttpConnector>;
pub struct HTTPClient {
pub struct HTTPAPIClient {
pub base_url: Uri,
credentials: Option<Credentials>,
auth_token: Option<JwtToken>,
client: HttpClient,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub enum Error {
ClientError(String),
@@ -39,29 +46,65 @@ impl From <serde_json::Error> for Error {
}
}
trait AuthBuilder {
fn with_auth(self, token: &Option<JwtToken>) -> Self;
}
impl AuthBuilder for hyper::http::request::Builder {
fn with_auth(self, token: &Option<JwtToken>) -> Self {
if let Some(token) = &token {
self.header("Authorization", token.to_header_value())
} else { self }
}
}
trait AuthSetting {
fn authenticate(&mut self, token: &Option<JwtToken>);
}
impl<B> AuthSetting for hyper::http::Request<B> {
fn authenticate(&mut self, token: &Option<JwtToken>) {
if let Some(token) = &token {
self.headers_mut().insert("Authorization", token.to_header_value());
}
}
}
#[async_trait]
impl APIInterface for HTTPClient {
impl APIInterface for HTTPAPIClient {
type Error = Error;
async fn get_version(&self) -> Result<String, Self::Error> {
async fn get_version(&mut self) -> Result<String, Self::Error> {
let version: String = self.request("/version", Method::GET).await?;
Ok(version)
}
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
let conversations: Vec<Conversation> = self.request("/conversations", Method::GET).await?;
Ok(conversations)
}
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
#[derive(Deserialize, Debug)]
struct AuthResponse {
jwt: String,
}
impl HTTPClient {
pub fn new(base_url: Uri) -> HTTPClient {
let client = ServiceBuilder::new()
.service(Client::new());
let body = || -> Body { serde_json::to_string(&credentials).unwrap().into() };
let token: AuthResponse = self.request_with_body_retry("/authenticate", Method::POST, body, false).await?;
let token = JwtToken::new(&token.jwt).map_err(|_| Error::DecodeError)?;
self.auth_token = Some(token.clone());
Ok(token)
}
}
HTTPClient {
base_url,
client,
impl HTTPAPIClient {
pub fn new(base_url: Uri, credentials: Option<Credentials>) -> HTTPAPIClient {
HTTPAPIClient {
base_url: base_url,
credentials: credentials,
auth_token: Option::None,
client: Client::new(),
}
}
@@ -74,29 +117,67 @@ impl HTTPClient {
Uri::try_from(parts).unwrap()
}
async fn request<T: DeserializeOwned>(&self, endpoint: &str, method: Method) -> Result<T, Error> {
self.request_with_body(endpoint, method, Body::empty()).await
async fn request<T: DeserializeOwned>(&mut self, endpoint: &str, method: Method) -> Result<T, Error> {
self.request_with_body(endpoint, method, || { Body::empty() }).await
}
async fn request_with_body<T: DeserializeOwned>(&self, endpoint: &str, method: Method, body: Body) -> Result<T, Error> {
async fn request_with_body<T, B>(&mut self, endpoint: &str, method: Method, body_fn: B) -> Result<T, Error>
where T: DeserializeOwned, B: Fn() -> Body
{
self.request_with_body_retry(endpoint, method, body_fn, true).await
}
async fn request_with_body_retry<T, B>(
&mut self,
endpoint: &str,
method: Method,
body_fn: B,
retry_auth: bool) -> Result<T, Error>
where
T: DeserializeOwned,
B: Fn() -> Body
{
use hyper::StatusCode;
let uri = self.uri_for_endpoint(endpoint);
let request = Request::builder()
.method(method)
.uri(uri)
let build_request = move |auth: &Option<JwtToken>| {
let body = body_fn();
Request::builder()
.method(&method)
.uri(&uri)
.with_auth(auth)
.body(body)
.unwrap();
.expect("Unable to build request")
};
let future = self.client.request(request);
let res = future.await?;
let status = res.status();
let request = build_request(&self.auth_token);
let mut response = self.client.request(request).await?;
match response.status() {
StatusCode::OK => { /* cool */ },
if status != hyper::StatusCode::OK {
let message = format!("Request failed ({:})", status);
// 401: Unauthorized. Token may have expired or is invalid. Attempt to renew.
StatusCode::UNAUTHORIZED => {
if !retry_auth {
return Err(Error::ClientError("Unauthorized".into()));
}
if let Some(credentials) = &self.credentials {
self.authenticate(credentials.clone()).await?;
let request = build_request(&self.auth_token);
response = self.client.request(request).await?;
}
},
// Other errors: bubble up.
_ => {
let message = format!("Request failed ({:})", response.status());
return Err(Error::ClientError(message));
}
}
// Read and parse response body
let body = hyper::body::to_bytes(res.into_body()).await?;
let body = hyper::body::to_bytes(response.into_body()).await?;
let parsed: T = match serde_json::from_slice(&body) {
Ok(result) => Ok(result),
Err(json_err) => {
@@ -121,13 +202,18 @@ mod test {
log::set_max_level(log::LevelFilter::Trace);
}
fn local_mock_client() -> HTTPClient {
fn local_mock_client() -> HTTPAPIClient {
let base_url = "http://localhost:5738".parse().unwrap();
HTTPClient::new(base_url)
let credentials = Credentials {
username: "test".to_string(),
password: "test".to_string(),
};
HTTPAPIClient::new(base_url, credentials.into())
}
async fn mock_client_is_reachable() -> bool {
let client = local_mock_client();
let mut client = local_mock_client();
let version = client.get_version().await;
match version {
@@ -146,7 +232,7 @@ mod test {
return;
}
let client = local_mock_client();
let mut client = local_mock_client();
let version = client.get_version().await.unwrap();
assert!(version.starts_with("KordophoneMock-"));
}
@@ -158,7 +244,7 @@ mod test {
return;
}
let client = local_mock_client();
let mut client = local_mock_client();
let conversations = client.get_conversations().await.unwrap();
assert!(!conversations.is_empty());
}

View File

@@ -1,17 +1,23 @@
use async_trait::async_trait;
pub use crate::model::Conversation;
use crate::model::JwtToken;
pub mod http_client;
pub use http_client::HTTPClient;
pub use http_client::HTTPAPIClient;
use self::http_client::Credentials;
#[async_trait]
pub trait APIInterface {
type Error;
// (GET) /version
async fn get_version(&self) -> Result<String, Self::Error>;
async fn get_version(&mut self) -> Result<String, Self::Error>;
// (GET) /conversations
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error>;
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error>;
// (POST) /authenticate
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error>;
}

112
kordophone/src/model/jwt.rs Normal file
View File

@@ -0,0 +1,112 @@
use std::error::Error;
use base64::{
engine::{self, general_purpose},
Engine,
};
use chrono::{DateTime, Utc};
use hyper::http::HeaderValue;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
struct JwtHeader {
alg: String,
typ: String,
}
#[derive(Deserialize, Debug, Clone)]
enum ExpValue {
Integer(i64),
String(String),
}
#[derive(Deserialize, Debug, Clone)]
struct JwtPayload {
exp: serde_json::Value,
iss: Option<String>,
user: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JwtToken {
header: JwtHeader,
payload: JwtPayload,
signature: Vec<u8>,
expiration_date: DateTime<Utc>,
token: String,
}
impl JwtToken {
fn decode_token_using_engine(
token: &str,
engine: engine::GeneralPurpose,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let mut parts = token.split('.');
let header = parts.next().unwrap();
let payload = parts.next().unwrap();
let signature = parts.next().unwrap();
let header = engine.decode(header)?;
let payload = engine.decode(payload)?;
let signature = engine.decode(signature)?;
// Parse jwt header
let header: JwtHeader = serde_json::from_slice(&header)?;
// Parse jwt payload
let payload: JwtPayload = serde_json::from_slice(&payload)?;
// Parse jwt expiration date
// Annoyingly, because of my own fault, this could be either an integer or string.
let exp: i64 = payload.exp.as_i64().unwrap_or_else(|| {
let exp: String = payload.exp.as_str().unwrap().to_string();
exp.parse().unwrap()
});
let timestamp = chrono::NaiveDateTime::from_timestamp_opt(exp, 0).unwrap();
let expiration_date = DateTime::<Utc>::from_utc(timestamp, Utc);
Ok(JwtToken {
header,
payload,
signature,
expiration_date,
token: token.to_string(),
})
}
pub fn new(token: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
// STUPID: My mock server uses a different encoding than the real server, so we have to
// try both encodings here.
Self::decode_token_using_engine(token, general_purpose::STANDARD).or(
Self::decode_token_using_engine(token, general_purpose::URL_SAFE_NO_PAD),
)
}
pub fn dummy() -> Self {
JwtToken {
header: JwtHeader {
alg: "none".to_string(),
typ: "JWT".to_string(),
},
payload: JwtPayload {
exp: serde_json::Value::Null,
iss: None,
user: None,
},
signature: vec![],
expiration_date: Utc::now(),
token: "".to_string(),
}
}
pub fn is_valid(&self) -> bool {
self.expiration_date > Utc::now()
}
pub fn to_header_value(&self) -> HeaderValue {
format!("Bearer {}", self.token).parse().unwrap()
}
}

View File

@@ -1,3 +1,5 @@
pub mod conversation;
pub use conversation::Conversation;
pub mod jwt;
pub use jwt::JwtToken;

View File

@@ -9,7 +9,7 @@ pub mod api_interface {
#[tokio::test]
async fn test_version() {
let client = TestClient::new();
let mut client = TestClient::new();
let version = client.get_version().await.unwrap();
assert_eq!(version, client.version);
}

View File

@@ -1,7 +1,7 @@
use async_trait::async_trait;
pub use crate::APIInterface;
use crate::model::Conversation;
use crate::{api::http_client::Credentials, model::{Conversation, JwtToken}};
pub struct TestClient {
pub version: &'static str,
@@ -24,11 +24,15 @@ impl TestClient {
impl APIInterface for TestClient {
type Error = TestError;
async fn get_version(&self) -> Result<String, Self::Error> {
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
Ok(JwtToken::dummy())
}
async fn get_version(&mut self) -> Result<String, Self::Error> {
Ok(self.version.to_string())
}
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
Ok(self.conversations.clone())
}
}