From 6651daca670664f3de8af9c7bcb74b1e7c6c6be9 Mon Sep 17 00:00:00 2001 From: Zach Berwaldt Date: Thu, 7 Mar 2024 19:16:07 -0500 Subject: Add CORS middleware and authentication middleware to the API server. The `setupRouter` function in `main.go` now includes a CORS middleware and a token authentication middleware. The CORS middleware allows cross-origin resource sharing by setting the appropriate response headers. The token authentication middleware checks for the presence of an `Authorization` header with a valid bearer token. If the token is missing or invalid, an unauthorized response is returned. In addition to these changes, a new test file `main_test.go` has been added to test the `/api/v1/auth` route. This test suite includes two test cases: one for successful authentication and one for failed authentication. Update go.mod to include new dependencies. The `go.mod` file has been modified to include two new dependencies: `github.com/spf13/viper` and `github.com/stretchr/testify`. Ignore go.sum changes. Ignore changes in the `go.sum` file, as they only include updates to existing dependencies. --- api/main.go | 301 ------------------------------------------------------------ 1 file changed, 301 deletions(-) delete mode 100644 api/main.go (limited to 'api/main.go') diff --git a/api/main.go b/api/main.go deleted file mode 100644 index 17a3c3a..0000000 --- a/api/main.go +++ /dev/null @@ -1,301 +0,0 @@ -package main - -import ( - "crypto/rand" - "database/sql" - "encoding/base64" - "errors" - "log" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - _ "github.com/mattn/go-sqlite3" - "golang.org/x/crypto/bcrypt" -) - -func CORSMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") - - if c.Request.Method == "OPTIONS" { - log.Println(c.Request.Header) - c.AbortWithStatus(http.StatusNoContent) - return - } - - c.Next() - } -} - -// generatToken will g -func generateToken() string { - token := make([]byte, 32) - _, err := rand.Read(token) - if err != nil { - return "" - } - return base64.StdEncoding.EncodeToString(token) -} - -func establishDBConnection() *sql.DB { - db, err := sql.Open("sqlite3", "../db/water.sqlite3") - if err != nil { - panic(err) - } - return db -} - -func checkForTokenInContext(c *gin.Context) (string, error) { - authorizationHeader := c.GetHeader("Authorization") - if authorizationHeader == "" { - return "", errors.New("authorization header is missing") - } - - parts := strings.Split(authorizationHeader, " ") - - if len(parts) != 2 || parts[0] != "Bearer" { - return "", errors.New("invalid Authorization header format") - } - - return parts[1], nil -} - -func TokenRequired() gin.HandlerFunc { - return func(c *gin.Context) { - _, err := checkForTokenInContext(c) - - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - - c.Next() - } -} - -func setupRouter() *gin.Engine { - // Disable Console Color - // gin.DisableConsoleColor() - r := gin.Default() - r.Use(CORSMiddleware()) - r.Use(gin.Logger()) - r.Use(gin.Recovery()) - - api := r.Group("api/v1") - - api.POST("/auth", func(c *gin.Context) { - username, password, ok := c.Request.BasicAuth() - if !ok { - c.Header("WWW-Authenticate", `Basic realm="Please enter your username and password."`) - c.AbortWithStatus(http.StatusUnauthorized) - return - } - - db := establishDBConnection() - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(db) - - var user User - var preference Preference - var size Size - - row := db.QueryRow("SELECT name, uuid, password, color, size, unit FROM Users u INNER JOIN Preferences p ON p.user_id = u.id INNER JOIN Sizes s ON p.size_id = s.id WHERE u.name = ?", username) - if err := row.Scan(&user.Name, &user.UUID, &user.Password, &preference.Color, &size.Size, &size.Unit); err != nil { - if errors.Is(err, sql.ErrNoRows) { - c.AbortWithStatus(http.StatusUnauthorized) - return - } - } - - if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - c.AbortWithStatus(http.StatusUnauthorized) - return - } - - preference.Size = size - - // Generate a simple API token - apiToken := generateToken() - c.JSON(http.StatusOK, gin.H{"token": apiToken, "user": user, "preferences": preference}) - }) - - stats := api.Group("/stats") - stats.Use(TokenRequired()) - { - stats.GET("/", func(c *gin.Context) { - db := establishDBConnection() - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(db) - - rows, err := db.Query("SELECT s.date, s.quantity, u.uuid, u.name FROM Statistics s INNER JOIN Users u ON u.id = s.user_id") - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer func(rows *sql.Rows) { - err := rows.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(rows) - - var data []Statistic - - for rows.Next() { - var stat Statistic - var user User - if err := rows.Scan(&stat.Date, &stat.Quantity, &user.UUID, &user.Name); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - stat.User = user - data = append(data, stat) - } - - c.JSON(http.StatusOK, data) - }) - - stats.POST("/", func(c *gin.Context) { - var stat StatisticPost - - if err := c.BindJSON(&stat); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - db := establishDBConnection() - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(db) - - result, err := db.Exec("INSERT INTO statistics (date, user_id, quantity) values (?, ?, ?)", stat.Date, stat.UserID, stat.Quantity) - - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } - - id, err := result.LastInsertId() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } - - c.JSON(http.StatusCreated, gin.H{"status": "created", "id": id}) - }) - - stats.GET("weekly/", func(c *gin.Context) { - db := establishDBConnection() - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(db) - - rows, err := db.Query("SELECT date, total FROM `WeeklyStatisticsView`") - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer func(rows *sql.Rows) { - err := rows.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(rows) - - var data []WeeklyStatistic - for rows.Next() { - var weeklyStat WeeklyStatistic - if err := rows.Scan(&weeklyStat.Date, &weeklyStat.Total); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - data = append(data, weeklyStat) - } - - c.JSON(http.StatusOK, data) - }) - - stats.GET("totals/", func(c *gin.Context) { - db := establishDBConnection() - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(db) - - rows, err := db.Query("SELECT name, total FROM DailyUserStatistics") - - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer func(rows *sql.Rows) { - err := rows.Close() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }(rows) - - var data []DailyUserTotals - for rows.Next() { - var stat DailyUserTotals - if err := rows.Scan(&stat.Name, &stat.Total); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - data = append(data, stat) - } - - c.JSON(http.StatusOK, data) - - }) - - stats.GET("user/:uuid", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"status": "ok", "uuid": c.Param("uuid")}) - }) - - stats.PATCH("user/:uuid", func(c *gin.Context) { - c.JSON(http.StatusNoContent, gin.H{"status": "No Content"}) - }) - - stats.DELETE("user/:uuid", func(c *gin.Context) { - c.JSON(http.StatusNoContent, gin.H{"status": "No Content"}) - }) - } - - return r -} - -func main() { - r := setupRouter() - // Listen and Server in 0.0.0.0:8080 - err := r.Run(":8080") - if err != nil { - return - } -} -- cgit v1.1