diff --git a/main.go b/main.go index 2c03d65..a228984 100644 --- a/main.go +++ b/main.go @@ -21,16 +21,13 @@ func (t *LoggingHook) Run(e *zerolog.Event, level zerolog.Level, message string) t.prompt.CleanAndRefreshForLogging() } -func setupLogging() { - debug := flag.Bool("debug", false, "enable debug logging") - flag.Parse() - +func setupLogging(debug bool) { // Pretty logging log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) // Default level for this example is info, unless debug flag is present zerolog.SetGlobalLevel(zerolog.InfoLevel) - if *debug { + if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } } @@ -47,12 +44,16 @@ func printWelcomeMessage() { } func main() { - setupLogging() + debugLogging := flag.Bool("debug", false, "enable debug logging") + authEnabled := flag.Bool("auth", false, "enable authentication") + flag.Parse() + + setupLogging(*debugLogging) printWelcomeMessage() c := web.MockHTTPServerConfiguration{ - AuthEnabled: false, + AuthEnabled: *authEnabled, } addr := ":5738" diff --git a/web/server.go b/web/server.go index 7b5619f..92988aa 100644 --- a/web/server.go +++ b/web/server.go @@ -62,14 +62,26 @@ func (m *MockHTTPServer) checkAuthentication(r *http.Request) error { return nil } +func (m *MockHTTPServer) requireAuthentication(w http.ResponseWriter, r *http.Request) bool { + if !m.authEnabled { + return true + } + + if err := m.checkAuthentication(r); err != nil { + log.Error().Err(err).Msg("Error checking authentication") + http.Error(w, err.Error(), http.StatusUnauthorized) + return false + } + + return true +} + func (m *MockHTTPServer) handleVersion(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", m.Server.Version()) } func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { - if err := m.checkAuthentication(r); err != nil { - log.Error().Err(err).Msg("Status: Error checking authentication") - http.Error(w, err.Error(), http.StatusUnauthorized) + if !m.requireAuthentication(w, r) { return } @@ -77,6 +89,10 @@ func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { } func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Request) { + if !m.requireAuthentication(w, r) { + return + } + convos := m.Server.Conversations() // Encode convos as JSON @@ -93,6 +109,10 @@ func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Requ } func (m *MockHTTPServer) handleMessages(w http.ResponseWriter, r *http.Request) { + if !m.requireAuthentication(w, r) { + return + } + guid := r.URL.Query().Get("guid") if len(guid) == 0 { log.Error().Msg("handleMessage: Got empty guid parameter") @@ -196,6 +216,10 @@ func (m *MockHTTPServer) handleNotFound(w http.ResponseWriter, r *http.Request) } func (m *MockHTTPServer) handleSendMessage(w http.ResponseWriter, r *http.Request) { + if !m.requireAuthentication(w, r) { + return + } + // Decode request body as SendMessageRequest var sendMessageReq SendMessageRequest err := json.NewDecoder(r.Body).Decode(&sendMessageReq) @@ -243,6 +267,10 @@ func (m *MockHTTPServer) handleSendMessage(w http.ResponseWriter, r *http.Reques } func (m *MockHTTPServer) handlePollUpdates(w http.ResponseWriter, r *http.Request) { + if !m.requireAuthentication(w, r) { + return + } + // TODO: This should block if we don't have updates for that seq yet. seq := -1 @@ -280,6 +308,10 @@ func (m *MockHTTPServer) handlePollUpdates(w http.ResponseWriter, r *http.Reques } func (m *MockHTTPServer) handleMarkConversation(w http.ResponseWriter, r *http.Request) { + if !m.requireAuthentication(w, r) { + return + } + guid := r.URL.Query().Get("guid") if len(guid) == 0 { log.Error().Msg("handleMarkConversation: Got empty guid parameter")