Retry auth automatically, remove tower dep
This commit is contained in:
886
Cargo.lock
generated
886
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.80"
|
async-trait = "0.1.80"
|
||||||
|
base64 = "0.22.1"
|
||||||
|
chrono = "0.4.38"
|
||||||
ctor = "0.2.8"
|
ctor = "0.2.8"
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
hyper-tls = "0.5.0"
|
hyper-tls = "0.5.0"
|
||||||
@@ -17,7 +19,4 @@ serde_json = "1.0.91"
|
|||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
time = { version = "0.3.17", features = ["parsing", "serde"] }
|
time = { version = "0.3.17", features = ["parsing", "serde"] }
|
||||||
tokio = { version = "1.37.0", features = ["full"] }
|
tokio = { version = "1.37.0", features = ["full"] }
|
||||||
tower = "0.4.13"
|
|
||||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
|
||||||
tower-hyper = "0.1.1"
|
|
||||||
uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] }
|
uuid = { version = "1.6.1", features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||||
|
|||||||
@@ -1,24 +1,31 @@
|
|||||||
extern crate hyper;
|
extern crate hyper;
|
||||||
extern crate serde;
|
extern crate serde;
|
||||||
|
|
||||||
use std::{path::PathBuf, str};
|
use std::{borrow::Cow, default, path::PathBuf, str};
|
||||||
use log::{error};
|
use log::{error};
|
||||||
|
|
||||||
use hyper::{Body, Client, Method, Request, Uri};
|
use hyper::{client::ResponseFuture, Body, Client, Method, Request, Uri};
|
||||||
use tower::{ServiceBuilder};
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{APIInterface, model::Conversation};
|
use crate::{model::{Conversation, JwtToken}, APIInterface};
|
||||||
|
|
||||||
type HttpClient = Client<hyper::client::HttpConnector>;
|
type HttpClient = Client<hyper::client::HttpConnector>;
|
||||||
|
|
||||||
pub struct HTTPClient {
|
pub struct HTTPAPIClient {
|
||||||
pub base_url: Uri,
|
pub base_url: Uri,
|
||||||
|
credentials: Option<Credentials>,
|
||||||
|
auth_token: Option<JwtToken>,
|
||||||
client: HttpClient,
|
client: HttpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||||
|
pub struct Credentials {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
ClientError(String),
|
ClientError(String),
|
||||||
@@ -39,29 +46,65 @@ impl From <serde_json::Error> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait AuthBuilder {
|
||||||
|
fn with_auth(self, token: &Option<JwtToken>) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthBuilder for hyper::http::request::Builder {
|
||||||
|
fn with_auth(self, token: &Option<JwtToken>) -> Self {
|
||||||
|
if let Some(token) = &token {
|
||||||
|
self.header("Authorization", token.to_header_value())
|
||||||
|
} else { self }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait AuthSetting {
|
||||||
|
fn authenticate(&mut self, token: &Option<JwtToken>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B> AuthSetting for hyper::http::Request<B> {
|
||||||
|
fn authenticate(&mut self, token: &Option<JwtToken>) {
|
||||||
|
if let Some(token) = &token {
|
||||||
|
self.headers_mut().insert("Authorization", token.to_header_value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl APIInterface for HTTPClient {
|
impl APIInterface for HTTPAPIClient {
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
|
|
||||||
async fn get_version(&self) -> Result<String, Self::Error> {
|
async fn get_version(&mut self) -> Result<String, Self::Error> {
|
||||||
let version: String = self.request("/version", Method::GET).await?;
|
let version: String = self.request("/version", Method::GET).await?;
|
||||||
Ok(version)
|
Ok(version)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
|
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
|
||||||
let conversations: Vec<Conversation> = self.request("/conversations", Method::GET).await?;
|
let conversations: Vec<Conversation> = self.request("/conversations", Method::GET).await?;
|
||||||
Ok(conversations)
|
Ok(conversations)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct AuthResponse {
|
||||||
|
jwt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = || -> Body { serde_json::to_string(&credentials).unwrap().into() };
|
||||||
|
let token: AuthResponse = self.request_with_body_retry("/authenticate", Method::POST, body, false).await?;
|
||||||
|
let token = JwtToken::new(&token.jwt).map_err(|_| Error::DecodeError)?;
|
||||||
|
self.auth_token = Some(token.clone());
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HTTPClient {
|
impl HTTPAPIClient {
|
||||||
pub fn new(base_url: Uri) -> HTTPClient {
|
pub fn new(base_url: Uri, credentials: Option<Credentials>) -> HTTPAPIClient {
|
||||||
let client = ServiceBuilder::new()
|
HTTPAPIClient {
|
||||||
.service(Client::new());
|
base_url: base_url,
|
||||||
|
credentials: credentials,
|
||||||
HTTPClient {
|
auth_token: Option::None,
|
||||||
base_url,
|
client: Client::new(),
|
||||||
client,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,29 +117,67 @@ impl HTTPClient {
|
|||||||
Uri::try_from(parts).unwrap()
|
Uri::try_from(parts).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn request<T: DeserializeOwned>(&self, endpoint: &str, method: Method) -> Result<T, Error> {
|
async fn request<T: DeserializeOwned>(&mut self, endpoint: &str, method: Method) -> Result<T, Error> {
|
||||||
self.request_with_body(endpoint, method, Body::empty()).await
|
self.request_with_body(endpoint, method, || { Body::empty() }).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn request_with_body<T: DeserializeOwned>(&self, endpoint: &str, method: Method, body: Body) -> Result<T, Error> {
|
async fn request_with_body<T, B>(&mut self, endpoint: &str, method: Method, body_fn: B) -> Result<T, Error>
|
||||||
|
where T: DeserializeOwned, B: Fn() -> Body
|
||||||
|
{
|
||||||
|
self.request_with_body_retry(endpoint, method, body_fn, true).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn request_with_body_retry<T, B>(
|
||||||
|
&mut self,
|
||||||
|
endpoint: &str,
|
||||||
|
method: Method,
|
||||||
|
body_fn: B,
|
||||||
|
retry_auth: bool) -> Result<T, Error>
|
||||||
|
where
|
||||||
|
T: DeserializeOwned,
|
||||||
|
B: Fn() -> Body
|
||||||
|
{
|
||||||
|
use hyper::StatusCode;
|
||||||
|
|
||||||
let uri = self.uri_for_endpoint(endpoint);
|
let uri = self.uri_for_endpoint(endpoint);
|
||||||
let request = Request::builder()
|
let build_request = move |auth: &Option<JwtToken>| {
|
||||||
.method(method)
|
let body = body_fn();
|
||||||
.uri(uri)
|
Request::builder()
|
||||||
|
.method(&method)
|
||||||
|
.uri(&uri)
|
||||||
|
.with_auth(auth)
|
||||||
.body(body)
|
.body(body)
|
||||||
.unwrap();
|
.expect("Unable to build request")
|
||||||
|
};
|
||||||
|
|
||||||
let future = self.client.request(request);
|
let request = build_request(&self.auth_token);
|
||||||
let res = future.await?;
|
let mut response = self.client.request(request).await?;
|
||||||
let status = res.status();
|
match response.status() {
|
||||||
|
StatusCode::OK => { /* cool */ },
|
||||||
|
|
||||||
if status != hyper::StatusCode::OK {
|
// 401: Unauthorized. Token may have expired or is invalid. Attempt to renew.
|
||||||
let message = format!("Request failed ({:})", status);
|
StatusCode::UNAUTHORIZED => {
|
||||||
|
if !retry_auth {
|
||||||
|
return Err(Error::ClientError("Unauthorized".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(credentials) = &self.credentials {
|
||||||
|
self.authenticate(credentials.clone()).await?;
|
||||||
|
|
||||||
|
let request = build_request(&self.auth_token);
|
||||||
|
response = self.client.request(request).await?;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Other errors: bubble up.
|
||||||
|
_ => {
|
||||||
|
let message = format!("Request failed ({:})", response.status());
|
||||||
return Err(Error::ClientError(message));
|
return Err(Error::ClientError(message));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Read and parse response body
|
// Read and parse response body
|
||||||
let body = hyper::body::to_bytes(res.into_body()).await?;
|
let body = hyper::body::to_bytes(response.into_body()).await?;
|
||||||
let parsed: T = match serde_json::from_slice(&body) {
|
let parsed: T = match serde_json::from_slice(&body) {
|
||||||
Ok(result) => Ok(result),
|
Ok(result) => Ok(result),
|
||||||
Err(json_err) => {
|
Err(json_err) => {
|
||||||
@@ -121,13 +202,18 @@ mod test {
|
|||||||
log::set_max_level(log::LevelFilter::Trace);
|
log::set_max_level(log::LevelFilter::Trace);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn local_mock_client() -> HTTPClient {
|
fn local_mock_client() -> HTTPAPIClient {
|
||||||
let base_url = "http://localhost:5738".parse().unwrap();
|
let base_url = "http://localhost:5738".parse().unwrap();
|
||||||
HTTPClient::new(base_url)
|
let credentials = Credentials {
|
||||||
|
username: "test".to_string(),
|
||||||
|
password: "test".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
HTTPAPIClient::new(base_url, credentials.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mock_client_is_reachable() -> bool {
|
async fn mock_client_is_reachable() -> bool {
|
||||||
let client = local_mock_client();
|
let mut client = local_mock_client();
|
||||||
let version = client.get_version().await;
|
let version = client.get_version().await;
|
||||||
|
|
||||||
match version {
|
match version {
|
||||||
@@ -146,7 +232,7 @@ mod test {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = local_mock_client();
|
let mut client = local_mock_client();
|
||||||
let version = client.get_version().await.unwrap();
|
let version = client.get_version().await.unwrap();
|
||||||
assert!(version.starts_with("KordophoneMock-"));
|
assert!(version.starts_with("KordophoneMock-"));
|
||||||
}
|
}
|
||||||
@@ -158,7 +244,7 @@ mod test {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = local_mock_client();
|
let mut client = local_mock_client();
|
||||||
let conversations = client.get_conversations().await.unwrap();
|
let conversations = client.get_conversations().await.unwrap();
|
||||||
assert!(!conversations.is_empty());
|
assert!(!conversations.is_empty());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,23 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
pub use crate::model::Conversation;
|
pub use crate::model::Conversation;
|
||||||
|
use crate::model::JwtToken;
|
||||||
|
|
||||||
pub mod http_client;
|
pub mod http_client;
|
||||||
pub use http_client::HTTPClient;
|
pub use http_client::HTTPAPIClient;
|
||||||
|
|
||||||
|
use self::http_client::Credentials;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait APIInterface {
|
pub trait APIInterface {
|
||||||
type Error;
|
type Error;
|
||||||
|
|
||||||
// (GET) /version
|
// (GET) /version
|
||||||
async fn get_version(&self) -> Result<String, Self::Error>;
|
async fn get_version(&mut self) -> Result<String, Self::Error>;
|
||||||
|
|
||||||
// (GET) /conversations
|
// (GET) /conversations
|
||||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error>;
|
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error>;
|
||||||
|
|
||||||
|
// (POST) /authenticate
|
||||||
|
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
112
kordophone/src/model/jwt.rs
Normal file
112
kordophone/src/model/jwt.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
use base64::{
|
||||||
|
engine::{self, general_purpose},
|
||||||
|
Engine,
|
||||||
|
};
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use hyper::http::HeaderValue;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
struct JwtHeader {
|
||||||
|
alg: String,
|
||||||
|
typ: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
enum ExpValue {
|
||||||
|
Integer(i64),
|
||||||
|
String(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
struct JwtPayload {
|
||||||
|
exp: serde_json::Value,
|
||||||
|
iss: Option<String>,
|
||||||
|
user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct JwtToken {
|
||||||
|
header: JwtHeader,
|
||||||
|
payload: JwtPayload,
|
||||||
|
signature: Vec<u8>,
|
||||||
|
expiration_date: DateTime<Utc>,
|
||||||
|
token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JwtToken {
|
||||||
|
fn decode_token_using_engine(
|
||||||
|
token: &str,
|
||||||
|
engine: engine::GeneralPurpose,
|
||||||
|
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||||
|
let mut parts = token.split('.');
|
||||||
|
let header = parts.next().unwrap();
|
||||||
|
let payload = parts.next().unwrap();
|
||||||
|
let signature = parts.next().unwrap();
|
||||||
|
|
||||||
|
let header = engine.decode(header)?;
|
||||||
|
let payload = engine.decode(payload)?;
|
||||||
|
let signature = engine.decode(signature)?;
|
||||||
|
|
||||||
|
// Parse jwt header
|
||||||
|
let header: JwtHeader = serde_json::from_slice(&header)?;
|
||||||
|
|
||||||
|
// Parse jwt payload
|
||||||
|
let payload: JwtPayload = serde_json::from_slice(&payload)?;
|
||||||
|
|
||||||
|
// Parse jwt expiration date
|
||||||
|
// Annoyingly, because of my own fault, this could be either an integer or string.
|
||||||
|
let exp: i64 = payload.exp.as_i64().unwrap_or_else(|| {
|
||||||
|
let exp: String = payload.exp.as_str().unwrap().to_string();
|
||||||
|
exp.parse().unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let timestamp = chrono::NaiveDateTime::from_timestamp_opt(exp, 0).unwrap();
|
||||||
|
let expiration_date = DateTime::<Utc>::from_utc(timestamp, Utc);
|
||||||
|
|
||||||
|
Ok(JwtToken {
|
||||||
|
header,
|
||||||
|
payload,
|
||||||
|
signature,
|
||||||
|
expiration_date,
|
||||||
|
token: token.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(token: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||||
|
// STUPID: My mock server uses a different encoding than the real server, so we have to
|
||||||
|
// try both encodings here.
|
||||||
|
|
||||||
|
Self::decode_token_using_engine(token, general_purpose::STANDARD).or(
|
||||||
|
Self::decode_token_using_engine(token, general_purpose::URL_SAFE_NO_PAD),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dummy() -> Self {
|
||||||
|
JwtToken {
|
||||||
|
header: JwtHeader {
|
||||||
|
alg: "none".to_string(),
|
||||||
|
typ: "JWT".to_string(),
|
||||||
|
},
|
||||||
|
payload: JwtPayload {
|
||||||
|
exp: serde_json::Value::Null,
|
||||||
|
iss: None,
|
||||||
|
user: None,
|
||||||
|
},
|
||||||
|
signature: vec![],
|
||||||
|
expiration_date: Utc::now(),
|
||||||
|
token: "".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_valid(&self) -> bool {
|
||||||
|
self.expiration_date > Utc::now()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_header_value(&self) -> HeaderValue {
|
||||||
|
format!("Bearer {}", self.token).parse().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
|
||||||
pub mod conversation;
|
pub mod conversation;
|
||||||
pub use conversation::Conversation;
|
pub use conversation::Conversation;
|
||||||
|
|
||||||
|
pub mod jwt;
|
||||||
|
pub use jwt::JwtToken;
|
||||||
@@ -9,7 +9,7 @@ pub mod api_interface {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_version() {
|
async fn test_version() {
|
||||||
let client = TestClient::new();
|
let mut client = TestClient::new();
|
||||||
let version = client.get_version().await.unwrap();
|
let version = client.get_version().await.unwrap();
|
||||||
assert_eq!(version, client.version);
|
assert_eq!(version, client.version);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
pub use crate::APIInterface;
|
pub use crate::APIInterface;
|
||||||
use crate::model::Conversation;
|
use crate::{api::http_client::Credentials, model::{Conversation, JwtToken}};
|
||||||
|
|
||||||
pub struct TestClient {
|
pub struct TestClient {
|
||||||
pub version: &'static str,
|
pub version: &'static str,
|
||||||
@@ -24,11 +24,15 @@ impl TestClient {
|
|||||||
impl APIInterface for TestClient {
|
impl APIInterface for TestClient {
|
||||||
type Error = TestError;
|
type Error = TestError;
|
||||||
|
|
||||||
async fn get_version(&self) -> Result<String, Self::Error> {
|
async fn authenticate(&mut self, credentials: Credentials) -> Result<JwtToken, Self::Error> {
|
||||||
|
Ok(JwtToken::dummy())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_version(&mut self) -> Result<String, Self::Error> {
|
||||||
Ok(self.version.to_string())
|
Ok(self.version.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_conversations(&self) -> Result<Vec<Conversation>, Self::Error> {
|
async fn get_conversations(&mut self) -> Result<Vec<Conversation>, Self::Error> {
|
||||||
Ok(self.conversations.clone())
|
Ok(self.conversations.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user