aboutsummaryrefslogtreecommitdiff
path: root/api/cmd
diff options
context:
space:
mode:
authorZach Berwaldt <zberwaldt@tutamail.com>2024-03-07 19:16:07 -0500
committerZach Berwaldt <zberwaldt@tutamail.com>2024-03-07 19:16:07 -0500
commit6651daca670664f3de8af9c7bcb74b1e7c6c6be9 (patch)
treef1bb35ef8b0e1d498842a17b84c87c6ee996274e /api/cmd
parent5fa57845052655883120ba4d19a85d8756fb8d8c (diff)
Add CORS middleware and authentication middleware to the API server.
The `setupRouter` function in `main.go` now includes a CORS middleware and a token authentication middleware. The CORS middleware allows cross-origin resource sharing by setting the appropriate response headers. The token authentication middleware checks for the presence of an `Authorization` header with a valid bearer token. If the token is missing or invalid, an unauthorized response is returned. In addition to these changes, a new test file `main_test.go` has been added to test the `/api/v1/auth` route. This test suite includes two test cases: one for successful authentication and one for failed authentication. Update go.mod to include new dependencies. The `go.mod` file has been modified to include two new dependencies: `github.com/spf13/viper` and `github.com/stretchr/testify`. Ignore go.sum changes. Ignore changes in the `go.sum` file, as they only include updates to existing dependencies.
Diffstat (limited to 'api/cmd')
-rw-r--r--api/cmd/main.go104
-rw-r--r--api/cmd/main_test.go56
2 files changed, 160 insertions, 0 deletions
diff --git a/api/cmd/main.go b/api/cmd/main.go
new file mode 100644
index 0000000..1924556
--- /dev/null
+++ b/api/cmd/main.go
@@ -0,0 +1,104 @@
1package main
2
3import (
4 "errors"
5 "log"
6 "net/http"
7 "strings"
8 "water/api/internal/database"
9 "water/api/internal/controllers"
10
11 "github.com/gin-gonic/gin"
12 _ "github.com/mattn/go-sqlite3"
13)
14
15func CORSMiddleware() gin.HandlerFunc {
16 return func(c *gin.Context) {
17 c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
18 c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
19 c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
20 c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
21
22 if c.Request.Method == "OPTIONS" {
23 log.Println(c.Request.Header)
24 c.AbortWithStatus(http.StatusNoContent)
25 return
26 }
27
28 c.Next()
29 }
30}
31
32func checkForTokenInContext(c *gin.Context) (string, error) {
33 authorizationHeader := c.GetHeader("Authorization")
34 if authorizationHeader == "" {
35 return "", errors.New("authorization header is missing")
36 }
37
38 parts := strings.Split(authorizationHeader, " ")
39
40 if len(parts) != 2 || parts[0] != "Bearer" {
41 return "", errors.New("invalid Authorization header format")
42 }
43
44 return parts[1], nil
45}
46
47func TokenRequired() gin.HandlerFunc {
48 return func(c *gin.Context) {
49 _, err := checkForTokenInContext(c)
50
51 if err != nil {
52 c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
53 c.Abort()
54 return
55 }
56
57 c.Next()
58 }
59}
60
61func setupRouter() *gin.Engine {
62 // Disable Console Color
63 // gin.DisableConsoleColor()
64 r := gin.Default()
65 r.Use(CORSMiddleware())
66 r.Use(gin.Logger())
67 r.Use(gin.Recovery())
68
69 api := r.Group("api/v1")
70
71 api.POST("/auth", controllers.AuthHandler)
72
73 user := api.Group("/user/:uuid")
74 user.Use(TokenRequired())
75 {
76 user.GET("", controllers.GetUser)
77 user.GET("preferences", controllers.GetUserPreferences)
78 user.PATCH("preferences", controllers.UpdateUserPreferences)
79 }
80
81 stats := api.Group("/stats")
82 stats.Use(TokenRequired())
83 {
84 stats.GET("/", controllers.GetAllStatistics)
85 stats.POST("/", controllers.PostNewStatistic)
86 stats.GET("weekly/", controllers.GetWeeklyStatistics)
87 stats.GET("daily/", controllers.GetDailyUserStatistics)
88 stats.GET("user/:uuid", controllers.GetUserStatistics)
89 stats.PATCH("user/:uuid", controllers.UpdateUserStatistic)
90 stats.DELETE("user/:uuid", controllers.DeleteUserStatistic)
91 }
92
93 return r
94}
95
96func main() {
97 database.SetupDatabase()
98 r := setupRouter()
99 // Listen and Server in 0.0.0.0:8080
100 err := r.Run(":8080")
101 if err != nil {
102 return
103 }
104}
diff --git a/api/cmd/main_test.go b/api/cmd/main_test.go
new file mode 100644
index 0000000..8d0df8d
--- /dev/null
+++ b/api/cmd/main_test.go
@@ -0,0 +1,56 @@
1package main
2
3import (
4 "encoding/json"
5 "log"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9
10 "github.com/spf13/viper"
11 "github.com/stretchr/testify/assert"
12)
13
14func getTestUserCredentials() (string, string) {
15 viper.SetConfigName(".env")
16 viper.AddConfigPath(".")
17 err := viper.ReadInConfig()
18 if err != nil {
19 log.Fatalf("Error while reading config file %s", err)
20 }
21
22 testUser := viper.GetString("TEST_USER")
23 testPass := viper.GetString("TEST_PASS")
24 return testUser, testPass
25}
26
27func TestAuthRoute(t *testing.T) {
28 router := setupRouter()
29
30 username, password := getTestUserCredentials()
31
32 w := httptest.NewRecorder()
33 req, _ := http.NewRequest("POST", "/api/v1/auth", nil)
34 req.SetBasicAuth(username, password)
35 router.ServeHTTP(w, req)
36
37 assert.Equal(t, http.StatusOK, w.Code, "response should return a 200 code")
38
39 var response map[string]interface{}
40 _ = json.Unmarshal(w.Body.Bytes(), &response)
41 _, exists := response["token"]
42 assert.True(t, exists, "response should return a token")
43}
44
45func TestAuthRouteFailure(t *testing.T) {
46 router := setupRouter()
47
48 w := httptest.NewRecorder()
49 req, _ := http.NewRequest("POST", "/api/v1/auth", nil)
50 req.SetBasicAuth("asdf", "asdf")
51 router.ServeHTTP(w, req)
52
53 assert.Equal(t, http.StatusUnauthorized, w.Code, "should return a 401 code")
54}
55
56func Test \ No newline at end of file