Retry auth automatically, remove tower dep
This commit is contained in:
886
Cargo.lock
generated
886
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.80"
|
||||
base64 = "0.22.1"
|
||||
chrono = "0.4.38"
|
||||
ctor = "0.2.8"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
hyper-tls = "0.5.0"
|
||||
@@ -17,7 +19,4 @@ serde_json = "1.0.91"
|
||||
serde_plain = "1.0.2"
|
||||
time = { version = "0.3.17", features = ["parsing", "serde"] }
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
||||
tower-hyper = "0.1.1"
|
||||
uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
extern crate hyper;
|
||||
extern crate serde;
|
||||
|
||||
use std::{path::PathBuf, str};
|
||||
use std::{borrow::Cow, default, path::PathBuf, str};
|
||||
use log::{error};
|
||||
|
||||
use hyper::{Body, Client, Method, Request, Uri};
|
||||
use tower::{ServiceBuilder};
|
||||
use hyper::{client::ResponseFuture, Body, Client, Method, Request, Uri};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
use crate::{APIInterface, model::Conversation};
|
||||
use crate::{model::{Conversation, JwtToken}, APIInterface};
|
||||
|
||||
type HttpClient = Client<hyper::client::HttpConnector>;
|
||||
|
||||
pub struct HTTPClient {
|
||||
pub struct HTTPAPIClient {
|
||||
pub base_url: Uri,
|
||||
credentials: Option<Credentials>,
|
||||
auth_token: Option<JwtToken>,
|
||||
client: HttpClient,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Credentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
ClientError(String),
|
||||
@@ -39,29 +46,65 @@ impl From <serde_json::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
trait AuthBuilder {
|
||||
fn with_auth(self, token: &Option<JwtToken>) -> Self;
|
||||
}
|
||||
|
||||
impl AuthBuilder for hyper::http::request::Builder {
|
||||
fn with_auth(self, token: &Option<JwtToken>) -> Self {
|
||||
if let Some(token) = &token {
|
||||
self.header("Authorization", token.to_header_value())
|
||||
} else { self }
|
||||
}
|
||||
}
|
||||
|
||||
trait AuthSetting {
|
||||
fn authenticate(&mut self, token: &Option<JwtToken>);
|
||||
}
|
||||
|
||||
impl<B> AuthSetting for hyper::http::Request<B> {
|
||||
fn authenticate(&mut self, token: &Option<JwtToken>) {
|
||||
if let Some(token) = &token {
|
||||
self.headers_mut().insert("Authorization", token.to_header_value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl APIInterface for HTTPClient {
|
||||
impl APIInterface for HTTPAPIClient {
|
||||
type Error = Error;
|
||||
|
||||
async fn get_version(&self) -> Result<String, Self::Error> {
|
||||
async fn get_version(&mut self) -> Result<String, Self::Error> {
|
||||
let version: String = self.request("/version", Method::GET).await?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
|
||||
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
|
||||
let conversations: Vec<Conversation> = self.request("/conversations", Method::GET).await?;
|
||||
Ok(conversations)
|
||||
}
|
||||
|
||||
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct AuthResponse {
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
let body = || -> Body { serde_json::to_string(&credentials).unwrap().into() };
|
||||
let token: AuthResponse = self.request_with_body_retry("/authenticate", Method::POST, body, false).await?;
|
||||
let token = JwtToken::new(&token.jwt).map_err(|_| Error::DecodeError)?;
|
||||
self.auth_token = Some(token.clone());
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
impl HTTPClient {
|
||||
pub fn new(base_url: Uri) -> HTTPClient {
|
||||
let client = ServiceBuilder::new()
|
||||
.service(Client::new());
|
||||
|
||||
HTTPClient {
|
||||
base_url,
|
||||
client,
|
||||
impl HTTPAPIClient {
|
||||
pub fn new(base_url: Uri, credentials: Option<Credentials>) -> HTTPAPIClient {
|
||||
HTTPAPIClient {
|
||||
base_url: base_url,
|
||||
credentials: credentials,
|
||||
auth_token: Option::None,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,29 +117,67 @@ impl HTTPClient {
|
||||
Uri::try_from(parts).unwrap()
|
||||
}
|
||||
|
||||
async fn request<T: DeserializeOwned>(&self, endpoint: &str, method: Method) -> Result<T, Error> {
|
||||
self.request_with_body(endpoint, method, Body::empty()).await
|
||||
async fn request<T: DeserializeOwned>(&mut self, endpoint: &str, method: Method) -> Result<T, Error> {
|
||||
self.request_with_body(endpoint, method, || { Body::empty() }).await
|
||||
}
|
||||
|
||||
async fn request_with_body<T: DeserializeOwned>(&self, endpoint: &str, method: Method, body: Body) -> Result<T, Error> {
|
||||
async fn request_with_body<T, B>(&mut self, endpoint: &str, method: Method, body_fn: B) -> Result<T, Error>
|
||||
where T: DeserializeOwned, B: Fn() -> Body
|
||||
{
|
||||
self.request_with_body_retry(endpoint, method, body_fn, true).await
|
||||
}
|
||||
|
||||
async fn request_with_body_retry<T, B>(
|
||||
&mut self,
|
||||
endpoint: &str,
|
||||
method: Method,
|
||||
body_fn: B,
|
||||
retry_auth: bool) -> Result<T, Error>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
B: Fn() -> Body
|
||||
{
|
||||
use hyper::StatusCode;
|
||||
|
||||
let uri = self.uri_for_endpoint(endpoint);
|
||||
let request = Request::builder()
|
||||
.method(method)
|
||||
.uri(uri)
|
||||
.body(body)
|
||||
.unwrap();
|
||||
let build_request = move |auth: &Option<JwtToken>| {
|
||||
let body = body_fn();
|
||||
Request::builder()
|
||||
.method(&method)
|
||||
.uri(&uri)
|
||||
.with_auth(auth)
|
||||
.body(body)
|
||||
.expect("Unable to build request")
|
||||
};
|
||||
|
||||
let future = self.client.request(request);
|
||||
let res = future.await?;
|
||||
let status = res.status();
|
||||
let request = build_request(&self.auth_token);
|
||||
let mut response = self.client.request(request).await?;
|
||||
match response.status() {
|
||||
StatusCode::OK => { /* cool */ },
|
||||
|
||||
if status != hyper::StatusCode::OK {
|
||||
let message = format!("Request failed ({:})", status);
|
||||
return Err(Error::ClientError(message));
|
||||
// 401: Unauthorized. Token may have expired or is invalid. Attempt to renew.
|
||||
StatusCode::UNAUTHORIZED => {
|
||||
if !retry_auth {
|
||||
return Err(Error::ClientError("Unauthorized".into()));
|
||||
}
|
||||
|
||||
if let Some(credentials) = &self.credentials {
|
||||
self.authenticate(credentials.clone()).await?;
|
||||
|
||||
let request = build_request(&self.auth_token);
|
||||
response = self.client.request(request).await?;
|
||||
}
|
||||
},
|
||||
|
||||
// Other errors: bubble up.
|
||||
_ => {
|
||||
let message = format!("Request failed ({:})", response.status());
|
||||
return Err(Error::ClientError(message));
|
||||
}
|
||||
}
|
||||
|
||||
// Read and parse response body
|
||||
let body = hyper::body::to_bytes(res.into_body()).await?;
|
||||
let body = hyper::body::to_bytes(response.into_body()).await?;
|
||||
let parsed: T = match serde_json::from_slice(&body) {
|
||||
Ok(result) => Ok(result),
|
||||
Err(json_err) => {
|
||||
@@ -121,13 +202,18 @@ mod test {
|
||||
log::set_max_level(log::LevelFilter::Trace);
|
||||
}
|
||||
|
||||
fn local_mock_client() -> HTTPClient {
|
||||
fn local_mock_client() -> HTTPAPIClient {
|
||||
let base_url = "http://localhost:5738".parse().unwrap();
|
||||
HTTPClient::new(base_url)
|
||||
let credentials = Credentials {
|
||||
username: "test".to_string(),
|
||||
password: "test".to_string(),
|
||||
};
|
||||
|
||||
HTTPAPIClient::new(base_url, credentials.into())
|
||||
}
|
||||
|
||||
async fn mock_client_is_reachable() -> bool {
|
||||
let client = local_mock_client();
|
||||
let mut client = local_mock_client();
|
||||
let version = client.get_version().await;
|
||||
|
||||
match version {
|
||||
@@ -146,7 +232,7 @@ mod test {
|
||||
return;
|
||||
}
|
||||
|
||||
let client = local_mock_client();
|
||||
let mut client = local_mock_client();
|
||||
let version = client.get_version().await.unwrap();
|
||||
assert!(version.starts_with("KordophoneMock-"));
|
||||
}
|
||||
@@ -158,7 +244,7 @@ mod test {
|
||||
return;
|
||||
}
|
||||
|
||||
let client = local_mock_client();
|
||||
let mut client = local_mock_client();
|
||||
let conversations = client.get_conversations().await.unwrap();
|
||||
assert!(!conversations.is_empty());
|
||||
}
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
use async_trait::async_trait;
|
||||
pub use crate::model::Conversation;
|
||||
use crate::model::JwtToken;
|
||||
|
||||
pub mod http_client;
|
||||
pub use http_client::HTTPClient;
|
||||
pub use http_client::HTTPAPIClient;
|
||||
|
||||
use self::http_client::Credentials;
|
||||
|
||||
#[async_trait]
|
||||
pub trait APIInterface {
|
||||
type Error;
|
||||
|
||||
// (GET) /version
|
||||
async fn get_version(&self) -> Result<String, Self::Error>;
|
||||
async fn get_version(&mut self) -> Result<String, Self::Error>;
|
||||
|
||||
// (GET) /conversations
|
||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error>;
|
||||
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error>;
|
||||
|
||||
// (POST) /authenticate
|
||||
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error>;
|
||||
}
|
||||
|
||||
|
||||
112
kordophone/src/model/jwt.rs
Normal file
112
kordophone/src/model/jwt.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use std::error::Error;
|
||||
|
||||
use base64::{
|
||||
engine::{self, general_purpose},
|
||||
Engine,
|
||||
};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use hyper::http::HeaderValue;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct JwtHeader {
|
||||
alg: String,
|
||||
typ: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
enum ExpValue {
|
||||
Integer(i64),
|
||||
String(String),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct JwtPayload {
|
||||
exp: serde_json::Value,
|
||||
iss: Option<String>,
|
||||
user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwtToken {
|
||||
header: JwtHeader,
|
||||
payload: JwtPayload,
|
||||
signature: Vec<u8>,
|
||||
expiration_date: DateTime<Utc>,
|
||||
token: String,
|
||||
}
|
||||
|
||||
impl JwtToken {
|
||||
fn decode_token_using_engine(
|
||||
token: &str,
|
||||
engine: engine::GeneralPurpose,
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
let mut parts = token.split('.');
|
||||
let header = parts.next().unwrap();
|
||||
let payload = parts.next().unwrap();
|
||||
let signature = parts.next().unwrap();
|
||||
|
||||
let header = engine.decode(header)?;
|
||||
let payload = engine.decode(payload)?;
|
||||
let signature = engine.decode(signature)?;
|
||||
|
||||
// Parse jwt header
|
||||
let header: JwtHeader = serde_json::from_slice(&header)?;
|
||||
|
||||
// Parse jwt payload
|
||||
let payload: JwtPayload = serde_json::from_slice(&payload)?;
|
||||
|
||||
// Parse jwt expiration date
|
||||
// Annoyingly, because of my own fault, this could be either an integer or string.
|
||||
let exp: i64 = payload.exp.as_i64().unwrap_or_else(|| {
|
||||
let exp: String = payload.exp.as_str().unwrap().to_string();
|
||||
exp.parse().unwrap()
|
||||
});
|
||||
|
||||
let timestamp = chrono::NaiveDateTime::from_timestamp_opt(exp, 0).unwrap();
|
||||
let expiration_date = DateTime::<Utc>::from_utc(timestamp, Utc);
|
||||
|
||||
Ok(JwtToken {
|
||||
header,
|
||||
payload,
|
||||
signature,
|
||||
expiration_date,
|
||||
token: token.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(token: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
// STUPID: My mock server uses a different encoding than the real server, so we have to
|
||||
// try both encodings here.
|
||||
|
||||
Self::decode_token_using_engine(token, general_purpose::STANDARD).or(
|
||||
Self::decode_token_using_engine(token, general_purpose::URL_SAFE_NO_PAD),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn dummy() -> Self {
|
||||
JwtToken {
|
||||
header: JwtHeader {
|
||||
alg: "none".to_string(),
|
||||
typ: "JWT".to_string(),
|
||||
},
|
||||
payload: JwtPayload {
|
||||
exp: serde_json::Value::Null,
|
||||
iss: None,
|
||||
user: None,
|
||||
},
|
||||
signature: vec![],
|
||||
expiration_date: Utc::now(),
|
||||
token: "".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_valid(&self) -> bool {
|
||||
self.expiration_date > Utc::now()
|
||||
}
|
||||
|
||||
pub fn to_header_value(&self) -> HeaderValue {
|
||||
format!("Bearer {}", self.token).parse().unwrap()
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
|
||||
pub mod conversation;
|
||||
pub use conversation::Conversation;
|
||||
|
||||
pub mod jwt;
|
||||
pub use jwt::JwtToken;
|
||||
@@ -9,7 +9,7 @@ pub mod api_interface {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_version() {
|
||||
let client = TestClient::new();
|
||||
let mut client = TestClient::new();
|
||||
let version = client.get_version().await.unwrap();
|
||||
assert_eq!(version, client.version);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub use crate::APIInterface;
|
||||
use crate::model::Conversation;
|
||||
use crate::{api::http_client::Credentials, model::{Conversation, JwtToken}};
|
||||
|
||||
pub struct TestClient {
|
||||
pub version: &'static str,
|
||||
@@ -24,11 +24,15 @@ impl TestClient {
|
||||
impl APIInterface for TestClient {
|
||||
type Error = TestError;
|
||||
|
||||
async fn get_version(&self) -> Result<String, Self::Error> {
|
||||
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
|
||||
Ok(JwtToken::dummy())
|
||||
}
|
||||
|
||||
async fn get_version(&mut self) -> Result<String, Self::Error> {
|
||||
Ok(self.version.to_string())
|
||||
}
|
||||
|
||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
|
||||
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
|
||||
Ok(self.conversations.clone())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user