diff --git a/server/server.go b/server/server.go index 4588037..d5787ca 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "sort" + "time" "code.severnaya.net/kordophone-mock/v2/data" "code.severnaya.net/kordophone-mock/v2/model" @@ -25,6 +26,14 @@ type Server struct { updateItemSeq int } +type MessagesQuery struct { + ConversationGUID string + BeforeDate *time.Time + AfterGUID *string + BeforeGUID *string + Limit *int +} + type AuthError struct { message string } @@ -136,15 +145,58 @@ func (s *Server) CheckBearerToken(token string) bool { return s.authenticateToken(token) } -func (s *Server) MessagesForConversation(conversation *model.Conversation) []model.Message { - messages := s.messageStore[conversation.Guid] +func (s *Server) PerformMessageQuery(query *MessagesQuery) []model.Message { + messages := s.messageStore[query.ConversationGUID] + + // Sort sort.Slice(messages, func(i int, j int) bool { return messages[i].Date.Before(messages[j].Date) }) + // Apply before/after filters + if query.BeforeGUID != nil { + beforeGUID := *query.BeforeGUID + for i := range messages { + if messages[i].Guid == beforeGUID { + messages = messages[:i] + break + } + } + } else if query.AfterGUID != nil { + afterGUID := *query.AfterGUID + for i := range messages { + if messages[i].Guid == afterGUID { + messages = messages[i+1:] + break + } + } + } else if query.BeforeDate != nil { + beforeDate := *query.BeforeDate + for i := range messages { + if messages[i].Date.Before(beforeDate) { + messages = messages[:i] + break + } + } + } + + // Limit + if query.Limit != nil { + limit := *query.Limit + if len(messages) > limit { + messages = messages[:limit] + } + } + return messages } +func (s *Server) MessagesForConversation(conversation *model.Conversation) []model.Message { + return s.PerformMessageQuery(&MessagesQuery{ + ConversationGUID: conversation.Guid, + }) +} + func (s *Server) AppendMessageToConversation(conversation *model.Conversation, message model.Message) { s.messageStore[conversation.Guid] = append(s.messageStore[conversation.Guid], message) } diff --git a/web/server.go b/web/server.go index c6fd1c2..2aad9bd 100644 --- a/web/server.go +++ b/web/server.go @@ -92,8 +92,6 @@ func (m *MockHTTPServer) handleConversations(w http.ResponseWriter, r *http.Requ } func (m *MockHTTPServer) handleMessages(w http.ResponseWriter, r *http.Request) { - // TODO handle optional "limit", "beforeDate", "beforeMessageGUID", and "afterMessageGUID" parameters - guid := r.URL.Query().Get("guid") if len(guid) == 0 { log.Error().Msg("handleMessage: Got empty guid parameter") @@ -108,7 +106,43 @@ func (m *MockHTTPServer) handleMessages(w http.ResponseWriter, r *http.Request) return } - messages := m.Server.MessagesForConversation(conversation) + beforeDate := r.URL.Query().Get("beforeDate") + beforeGUID := r.URL.Query().Get("beforeMessageGUID") + afterGUID := r.URL.Query().Get("afterMessageGUID") + limit := r.URL.Query().Get("limit") + + stringOrNil := func(s string) *string { + if len(s) == 0 { + return nil + } + return &s + } + + dateOrNil := func(s string) *time.Time { + if len(s) == 0 { + return nil + } + t, _ := time.Parse(time.RFC3339, s) + return &t + } + + intOrNil := func(s string) *int { + if len(s) == 0 { + return nil + } + i, _ := strconv.Atoi(s) + return &i + } + + query := server.MessagesQuery{ + Limit: intOrNil(limit), + BeforeDate: dateOrNil(beforeDate), + BeforeGUID: stringOrNil(beforeGUID), + AfterGUID: stringOrNil(afterGUID), + ConversationGUID: conversation.Guid, + } + + messages := m.Server.PerformMessageQuery(&query) jsonData, err := json.Marshal(messages) if err != nil { diff --git a/web/server_test.go b/web/server_test.go index 589beac..1831a01 100644 --- a/web/server_test.go +++ b/web/server_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "code.severnaya.net/kordophone-mock/v2/data" "code.severnaya.net/kordophone-mock/v2/model" "code.severnaya.net/kordophone-mock/v2/server" "code.severnaya.net/kordophone-mock/v2/web" @@ -342,3 +343,100 @@ func TestMarkConversation(t *testing.T) { t.Fatalf("Unexpected unread count: %d (expected %d)", convo.UnreadCount, 0) } } + +func TestMessageQueries(t *testing.T) { + s := web.NewMockHTTPServer(web.MockHTTPServerConfiguration{AuthEnabled: true}) + httpServer := httptest.NewServer(s) + + // Mock conversation + guid := "1234567890" + conversation := model.Conversation{ + Date: time.Now(), + Participants: []string{"Alice"}, + UnreadCount: 0, + Guid: guid, + } + + s.Server.AddConversation(conversation) + + // Mock messages + numMessages := 20 + for i := 0; i < numMessages; i++ { + message := data.GenerateRandomMessage(conversation.Participants) + s.Server.AppendMessageToConversation(&conversation, message) + } + + // Pick a pivot message from the sorted list + sortedMessages := s.Server.MessagesForConversation(&conversation) + pivotMessage := sortedMessages[len(sortedMessages)/2] + + // Query messages before the pivot, test limit also + limitMessageCount := 5 + resp, err := http.Get(httpServer.URL + fmt.Sprintf("/messages?guid=%s&beforeMessageGUID=%s&limit=%d", guid, pivotMessage.Guid, limitMessageCount)) + if err != nil { + t.Fatalf("TestMessageQueries error: %s", err) + } + + // Decode response + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error decoding body: %s", body) + } + + var messages []model.Message + err = json.Unmarshal(body, &messages) + if err != nil { + t.Fatalf("Error unmarshalling JSON: %s", err) + } + + if len(messages) != limitMessageCount { + t.Fatalf("Unexpected num messages: %d (expected %d)", len(messages), limitMessageCount) + } + + // Make sure before query is exclusive of the pivot message + for _, message := range messages { + if message.Guid == pivotMessage.Guid { + t.Fatalf("Found pivot guid in before query: %s (expected != %s)", message.Guid, pivotMessage.Guid) + } + } + + // Make sure messages are actually before the pivot + for _, message := range messages { + if message.Date.After(pivotMessage.Date) { + t.Fatalf("Unexpected message date: %s (expected before %s)", message.Date, pivotMessage.Date) + } + } + + // Query messages after the pivot + resp, err = http.Get(httpServer.URL + fmt.Sprintf("/messages?guid=%s&afterMessageGUID=%s", guid, pivotMessage.Guid)) + if err != nil { + t.Fatalf("TestMessageQueries error: %s", err) + } + + // Decode response + body, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error decoding body: %s", body) + } + + messages = []model.Message{} + err = json.Unmarshal(body, &messages) + if err != nil { + t.Fatalf("Error unmarshalling JSON: %s", err) + } + + // Make sure after query is exclusive of the pivot message + for _, message := range messages { + if message.Guid == pivotMessage.Guid { + t.Fatalf("Found pivot guid in after query: %s (expected != %s)", message.Guid, pivotMessage.Guid) + } + } + + // Make sure messages are actually after the pivot + for _, message := range messages { + if message.Date.Before(pivotMessage.Date) { + t.Fatalf("Unexpected message date: %s (expected after %s)", message.Date, pivotMessage.Date) + } + } + +}