From 06a2c55da50414cb6b38d3100755e4b6edc7b78b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BENEDEK=20L=C3=A1szl=C3=B3?= Date: Fri, 6 Jun 2025 16:17:02 +0200 Subject: [PATCH] message notifications --- api/auth.go | 18 +++++- api/chat.go | 90 ++++++++++++++++++++++++++ api/endpoints.go | 34 +++++++++- controller/AuthController.go | 8 +++ controller/ChatController.go | 62 ++++++++++++++++-- dao/Factory.go | 14 ++++ dao/{IChannelDAD.go => IChannelDAO.go} | 0 dao/{IMessageDAD.go => IMessageDAO.go} | 2 +- dao/INotificationDAO.go | 6 +- dao/postgres/MessageDAO.go | 46 +++++++++++-- dao/valkey/NotificationDAO.go | 33 ++++++---- go.mod | 1 + go.sum | 2 + model/Message.go | 6 +- util/errors.go | 5 ++ 15 files changed, 292 insertions(+), 35 deletions(-) rename dao/{IChannelDAD.go => IChannelDAO.go} (100%) rename dao/{IMessageDAD.go => IMessageDAO.go} (86%) diff --git a/api/auth.go b/api/auth.go index b8704fb..d27663c 100644 --- a/api/auth.go +++ b/api/auth.go @@ -7,10 +7,14 @@ import ( "github.com/gin-gonic/gin" ) -const SESSION_COOKIE string = "session" +const ( + SESSION_COOKIE string = "session" + USER_ID string = "user_id" +) func isLoggedIn(c *gin.Context) { token, err := c.Cookie(SESSION_COOKIE) + if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "missing token", @@ -19,7 +23,15 @@ func isLoggedIn(c *gin.Context) { return } + id, authErr := authController.IsLoggedIn(token) + if authErr != nil { + sendError(c, authErr) + c.Abort() + return + } + c.Set(SESSION_COOKIE, token) + c.Set(USER_ID, id) c.Next() } @@ -68,7 +80,7 @@ func login(c *gin.Context) { return } - c.SetCookie(SESSION_COOKIE, token, config.GetConfig().API.TokenLife, "", "", false, false) + c.SetCookie(SESSION_COOKIE, token, config.GetConfig().API.TokenLife, "/", "", false, false) c.JSON(http.StatusOK, gin.H{ "message": "sucessful login", "session": token, @@ -83,7 +95,7 @@ func logout(c *gin.Context) { return } - c.SetCookie(SESSION_COOKIE, "", 0, "", "", false, false) + c.SetCookie(SESSION_COOKIE, "", 0, "/", "", false, false) c.JSON(http.StatusOK, gin.H{ "message": "sucessful logout", }) diff --git a/api/chat.go b/api/chat.go index f878998..00e0d58 100644 --- a/api/chat.go +++ b/api/chat.go @@ -1,11 +1,23 @@ package api import ( + "fmt" "net/http" + "strconv" + "git.tek.govt.hu/dowerx/chat/server/model" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) +const CHANNEL_ID string = "id" + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + func listAvailableChannels(c *gin.Context) { token, _ := c.Get(SESSION_COOKIE) @@ -20,3 +32,81 @@ func listAvailableChannels(c *gin.Context) { "channels": channels, }) } + +func getMessages(c *gin.Context) { + id, err := strconv.Atoi(c.Param(CHANNEL_ID)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid channel id", + }) + return + } + + messages, msgErr := chatController.GetMessages(id) + if msgErr != nil { + sendError(c, msgErr) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "messages listed", + "messages": messages, + }) +} + +func sendMessage(c *gin.Context) { + token, _ := c.Get(SESSION_COOKIE) + + message := model.Message{} + if err := c.Bind(&message); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + }) + return + } + + err := chatController.SendMessage(token.(string), message.Channel, message.Content) + if err != nil { + fmt.Println(err.Error()) + + sendError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "message sent", + }) +} + +func subscribeToChannel(c *gin.Context) { + // TODO: check if the user has right to subscribe to the given channel + id, err := strconv.Atoi(c.Param(CHANNEL_ID)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid channel id", + }) + return + } + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + fmt.Println(err.Error()) + return + } + defer conn.Close() + + messages, chanErr := chatController.SubscribeToChannel(c, id) + + for { + select { + case msg := <-messages: + if err := conn.WriteJSON(msg); err != nil { + return + } + + case err := <-chanErr: + fmt.Println(err.Error()) + return + } + } +} diff --git a/api/endpoints.go b/api/endpoints.go index 46b1a0d..bd64798 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -30,9 +30,24 @@ func initControlles() *util.ChatError { return err } + messageDAO, err := dao.GetMessageDAO() + if err != nil { + return err + } + + notifcationDAO, err := dao.GetNotificationDAO() + if err != nil { + return err + } + userController = controller.MakeUserController(userDAO) authController = controller.MakeAuthController(userDAO, sessionDAO) - chatController = controller.MakeChatController(channelDAO, sessionDAO) + chatController = controller.MakeChatController( + channelDAO, + sessionDAO, + messageDAO, + notifcationDAO, + userDAO) return nil } @@ -44,6 +59,20 @@ func Listen(address string, base string) error { router := gin.Default() + if gin.DebugMode == "debug" { + router.Use(func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization") + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + c.Next() + }) + } + api := router.Group(base) auth := api.Group("auth") @@ -59,6 +88,9 @@ func Listen(address string, base string) error { chat := api.Group("chat") chat.Use(isLoggedIn) chat.GET("channels", listAvailableChannels) + chat.GET("messages/:id", getMessages) + chat.POST("send", sendMessage) + chat.GET("subscribe/:id", subscribeToChannel) server := &http.Server{ Addr: address, diff --git a/controller/AuthController.go b/controller/AuthController.go index bbecd45..08f4d29 100644 --- a/controller/AuthController.go +++ b/controller/AuthController.go @@ -23,6 +23,14 @@ const ( TOKEN_LENGTH int = 32 ) +func (c AuthController) IsLoggedIn(token string) (int, *util.ChatError) { + id, err := c.sessionDAO.Get(token) + if err != nil { + err = &util.ChatError{Message: "", Code: util.NOT_LOGGED_IN} + } + return id, err +} + func (c AuthController) Register(username string, password string, repeatPassword string) *util.ChatError { if len(username) < MIN_USERNAME_LENGTH { return &util.ChatError{Message: "", Code: util.USERNAME_TOO_SHORT} diff --git a/controller/ChatController.go b/controller/ChatController.go index 8d39205..ed2c7df 100644 --- a/controller/ChatController.go +++ b/controller/ChatController.go @@ -1,14 +1,20 @@ package controller import ( + "context" + "time" + "git.tek.govt.hu/dowerx/chat/server/dao" "git.tek.govt.hu/dowerx/chat/server/model" "git.tek.govt.hu/dowerx/chat/server/util" ) type ChatController struct { - channelDAO dao.IChannelDAO - sessionDAO dao.ISessionDAO + channelDAO dao.IChannelDAO + sessionDAO dao.ISessionDAO + messageDAO dao.IMessageDAO + notifcationDAO dao.INotificationDAO + userDAO dao.IUserDAO } func (c ChatController) ListAvailableChannels(token string) ([]model.Channel, *util.ChatError) { @@ -20,7 +26,55 @@ func (c ChatController) ListAvailableChannels(token string) ([]model.Channel, *u return c.channelDAO.ListAvailableChannels(userID) } -func MakeChatController(channelDAO dao.IChannelDAO, sessionDAO dao.ISessionDAO) ChatController { - controller := ChatController{channelDAO: channelDAO, sessionDAO: sessionDAO} +func (c ChatController) GetMessages(channel int) ([]model.Message, *util.ChatError) { + return c.messageDAO.List(model.Channel{ID: channel}) +} + +func (c ChatController) SendMessage(token string, channel int, content string) *util.ChatError { + sender_id, err := c.sessionDAO.Get(token) + if err != nil { + return err + } + + user, err := c.userDAO.Read(model.User{ID: sender_id}) + if err != nil { + return err + } + + // TODO: check if user has right to send message in the given channel + + message := model.Message{ + SenderID: user.ID, + SenderName: user.Username, + Channel: channel, + Time: time.Now().UTC(), + Content: content, + } + + message.ID, err = c.messageDAO.Create(message) + if err != nil { + return err + } + + return c.notifcationDAO.SendMessage(message) +} + +func (c ChatController) SubscribeToChannel(ctx context.Context, id int) (<-chan model.Message, <-chan *util.ChatError) { + return c.notifcationDAO.SubscribeToChannel(ctx, id) +} + +func MakeChatController( + channelDAO dao.IChannelDAO, + sessionDAO dao.ISessionDAO, + messageDAO dao.IMessageDAO, + notificationDAO dao.INotificationDAO, + userDAO dao.IUserDAO) ChatController { + controller := ChatController{ + channelDAO: channelDAO, + sessionDAO: sessionDAO, + messageDAO: messageDAO, + notifcationDAO: notificationDAO, + userDAO: userDAO, + } return controller } diff --git a/dao/Factory.go b/dao/Factory.go index ef9a27f..7277052 100644 --- a/dao/Factory.go +++ b/dao/Factory.go @@ -10,6 +10,7 @@ var userDAO IUserDAO var channelDAO IChannelDAO var sessionDAO ISessionDAO var messageDAO IMessageDAO +var notifcationDAO INotificationDAO func GetUserDAO() (IUserDAO, *util.ChatError) { if userDAO == nil { @@ -62,3 +63,16 @@ func GetSessionDAO() (ISessionDAO, *util.ChatError) { return sessionDAO, nil } + +func GetNotificationDAO() (INotificationDAO, *util.ChatError) { + if notifcationDAO == nil { + dao, err := valkey.MakeNotificationDAO() + if err != nil { + return notifcationDAO, err + } + + notifcationDAO = dao + } + + return notifcationDAO, nil +} diff --git a/dao/IChannelDAD.go b/dao/IChannelDAO.go similarity index 100% rename from dao/IChannelDAD.go rename to dao/IChannelDAO.go diff --git a/dao/IMessageDAD.go b/dao/IMessageDAO.go similarity index 86% rename from dao/IMessageDAD.go rename to dao/IMessageDAO.go index 639c3e1..fe70d41 100644 --- a/dao/IMessageDAD.go +++ b/dao/IMessageDAO.go @@ -6,7 +6,7 @@ import ( ) type IMessageDAO interface { - Create(message model.Message) *util.ChatError + Create(message model.Message) (int, *util.ChatError) Read(id int) (model.Message, *util.ChatError) List(channel model.Channel) ([]model.Message, *util.ChatError) Update(message model.Message) *util.ChatError diff --git a/dao/INotificationDAO.go b/dao/INotificationDAO.go index 8f78313..f0ce780 100644 --- a/dao/INotificationDAO.go +++ b/dao/INotificationDAO.go @@ -7,7 +7,7 @@ import ( "git.tek.govt.hu/dowerx/chat/server/util" ) -type INotification interface { - SendMessage(ctx context.Context, message model.Message) *util.ChatError - SubscribeToChannel(ctx context.Context, id int) (<-chan model.Message, func(), *util.ChatError) +type INotificationDAO interface { + SendMessage(message model.Message) *util.ChatError + SubscribeToChannel(ctx context.Context, id int) (<-chan model.Message, <-chan *util.ChatError) } diff --git a/dao/postgres/MessageDAO.go b/dao/postgres/MessageDAO.go index a951eed..23199e4 100644 --- a/dao/postgres/MessageDAO.go +++ b/dao/postgres/MessageDAO.go @@ -11,9 +11,22 @@ type MessageDAO struct { } // Create a new message -func (d MessageDAO) Create(message model.Message) *util.ChatError { - _, err := d.db.NamedExec(`insert into "message" ("sender_id", "channel_id", "time", "content") values (:sender_id, :channel_id, :time, :content)`, &message) - return util.MakeError(err, util.DATABASE_QUERY_FAULT) +func (d MessageDAO) Create(message model.Message) (int, *util.ChatError) { + rows, err := d.db.NamedQuery(`insert into "message" ("sender_id", "channel_id", "time", "content") values (:sender_id, :channel_id, :time, :content) returning "id"`, &message) + if err != nil { + return 0, util.MakeError(err, util.DATABASE_QUERY_FAULT) + } + + if !rows.Next() { + return 0, &util.ChatError{Message: "failed to insert new message", Code: util.DATABASE_QUERY_FAULT} + } + + var id int + if rows.Scan(&id) != nil { + return 0, &util.ChatError{Message: "failed to return new message id", Code: util.DATABASE_QUERY_FAULT} + } + + return int(id), util.MakeError(err, util.DATABASE_QUERY_FAULT) } // Read returns a message by ID @@ -41,9 +54,30 @@ func (d MessageDAO) List(channel model.Channel) ([]model.Message, *util.ChatErro var rows *sqlx.Rows var err error if channel.ID != 0 { - rows, err = d.db.Queryx(`select * from "message" where "id" = :id order by "time"`, &channel) + rows, err = d.db.NamedQuery( + `select + "m"."id" as "id", + "u"."username" as "sender_name", + "m"."channel_id" as "channel_id", + "m"."time" as "time", + "m"."content" as "content" + from "message" as "m" + inner join "user" "u" on "u"."id" = "m"."sender_id" + where "m"."channel_id" = :id order by "time"`, + &channel) } else { - rows, err = d.db.Queryx(`select * from "message" where "name" = :name order by "time"`, &channel) + rows, err = d.db.NamedQuery( + `select + "m"."id" as "id", + "u"."username" as "sender_name", + "m"."channel_id" as "channel_id", + "m"."time" as "time", + "m"."content" as "content" + from "message" as "m" + inner join "user" "u" on "u"."id" = "m"."sender_id" + inner join "channel" "c" on "c"."id" = "m"."channel_id" + where "c"."name" = :name order by "time"`, + &channel) } if err != nil { @@ -56,7 +90,7 @@ func (d MessageDAO) List(channel model.Channel) ([]model.Message, *util.ChatErro for rows.Next() { message := model.Message{} - err = rows.StructScan(&channel) + err = rows.StructScan(&message) if err != nil { break } diff --git a/dao/valkey/NotificationDAO.go b/dao/valkey/NotificationDAO.go index e5ceaef..b2a73f7 100644 --- a/dao/valkey/NotificationDAO.go +++ b/dao/valkey/NotificationDAO.go @@ -27,25 +27,30 @@ func (d NotificationDAOVK) SendMessage(message model.Message) *util.ChatError { return util.MakeError((*d.vk).Do(context.Background(), cmd).Error(), util.DATABASE_QUERY_FAULT) } -func (d NotificationDAOVK) SubscribeToChannel(ctx context.Context, id int) (<-chan model.Message, *util.ChatError) { +func (d NotificationDAOVK) SubscribeToChannel(ctx context.Context, id int) (<-chan model.Message, <-chan *util.ChatError) { cmd := (*d.vk).B().Subscribe().Channel(CHANNEL_PREFIX + strconv.Itoa(id)).Build() - var messages chan model.Message = make(chan model.Message) + var messages chan model.Message = make(chan model.Message, 1) + var errChan chan *util.ChatError = make(chan *util.ChatError, 1) - err := (*d.vk).Receive(ctx, cmd, func(msg valkey.PubSubMessage) { - go func() { - message := model.Message{} - err := json.Unmarshal([]byte(msg.Message), &message) - if err != nil { - return - } + go func() { + err := (*d.vk).Receive(ctx, cmd, func(msg valkey.PubSubMessage) { + go func() { + message := model.Message{} + err := json.Unmarshal([]byte(msg.Message), &message) + if err != nil { + return + } - messages <- message - }() - }) - defer close(messages) + messages <- message + }() + }) + defer close(messages) - return messages, util.MakeError(err, util.GENERAL_ERROR) + errChan <- util.MakeError(err, util.DATABASE_QUERY_FAULT) + }() + + return messages, errChan } func MakeNotificationDAO() (NotificationDAOVK, *util.ChatError) { diff --git a/go.mod b/go.mod index debc521..99994fd 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 03378c0..8c7d9c0 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqw github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/model/Message.go b/model/Message.go index 0ec27ee..bbb4a97 100644 --- a/model/Message.go +++ b/model/Message.go @@ -4,9 +4,9 @@ import "time" type Message struct { ID int `db:"id" json:"id"` - SenderID string `db:"sender_id" json:"-"` + SenderID int `db:"sender_id" json:"-"` SenderName string `db:"sender_name" json:"sender_name"` - Channel int `db:"channel_id" json:"channel_id"` + Channel int `db:"channel_id" form:"channel_id" json:"channel_id"` Time time.Time `db:"time" json:"time"` - Content string `db:"content" json:"content"` + Content string `db:"content" form:"content" json:"content"` } diff --git a/util/errors.go b/util/errors.go index 3caaa0a..baf9bf0 100644 --- a/util/errors.go +++ b/util/errors.go @@ -18,6 +18,8 @@ const ( USERNAME_TOO_SHORT PASSWORD_TOO_SHORT PASSWORDS_DONT_MATCH + + NOT_LOGGED_IN ) var codeToMessage = map[ChatErrorCode]string{ @@ -29,6 +31,7 @@ var codeToMessage = map[ChatErrorCode]string{ USERNAME_TOO_SHORT: "username is too short", PASSWORD_TOO_SHORT: "password is too short", PASSWORDS_DONT_MATCH: "passwords do not match", + NOT_LOGGED_IN: "not logged in", } type ChatError struct { @@ -68,6 +71,8 @@ func (e *ChatError) Status() int { fallthrough case PASSWORDS_DONT_MATCH: return http.StatusOK + case NOT_LOGGED_IN: + return http.StatusUnauthorized default: return http.StatusInternalServerError }