diff --git a/go.mod b/go.mod index ab8e074..164d9ae 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module code.severnaya.net/kordophone-mock/v2 go 1.17 -require github.com/google/uuid v1.3.0 // indirect +require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect + github.com/google/uuid v1.3.0 // indirect +) diff --git a/go.sum b/go.sum index 3dfe1c9..f712d1b 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/main.go b/main.go index e5d7f82..9381445 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,11 @@ import ( func main() { log.Println("Initializing") - s := web.NewMockHTTPServer() + c := web.MockHTTPServerConfiguration{ + AuthEnabled: false, + } + + s := web.NewMockHTTPServer(c) httpServer := &http.Server{ Addr: ":5738", Handler: s, diff --git a/model/authtoken.go b/model/authtoken.go new file mode 100644 index 0000000..fca9ec8 --- /dev/null +++ b/model/authtoken.go @@ -0,0 +1,59 @@ +package model + +import ( + "encoding/base64" + "log" + "time" + + "github.com/dgrijalva/jwt-go" +) + +type AuthToken struct { + SignedToken string `json:"jwt"` + token jwt.Token +} + +type TokenGenerationError struct { + message string +} + +func (e *TokenGenerationError) Error() string { + return e.message +} + +// Create a struct to hold your custom claims +type customClaims struct { + Username string `json:"username"` + jwt.StandardClaims +} + +const signingKey = "nDjYmTjoPrAGzuyhHz6Dq5bqcRrEZJc5Ls3SQcdylBI=" + +func NewAuthToken(username string) (*AuthToken, error) { + claims := customClaims{ + Username: username, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Hour * 24 * 5).Unix(), // 5 days + }, + } + + // Create a new JWT token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + if token == nil { + log.Printf("Error creating Jwt Token") + return nil, &TokenGenerationError{"Error creating Jwt Token"} + } + + // Sign the token with the specified signing key + decodedSigningKey, _ := base64.StdEncoding.DecodeString(signingKey) + signedToken, err := token.SignedString(decodedSigningKey) + if err != nil { + log.Printf("Error signing Jwt Token: %s", err) + return nil, &TokenGenerationError{"Error signing Jwt Token"} + } + + return &AuthToken{ + SignedToken: signedToken, + token: *token, + }, nil +} diff --git a/server/server.go b/server/server.go index ef2f441..bcb2b58 100644 --- a/server/server.go +++ b/server/server.go @@ -7,9 +7,23 @@ import ( const VERSION = "Kordophone-2.0" +const ( + AUTH_USERNAME = "test" + AUTH_PASSWORD = "test" +) + type Server struct { version string conversations []model.Conversation + authTokens []model.AuthToken +} + +type AuthError struct { + message string +} + +func (e *AuthError) Error() string { + return e.message } func NewServer() *Server { @@ -40,3 +54,39 @@ func (s *Server) PopulateWithTestData() { s.conversations = cs } + +func (s *Server) Authenticate(username string, password string) (*model.AuthToken, error) { + if username != AUTH_USERNAME || password != AUTH_PASSWORD { + return nil, &AuthError{"Invalid username or password"} + } + + token, err := model.NewAuthToken(username) + if err != nil { + return nil, err + } + + // Register for future auth + s.registerAuthToken(token) + + return token, nil +} + +func (s *Server) CheckBearerToken(token string) bool { + return s.authenticateToken(token) +} + +// Private + +func (s *Server) registerAuthToken(token *model.AuthToken) { + s.authTokens = append(s.authTokens, *token) +} + +func (s *Server) authenticateToken(token string) bool { + for _, t := range s.authTokens { + if t.SignedToken == token { + return true + } + } + + return false +} diff --git a/web/request_types.go b/web/request_types.go new file mode 100644 index 0000000..ee92e49 --- /dev/null +++ b/web/request_types.go @@ -0,0 +1,6 @@ +package web + +type AuthenticationRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} diff --git a/web/server.go b/web/server.go index 5bacb8f..4febbd0 100644 --- a/web/server.go +++ b/web/server.go @@ -9,9 +9,47 @@ import ( "code.severnaya.net/kordophone-mock/v2/server" ) +type MockHTTPServerConfiguration struct { + AuthEnabled bool +} + type MockHTTPServer struct { - Server server.Server - mux http.ServeMux + Server server.Server + mux http.ServeMux + authEnabled bool +} + +type AuthError struct { + message string +} + +func (e *AuthError) Error() string { + return e.message +} + +func (m *MockHTTPServer) checkAuthentication(r *http.Request) error { + if !m.authEnabled { + return nil + } + + // Check for Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return &AuthError{"Missing Authorization header"} + } + + // Check for "Bearer" prefix + if authHeader[:7] != "Bearer " { + return &AuthError{"Invalid Authorization header"} + } + + // Check for valid token + token := authHeader[7:] + if !m.Server.CheckBearerToken(token) { + return &AuthError{"Invalid token"} + } + + return nil } func (m *MockHTTPServer) handleVersion(w http.ResponseWriter, r *http.Request) { @@ -19,6 +57,12 @@ func (m *MockHTTPServer) handleVersion(w http.ResponseWriter, r *http.Request) { } func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { + if err := m.checkAuthentication(r); err != nil { + log.Printf("Status: Error checking authentication: %s", err) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + fmt.Fprintf(w, "OK") } @@ -38,19 +82,54 @@ func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Requ w.Write(jsonData) } +func (m *MockHTTPServer) handleAuthenticate(w http.ResponseWriter, r *http.Request) { + // Decode request body as AuthenticationRequest + var authReq AuthenticationRequest + err := json.NewDecoder(r.Body).Decode(&authReq) + if err != nil { + log.Printf("Authenticate: Error decoding request body: %s", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Authenticate + token, err := m.Server.Authenticate(authReq.Username, authReq.Password) + if err != nil { + log.Printf("Authenticate: Error authenticating: %s", err) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + // Write response + w.Header().Set("Content-Type", "application/json") + + // Encode token as JSON + jsonData, err := json.Marshal(token) + if err != nil { + log.Printf("Error marshalling token: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Write JSON to response + w.Write(jsonData) +} + func (m *MockHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.mux.ServeHTTP(w, r) } -func NewMockHTTPServer() *MockHTTPServer { +func NewMockHTTPServer(config MockHTTPServerConfiguration) *MockHTTPServer { this := MockHTTPServer{ - Server: *server.NewServer(), - mux: *http.NewServeMux(), + Server: *server.NewServer(), + mux: *http.NewServeMux(), + authEnabled: config.AuthEnabled, } this.mux.Handle("/version", http.HandlerFunc(this.handleVersion)) this.mux.Handle("/conversations", http.HandlerFunc(this.handleConversations)) this.mux.Handle("/status", http.HandlerFunc(this.handleStatus)) + this.mux.Handle("/authenticate", http.HandlerFunc(this.handleAuthenticate)) return &this } diff --git a/web/server_test.go b/web/server_test.go index 17a7c00..2eb595d 100644 --- a/web/server_test.go +++ b/web/server_test.go @@ -1,6 +1,7 @@ package web_test import ( + "bytes" "encoding/json" "io" "net/http" @@ -14,7 +15,7 @@ import ( ) func TestVersion(t *testing.T) { - s := httptest.NewServer(web.NewMockHTTPServer()) + s := httptest.NewServer(web.NewMockHTTPServer(web.MockHTTPServerConfiguration{})) resp, err := http.Get(s.URL + "/version") if err != nil { @@ -32,7 +33,7 @@ func TestVersion(t *testing.T) { } func TestStatus(t *testing.T) { - s := httptest.NewServer(web.NewMockHTTPServer()) + s := httptest.NewServer(web.NewMockHTTPServer(web.MockHTTPServerConfiguration{})) resp, err := http.Get(s.URL + "/status") if err != nil { @@ -50,7 +51,7 @@ func TestStatus(t *testing.T) { } func TestConversations(t *testing.T) { - server := web.NewMockHTTPServer() + server := web.NewMockHTTPServer(web.MockHTTPServerConfiguration{}) httpServer := httptest.NewServer(server) conversation := model.Conversation{ @@ -105,3 +106,91 @@ func TestConversations(t *testing.T) { t.Fatalf("Unexpected conversation Date: %s (expected %s)", convos[0].Date, conversation.Date) } } + +func TestAuthentication(t *testing.T) { + s := web.NewMockHTTPServer(web.MockHTTPServerConfiguration{AuthEnabled: true}) + httpServer := httptest.NewServer(s) + + // First, try authenticated request and make sure it fails + resp, err := http.Get(httpServer.URL + "/status") + if err != nil { + t.Fatalf("TestAuthentication status error: %s", err) + } + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("Unexpected status code: %d (expected %d)", resp.StatusCode, http.StatusUnauthorized) + } + + tryAuthenticate := func(username string, password string) *http.Response { + authRequest := web.AuthenticationRequest{ + Username: username, + Password: password, + } + + authRequestJSON, err := json.Marshal(authRequest) + if err != nil { + t.Fatalf("Error marshalling JSON: %s", err) + } + + resp, err := http.Post(httpServer.URL+"/authenticate", "application/json", io.NopCloser(bytes.NewReader(authRequestJSON))) + if err != nil { + t.Fatalf("TestAuthentication error: %s", err) + } + + return resp + } + + // Send authentication request with bad credentials + resp = tryAuthenticate("bad", "credentials") + if resp.StatusCode == http.StatusOK { + t.Fatalf("Unexpected status code: %d (expected %d)", resp.StatusCode, http.StatusUnauthorized) + } + + // Now try good credentials + resp = tryAuthenticate(server.AUTH_USERNAME, server.AUTH_PASSWORD) + if resp.StatusCode != http.StatusOK { + t.Fatalf("Unexpected status code: %d (expected %d)", resp.StatusCode, http.StatusOK) + } + + // Decode the token from the body. + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error decoding body: %s", body) + } + + var authToken model.AuthToken + err = json.Unmarshal(body, &authToken) + if err != nil { + t.Fatalf("Error unmarshalling JSON: %s, body: %s", err, body) + } + + if authToken.SignedToken == "" { + t.Fatalf("Unexpected empty signed token") + } + + // Send a request with the signed token + req, err := http.NewRequest(http.MethodGet, httpServer.URL+"/status", nil) + if err != nil { + t.Fatalf("Error creating request: %s", err) + } + + req.Header.Set("Authorization", "Bearer "+authToken.SignedToken) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Error sending request: %s", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Unexpected status code: %d (expected %d)", resp.StatusCode, http.StatusUnauthorized) + } + + body, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error decoding body: %s", body) + } + + if string(body) != "OK" { + t.Fatalf("Unexpected body: %s (expected %s)", body, "OK") + } +}