From 6651daca670664f3de8af9c7bcb74b1e7c6c6be9 Mon Sep 17 00:00:00 2001 From: Zach Berwaldt Date: Thu, 7 Mar 2024 19:16:07 -0500 Subject: Add CORS middleware and authentication middleware to the API server. The `setupRouter` function in `main.go` now includes a CORS middleware and a token authentication middleware. The CORS middleware allows cross-origin resource sharing by setting the appropriate response headers. The token authentication middleware checks for the presence of an `Authorization` header with a valid bearer token. If the token is missing or invalid, an unauthorized response is returned. In addition to these changes, a new test file `main_test.go` has been added to test the `/api/v1/auth` route. This test suite includes two test cases: one for successful authentication and one for failed authentication. Update go.mod to include new dependencies. The `go.mod` file has been modified to include two new dependencies: `github.com/spf13/viper` and `github.com/stretchr/testify`. Ignore go.sum changes. Ignore changes in the `go.sum` file, as they only include updates to existing dependencies. --- api/cmd/main.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 api/cmd/main.go (limited to 'api/cmd/main.go') diff --git a/api/cmd/main.go b/api/cmd/main.go new file mode 100644 index 0000000..1924556 --- /dev/null +++ b/api/cmd/main.go @@ -0,0 +1,104 @@ +package main + +import ( + "errors" + "log" + "net/http" + "strings" + "water/api/internal/database" + "water/api/internal/controllers" + + "github.com/gin-gonic/gin" + _ "github.com/mattn/go-sqlite3" +) + +func CORSMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") + + if c.Request.Method == "OPTIONS" { + log.Println(c.Request.Header) + c.AbortWithStatus(http.StatusNoContent) + return + } + + c.Next() + } +} + +func checkForTokenInContext(c *gin.Context) (string, error) { + authorizationHeader := c.GetHeader("Authorization") + if authorizationHeader == "" { + return "", errors.New("authorization header is missing") + } + + parts := strings.Split(authorizationHeader, " ") + + if len(parts) != 2 || parts[0] != "Bearer" { + return "", errors.New("invalid Authorization header format") + } + + return parts[1], nil +} + +func TokenRequired() gin.HandlerFunc { + return func(c *gin.Context) { + _, err := checkForTokenInContext(c) + + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + + c.Next() + } +} + +func setupRouter() *gin.Engine { + // Disable Console Color + // gin.DisableConsoleColor() + r := gin.Default() + r.Use(CORSMiddleware()) + r.Use(gin.Logger()) + r.Use(gin.Recovery()) + + api := r.Group("api/v1") + + api.POST("/auth", controllers.AuthHandler) + + user := api.Group("/user/:uuid") + user.Use(TokenRequired()) + { + user.GET("", controllers.GetUser) + user.GET("preferences", controllers.GetUserPreferences) + user.PATCH("preferences", controllers.UpdateUserPreferences) + } + + stats := api.Group("/stats") + stats.Use(TokenRequired()) + { + stats.GET("/", controllers.GetAllStatistics) + stats.POST("/", controllers.PostNewStatistic) + stats.GET("weekly/", controllers.GetWeeklyStatistics) + stats.GET("daily/", controllers.GetDailyUserStatistics) + stats.GET("user/:uuid", controllers.GetUserStatistics) + stats.PATCH("user/:uuid", controllers.UpdateUserStatistic) + stats.DELETE("user/:uuid", controllers.DeleteUserStatistic) + } + + return r +} + +func main() { + database.SetupDatabase() + r := setupRouter() + // Listen and Server in 0.0.0.0:8080 + err := r.Run(":8080") + if err != nil { + return + } +} -- cgit v1.1