Private
Public Access
1
0

Actually check authentication

This commit is contained in:
2023-12-10 19:51:18 -08:00
parent 416949095c
commit d8ca07f92a
2 changed files with 43 additions and 10 deletions

15
main.go
View File

@@ -21,16 +21,13 @@ func (t *LoggingHook) Run(e *zerolog.Event, level zerolog.Level, message string)
t.prompt.CleanAndRefreshForLogging() t.prompt.CleanAndRefreshForLogging()
} }
func setupLogging() { func setupLogging(debug bool) {
debug := flag.Bool("debug", false, "enable debug logging")
flag.Parse()
// Pretty logging // Pretty logging
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// Default level for this example is info, unless debug flag is present // Default level for this example is info, unless debug flag is present
zerolog.SetGlobalLevel(zerolog.InfoLevel) zerolog.SetGlobalLevel(zerolog.InfoLevel)
if *debug { if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
} }
@@ -47,12 +44,16 @@ func printWelcomeMessage() {
} }
func main() { func main() {
setupLogging() debugLogging := flag.Bool("debug", false, "enable debug logging")
authEnabled := flag.Bool("auth", false, "enable authentication")
flag.Parse()
setupLogging(*debugLogging)
printWelcomeMessage() printWelcomeMessage()
c := web.MockHTTPServerConfiguration{ c := web.MockHTTPServerConfiguration{
AuthEnabled: false, AuthEnabled: *authEnabled,
} }
addr := ":5738" addr := ":5738"

View File

@@ -62,14 +62,26 @@ func (m *MockHTTPServer) checkAuthentication(r *http.Request) error {
return nil return nil
} }
func (m *MockHTTPServer) requireAuthentication(w http.ResponseWriter, r *http.Request) bool {
if !m.authEnabled {
return true
}
if err := m.checkAuthentication(r); err != nil {
log.Error().Err(err).Msg("Error checking authentication")
http.Error(w, err.Error(), http.StatusUnauthorized)
return false
}
return true
}
func (m *MockHTTPServer) handleVersion(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleVersion(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", m.Server.Version()) fmt.Fprintf(w, "%s", m.Server.Version())
} }
func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if err := m.checkAuthentication(r); err != nil { if !m.requireAuthentication(w, r) {
log.Error().Err(err).Msg("Status: Error checking authentication")
http.Error(w, err.Error(), http.StatusUnauthorized)
return return
} }
@@ -77,6 +89,10 @@ func (m *MockHTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
} }
func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Request) {
if !m.requireAuthentication(w, r) {
return
}
convos := m.Server.Conversations() convos := m.Server.Conversations()
// Encode convos as JSON // Encode convos as JSON
@@ -93,6 +109,10 @@ func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Requ
} }
func (m *MockHTTPServer) handleMessages(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleMessages(w http.ResponseWriter, r *http.Request) {
if !m.requireAuthentication(w, r) {
return
}
guid := r.URL.Query().Get("guid") guid := r.URL.Query().Get("guid")
if len(guid) == 0 { if len(guid) == 0 {
log.Error().Msg("handleMessage: Got empty guid parameter") log.Error().Msg("handleMessage: Got empty guid parameter")
@@ -196,6 +216,10 @@ func (m *MockHTTPServer) handleNotFound(w http.ResponseWriter, r *http.Request)
} }
func (m *MockHTTPServer) handleSendMessage(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleSendMessage(w http.ResponseWriter, r *http.Request) {
if !m.requireAuthentication(w, r) {
return
}
// Decode request body as SendMessageRequest // Decode request body as SendMessageRequest
var sendMessageReq SendMessageRequest var sendMessageReq SendMessageRequest
err := json.NewDecoder(r.Body).Decode(&sendMessageReq) err := json.NewDecoder(r.Body).Decode(&sendMessageReq)
@@ -243,6 +267,10 @@ func (m *MockHTTPServer) handleSendMessage(w http.ResponseWriter, r *http.Reques
} }
func (m *MockHTTPServer) handlePollUpdates(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handlePollUpdates(w http.ResponseWriter, r *http.Request) {
if !m.requireAuthentication(w, r) {
return
}
// TODO: This should block if we don't have updates for that seq yet. // TODO: This should block if we don't have updates for that seq yet.
seq := -1 seq := -1
@@ -280,6 +308,10 @@ func (m *MockHTTPServer) handlePollUpdates(w http.ResponseWriter, r *http.Reques
} }
func (m *MockHTTPServer) handleMarkConversation(w http.ResponseWriter, r *http.Request) { func (m *MockHTTPServer) handleMarkConversation(w http.ResponseWriter, r *http.Request) {
if !m.requireAuthentication(w, r) {
return
}
guid := r.URL.Query().Get("guid") guid := r.URL.Query().Get("guid")
if len(guid) == 0 { if len(guid) == 0 {
log.Error().Msg("handleMarkConversation: Got empty guid parameter") log.Error().Msg("handleMarkConversation: Got empty guid parameter")