diff --git a/api/auth/middleware.go b/api/auth/middleware.go index e2400db..1531c6c 100644 --- a/api/auth/middleware.go +++ b/api/auth/middleware.go @@ -33,5 +33,24 @@ func LoggedIn(c *gin.Context) { } func IsAdmin(c *gin.Context) { + neptun, exists := c.Get("neptun") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "status": http.StatusUnauthorized, + "error": "not logged in", + }) + c.Abort() + return + } + + err := auth.IsAdmin(neptun.(string)) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "status": http.StatusUnauthorized, + "error": "not an admin", + }) + c.Abort() + return + } } diff --git a/api/endpotins.go b/api/endpotins.go index d1792dc..5fecd2e 100644 --- a/api/endpotins.go +++ b/api/endpotins.go @@ -26,7 +26,7 @@ func Listen(address string, path string) { apiAuth.GET("login", auth.Login) } - apiTest := api.Group("test").Use(auth.LoggedIn) + apiTest := api.Group("test").Use(auth.LoggedIn).Use(auth.IsAdmin) { apiTest.GET("logged_in", func(c *gin.Context) { neptun, _ := c.Get("neptun") @@ -37,6 +37,13 @@ func Listen(address string, path string) { "neptun": neptun, }) }) + + apiTest.GET("is_admin", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": http.StatusOK, + "message": "if you see this you are an admin", + }) + }) } } diff --git a/database/auth/middleware.go b/database/auth/middleware.go index 2596d8c..ff4f821 100644 --- a/database/auth/middleware.go +++ b/database/auth/middleware.go @@ -1,6 +1,8 @@ package auth import ( + "errors" + "git.tek.govt.hu/dowerx/szoe-pontok/database" "github.com/redis/go-redis/v9" ) @@ -17,7 +19,27 @@ func LoggedIn(token string) (string, error) { } func IsAdmin(neptun string) error { - // db := database.GetDB() + db := database.GetDB() + + rows, err := db.NamedQuery(`select count(*) from "admin" inner join "user" on "user"."id" = "admin"."user" where "user"."neptun" = :neptun`, + map[string]interface{}{ + "neptun": neptun, + }) + + if err != nil { + return err + } + + var count int + if !rows.Next() { + return errors.New("not an admin") + } + + rows.Scan(&count) + + if count != 1 { + return errors.New("not an admin") + } return nil }