diff --git a/backend/.env.sample b/backend/.env.sample index e692901..d929ca0 100644 --- a/backend/.env.sample +++ b/backend/.env.sample @@ -1,3 +1,6 @@ +# Port +PORT=3003 + # Db address DB_ADDRESS=data/sqlite3.db @@ -22,3 +25,8 @@ FRONTEND_URL=https://c2r5p11.hive.fi:5173 # 2FA TWO_FA_URL_PREFIX=otpauth://totp/aaa?secret= + +# Rate Limiter +RATE_LIMITER_DURATION_IN_SECONDS=60 +RATE_LIMITER_REQUEST_LIMIT=1000 +RATE_LIMITER_CLEANUP_INTERVAL_IN_SECONDS=300 \ No newline at end of file diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 739ac45..78a791e 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -1,9 +1,13 @@ package main import ( + "context" + "fmt" "net/http" "os" + "os/signal" "strings" + "syscall" "time" swaggerfiles "github.com/swaggo/files" @@ -12,10 +16,9 @@ import ( "github.com/gin-gonic/gin" "github.com/joho/godotenv" _ "github.com/paularynty/transcendence/auth-service-go/docs" - "github.com/paularynty/transcendence/auth-service-go/internal/config" - "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/routers" + "github.com/paularynty/transcendence/auth-service-go/internal/service" "github.com/paularynty/transcendence/auth-service-go/internal/util" "log/slog" @@ -56,7 +59,7 @@ func SetupRouter(dep *dependency.Dependency) *gin.Engine { MaxAge: 12 * time.Hour, })) - rateLimiter := middleware.NewRateLimiter(60*time.Second, 1000) + rateLimiter := middleware.NewRateLimiter(time.Duration(dep.Cfg.RateLimiterDurationInSec)*time.Second, dep.Cfg.RateLimiterRequestLimit, time.Duration(dep.Cfg.RateLimiterCleanupIntervalInSec)*time.Second) r.Use(rateLimiter.RateLimit()) r.Use(middleware.PanicHandler()) @@ -66,15 +69,6 @@ func SetupRouter(dep *dependency.Dependency) *gin.Engine { return r } -func initDependency() *dependency.Dependency { - logger := util.GetLogger(slog.LevelInfo) - cfg := config.LoadConfigFromEnv() - myDB := db.GetDB(cfg.DbAddress, logger) - redis := db.GetRedis(cfg.RedisURL, cfg, logger) - - return dependency.NewDependency(cfg, myDB, redis, logger) -} - // @title Auth Service API // @version 1.0 // @description Auth service @@ -83,17 +77,28 @@ func main() { // config _ = godotenv.Load() + // logger + logger := util.GetLogger(slog.LevelDebug) + // init dependency - dep := initDependency() - defer db.CloseDB(dep.DB, dep.Logger) - defer db.CloseRedis(dep.Redis, dep.Logger) + dep, err := dependency.InitDependency(logger) + if err != nil { + util.LogFatalErr(logger, err, "failed to create dependency") + } + defer dependency.CloseDependency(dep) // validator dto.InitValidator() + // create services + userService, err := service.NewUserService(dep) + if err != nil { + util.LogFatalErr(logger, err, "failed to create user service") + } + // router r := SetupRouter(dep) - routers.UsersRouter(r.Group("/api/users"), dep) + routers.UsersRouter(r.Group("/api/users"), userService) // Health check r.GET("/api/ping", func(c *gin.Context) { @@ -105,8 +110,29 @@ func main() { // Swagger r.GET("/api/docs/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) - if err := r.Run(":3003"); err != nil { - dep.Logger.Error("failed to start server", "err", err) - os.Exit(1) + // http server + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", dep.Cfg.Port), + Handler: r, + } + + // Start server + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + util.LogFatalErr(logger, err, "failed to start server") + } + }() + + // Graceful shutdown + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit // consume the signal, blocking here + logger.Info("shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + util.LogFatalErr(logger, err, "server forced to shutdown") } + logger.Info("server exiting") } diff --git a/backend/internal/auth_error/auth_error.go b/backend/internal/auth_error/auth_error.go new file mode 100644 index 0000000..692a048 --- /dev/null +++ b/backend/internal/auth_error/auth_error.go @@ -0,0 +1,17 @@ +package authError + +type AuthError struct { + Status int + Message string +} + +func (e *AuthError) Error() string { + return e.Message +} + +func NewAuthError(status int, message string) *AuthError { + return &AuthError{ + Status: status, + Message: message, + } +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 28cfe08..2c144fa 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1,25 +1,30 @@ package config import ( + "fmt" "os" "strconv" ) type Config struct { - GinMode string - DbAddress string - JwtSecret string - UserTokenExpiry int - OauthStateTokenExpiry int - GoogleClientId string - GoogleClientSecret string - GoogleRedirectUri string - FrontendUrl string - TwoFaUrlPrefix string - TwoFaTokenExpiry int - RedisURL string - IsRedisEnabled bool - UserTokenAbsoluteExpiry int + GinMode string + DbAddress string + JwtSecret string + UserTokenExpiry int + OauthStateTokenExpiry int + GoogleClientId string + GoogleClientSecret string + GoogleRedirectUri string + FrontendUrl string + TwoFaUrlPrefix string + TwoFaTokenExpiry int + RedisURL string + IsRedisEnabled bool + UserTokenAbsoluteExpiry int + Port int + RateLimiterDurationInSec int + RateLimiterRequestLimit int + RateLimiterCleanupIntervalInSec int } func getEnvStrOrDefault(key string, defaultValue string) string { @@ -32,14 +37,14 @@ func getEnvStrOrDefault(key string, defaultValue string) string { return value } -func getEnvStrOrPanic(key string) string { +func getEnvStrOrError(key string) (string, error) { value := os.Getenv(key) if value == "" { - panic("environment variable " + key + " is required but not set") + return "", fmt.Errorf("environment variable %s is required but not set", key) } - return value + return value, nil } func getEnvIntOrDefault(key string, defaultValue int) int { @@ -53,21 +58,40 @@ func getEnvIntOrDefault(key string, defaultValue int) int { return intValue } -func LoadConfigFromEnv() *Config { - return &Config{ - GinMode: getEnvStrOrDefault("GIN_MODE", "debug"), - DbAddress: getEnvStrOrDefault("DB_ADDRESS", "data/auth_service_db.sqlite"), - JwtSecret: getEnvStrOrPanic("JWT_SECRET"), - UserTokenExpiry: getEnvIntOrDefault("USER_TOKEN_EXPIRY", 3600), - OauthStateTokenExpiry: getEnvIntOrDefault("OAUTH_STATE_TOKEN_EXPIRY", 600), - GoogleClientId: getEnvStrOrPanic("GOOGLE_CLIENT_ID"), - GoogleClientSecret: getEnvStrOrPanic("GOOGLE_CLIENT_SECRET"), - GoogleRedirectUri: getEnvStrOrDefault("GOOGLE_REDIRECT_URI", "test-google-redirect-uri"), - FrontendUrl: getEnvStrOrDefault("FRONTEND_URL", "http://localhost:5173"), - TwoFaUrlPrefix: getEnvStrOrDefault("TWO_FA_URL_PREFIX", "otpauth://totp/Transcendence?secret="), - TwoFaTokenExpiry: getEnvIntOrDefault("TWO_FA_TOKEN_EXPIRY", 600), - RedisURL: getEnvStrOrDefault("REDIS_URL", ""), - IsRedisEnabled: getEnvStrOrDefault("REDIS_URL", "") != "", - UserTokenAbsoluteExpiry: getEnvIntOrDefault("USER_TOKEN_ABSOLUTE_EXPIRY", 2592000), +func LoadConfigFromEnv() (*Config, error) { + jwtSecret, err := getEnvStrOrError("JWT_SECRET") + if err != nil { + return nil, err + } + + GoogleClientId, err := getEnvStrOrError("GOOGLE_CLIENT_ID") + if err != nil { + return nil, err } + + GoogleClientSecret, err := getEnvStrOrError("GOOGLE_CLIENT_SECRET") + if err != nil { + return nil, err + } + + return &Config{ + GinMode: getEnvStrOrDefault("GIN_MODE", "debug"), + DbAddress: getEnvStrOrDefault("DB_ADDRESS", "data/auth_service_db.sqlite"), + JwtSecret: jwtSecret, + UserTokenExpiry: getEnvIntOrDefault("USER_TOKEN_EXPIRY", 3600), + OauthStateTokenExpiry: getEnvIntOrDefault("OAUTH_STATE_TOKEN_EXPIRY", 600), + GoogleClientId: GoogleClientId, + GoogleClientSecret: GoogleClientSecret, + GoogleRedirectUri: getEnvStrOrDefault("GOOGLE_REDIRECT_URI", "test-google-redirect-uri"), + FrontendUrl: getEnvStrOrDefault("FRONTEND_URL", "http://localhost:5173"), + TwoFaUrlPrefix: getEnvStrOrDefault("TWO_FA_URL_PREFIX", "otpauth://totp/Transcendence?secret="), + TwoFaTokenExpiry: getEnvIntOrDefault("TWO_FA_TOKEN_EXPIRY", 600), + RedisURL: getEnvStrOrDefault("REDIS_URL", ""), + IsRedisEnabled: getEnvStrOrDefault("REDIS_URL", "") != "", + UserTokenAbsoluteExpiry: getEnvIntOrDefault("USER_TOKEN_ABSOLUTE_EXPIRY", 2592000), + Port: getEnvIntOrDefault("PORT", 3003), + RateLimiterDurationInSec: getEnvIntOrDefault("RATE_LIMITER_DURATION_IN_SECONDS", 60), + RateLimiterRequestLimit: getEnvIntOrDefault("RATE_LIMITER_REQUEST_LIMIT", 1000), + RateLimiterCleanupIntervalInSec: getEnvIntOrDefault("RATE_LIMITER_CLEANUP_INTERVAL_IN_SECONDS", 300), + }, nil } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 87f57d7..51cabcc 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -4,24 +4,18 @@ import ( "testing" ) -func assertPanics(t *testing.T, fn func(), name string) { +func assertError(t *testing.T, err error, name string) { t.Helper() - defer func() { - if r := recover(); r == nil { - t.Fatalf("expected panic for %s", name) - } - }() - fn() + if err == nil { + t.Fatalf("expected error for %s", name) + } } -func assertNotPanics(t *testing.T, fn func(), name string) { +func assertNoError(t *testing.T, err error, name string) { t.Helper() - defer func() { - if r := recover(); r != nil { - t.Fatalf("unexpected panic for %s: %v", name, r) - } - }() - fn() + if err != nil { + t.Fatalf("unexpected error for %s: %v", name, err) + } } func TestGetEnvStrOrDefault(t *testing.T) { @@ -36,18 +30,17 @@ func TestGetEnvStrOrDefault(t *testing.T) { } } -func TestGetEnvStrOrPanic(t *testing.T) { +func TestGetEnvStrOrError(t *testing.T) { t.Setenv("TEST_PANIC", "") - assertPanics(t, func() { - _ = getEnvStrOrPanic("TEST_PANIC") - }, "empty env") + _, err := getEnvStrOrError("TEST_PANIC") + assertError(t, err, "empty env") t.Setenv("TEST_PANIC", "value") - assertNotPanics(t, func() { - if got := getEnvStrOrPanic("TEST_PANIC"); got != "value" { - t.Fatalf("expected env value, got %q", got) - } - }, "set env") + got, err := getEnvStrOrError("TEST_PANIC") + assertNoError(t, err, "set env") + if got != "value" { + t.Fatalf("expected env value, got %q", got) + } } func TestGetEnvIntOrDefault(t *testing.T) { @@ -67,29 +60,25 @@ func TestGetEnvIntOrDefault(t *testing.T) { } } -func TestLoadConfigFromEnv_PanicsOnMissingRequired(t *testing.T) { +func TestLoadConfigFromEnv_ErrsOnMissingRequired(t *testing.T) { t.Setenv("JWT_SECRET", "jwt") t.Setenv("GOOGLE_CLIENT_ID", "client") t.Setenv("GOOGLE_CLIENT_SECRET", "secret") - assertNotPanics(t, func() { - _ = LoadConfigFromEnv() - }, "all required set") + _, err := LoadConfigFromEnv() + assertNoError(t, err, "all required set") t.Setenv("JWT_SECRET", "") - assertPanics(t, func() { - _ = LoadConfigFromEnv() - }, "JWT_SECRET unset") + _, err = LoadConfigFromEnv() + assertError(t, err, "JWT_SECRET unset") t.Setenv("JWT_SECRET", "jwt") t.Setenv("GOOGLE_CLIENT_ID", "") - assertPanics(t, func() { - _ = LoadConfigFromEnv() - }, "GOOGLE_CLIENT_ID unset") + _, err = LoadConfigFromEnv() + assertError(t, err, "GOOGLE_CLIENT_ID unset") t.Setenv("GOOGLE_CLIENT_ID", "client") t.Setenv("GOOGLE_CLIENT_SECRET", "") - assertPanics(t, func() { - _ = LoadConfigFromEnv() - }, "GOOGLE_CLIENT_SECRET unset") + _, err = LoadConfigFromEnv() + assertError(t, err, "GOOGLE_CLIENT_SECRET unset") } diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go index 923361f..b0f57b2 100644 --- a/backend/internal/db/db.go +++ b/backend/internal/db/db.go @@ -2,18 +2,19 @@ package db import ( "context" + "fmt" "log/slog" "gorm.io/driver/sqlite" "gorm.io/gorm" ) -func GetDB(dbName string, logger *slog.Logger) *gorm.DB { +func GetDB(dbName string, logger *slog.Logger) (*gorm.DB, error) { var err error db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) if err != nil { - panic("failed to connect to db: " + dbName) + return nil, fmt.Errorf("failed to connect to db: %w", err) } db.Exec("PRAGMA foreign_keys = ON") @@ -25,13 +26,13 @@ func GetDB(dbName string, logger *slog.Logger) *gorm.DB { &HeartBeat{}, } { if err := db.AutoMigrate(model); err != nil { - panic("failed to migrate model: " + err.Error()) + return nil, fmt.Errorf("failed to migrate model: %w", err) } } logger.Info("connected to db") - return db + return db, nil } func CloseDB(db *gorm.DB, logger *slog.Logger) { diff --git a/backend/internal/db/redis.go b/backend/internal/db/redis.go index b96d1f8..8196307 100644 --- a/backend/internal/db/redis.go +++ b/backend/internal/db/redis.go @@ -2,22 +2,22 @@ package db import ( "context" + "fmt" "log/slog" "github.com/paularynty/transcendence/auth-service-go/internal/config" "github.com/redis/go-redis/v9" ) -func GetRedis(redisURL string, cfg *config.Config, logger *slog.Logger) *redis.Client { +func GetRedis(redisURL string, cfg *config.Config, logger *slog.Logger) (*redis.Client, error) { if !cfg.IsRedisEnabled { - logger.Info("redis is disabled by config") - return nil + return nil, nil } opt, err := redis.ParseURL(redisURL) if err != nil { - panic("failed to parse redis url, err: " + err.Error()) + return nil, fmt.Errorf("failed to parse redis url, err: %w", err) } client := redis.NewClient(opt) @@ -26,12 +26,12 @@ func GetRedis(redisURL string, cfg *config.Config, logger *slog.Logger) *redis.C _, err = client.Ping(ctx).Result() if err != nil { - panic("failed to connect to redis: " + err.Error()) + return nil, fmt.Errorf("failed to connect to redis: %w", err) } logger.Info("connected to redis") - return client + return client, nil } func CloseRedis(client *redis.Client, logger *slog.Logger) { diff --git a/backend/internal/dependency/dependency.go b/backend/internal/dependency/dependency.go index 0bad228..d3c76da 100644 --- a/backend/internal/dependency/dependency.go +++ b/backend/internal/dependency/dependency.go @@ -1,10 +1,12 @@ package dependency import ( + "log/slog" + "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/redis/go-redis/v9" "gorm.io/gorm" - "log/slog" ) type Dependency struct { @@ -22,3 +24,27 @@ func NewDependency(cfg *config.Config, db *gorm.DB, redis *redis.Client, logger Logger: logger, } } + +func InitDependency(logger *slog.Logger) (*Dependency, error) { + cfg, err := config.LoadConfigFromEnv() + if err != nil { + return nil, err + } + + myDB, err := db.GetDB(cfg.DbAddress, logger) + if err != nil { + return nil, err + } + + redis, err := db.GetRedis(cfg.RedisURL, cfg, logger) + if err != nil { + return nil, err + } + + return NewDependency(cfg, myDB, redis, logger), nil +} + +func CloseDependency(dep *Dependency) { + db.CloseDB(dep.DB, dep.Logger) + db.CloseRedis(dep.Redis, dep.Logger) +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 95d4279..26fc584 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -5,8 +5,8 @@ import ( "github.com/gin-gonic/gin" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/service" ) @@ -15,7 +15,7 @@ type UserHandler struct { } func handleError(c *gin.Context, err error) { - var authErr *middleware.AuthError + var authErr *authError.AuthError if errors.As(err, &authErr) { _ = c.AbortWithError(authErr.Status, err) } else { @@ -23,18 +23,6 @@ func handleError(c *gin.Context, err error) { } } -func (h *UserHandler) validateToken(c *gin.Context) (uint, error) { - userID := c.MustGet("userID").(uint) - token := c.MustGet("token").(string) - - err := h.Service.ValidateUserToken(c.Request.Context(), token, userID) - if err != nil { - return 0, err - } - - return userID, nil -} - // @BasePath /users // CreateUserHandler godoc @@ -100,13 +88,9 @@ func (h *UserHandler) LoginUserHandler(c *gin.Context) { // @Success 204 {object} nil // @Router /logout [delete] func (h *UserHandler) LogoutUserHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) - err = h.Service.LogoutUser(c.Request.Context(), userID) + err := h.Service.LogoutUser(c.Request.Context(), userID) if err != nil { handleError(c, err) return @@ -124,11 +108,7 @@ func (h *UserHandler) LogoutUserHandler(c *gin.Context) { // @Success 200 {object} dto.UserWithoutTokenResponse // @Router /me [get] func (h *UserHandler) GetLoggedUserProfileHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) user, err := h.Service.GetUserByID(c.Request.Context(), userID) if err != nil { @@ -150,11 +130,7 @@ func (h *UserHandler) GetLoggedUserProfileHandler(c *gin.Context) { // @Success 200 {object} dto.UserWithTokenResponse // @Router /password [put] func (h *UserHandler) UpdateLoggedUserPasswordHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) request := c.MustGet("validatedBody").(dto.UpdateUserPasswordRequest) @@ -178,15 +154,11 @@ func (h *UserHandler) UpdateLoggedUserPasswordHandler(c *gin.Context) { // @Success 200 {object} dto.UserWithoutTokenResponse // @Router /me [put] func (h *UserHandler) UpdateLoggedUserProfileHandler(c *gin.Context) { - userId, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) request := c.MustGet("validatedBody").(dto.UpdateUserRequest) - user, err := h.Service.UpdateUserProfile(c.Request.Context(), userId, &request) + user, err := h.Service.UpdateUserProfile(c.Request.Context(), userID, &request) if err != nil { handleError(c, err) return @@ -204,13 +176,9 @@ func (h *UserHandler) UpdateLoggedUserProfileHandler(c *gin.Context) { // @Success 204 {object} nil // @Router /me [delete] func (h *UserHandler) DeleteLoggedUserHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) - err = h.Service.DeleteUser(c.Request.Context(), userID) + err := h.Service.DeleteUser(c.Request.Context(), userID) if err != nil { handleError(c, err) return @@ -228,11 +196,7 @@ func (h *UserHandler) DeleteLoggedUserHandler(c *gin.Context) { // @Success 200 {object} dto.TwoFASetupResponse // @Router /2fa/setup [post] func (h *UserHandler) StartTwoFaSetupHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) response, err := h.Service.StartTwoFaSetup(c.Request.Context(), userID) if err != nil { @@ -254,11 +218,7 @@ func (h *UserHandler) StartTwoFaSetupHandler(c *gin.Context) { // @Success 200 {object} dto.UserWithTokenResponse // @Router /2fa/confirm [post] func (h *UserHandler) ConfirmTwoFaSetupHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) request := c.MustGet("validatedBody").(dto.TwoFAConfirmRequest) @@ -282,11 +242,7 @@ func (h *UserHandler) ConfirmTwoFaSetupHandler(c *gin.Context) { // @Success 200 {object} dto.UserWithTokenResponse // @Router /2fa/disable [put] func (h *UserHandler) DisableTwoFaHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) request := c.MustGet("validatedBody").(dto.DisableTwoFARequest) @@ -329,11 +285,7 @@ func (h *UserHandler) TwoFaSubmitHandler(c *gin.Context) { // @Success 200 {array} dto.SimpleUser // @Router / [get] func (h *UserHandler) GetUsersWithLimitedInfoHandler(c *gin.Context) { - _, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + _ = c.MustGet("userID").(uint) users, err := h.Service.GetAllUsersLimitedInfo(c.Request.Context()) if err != nil { @@ -353,11 +305,7 @@ func (h *UserHandler) GetUsersWithLimitedInfoHandler(c *gin.Context) { // @Success 200 {array} dto.FriendResponse // @Router /friends [get] func (h *UserHandler) GetLoggedUsersFriendsHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) friends, err := h.Service.GetUserFriends(c.Request.Context(), userID) if err != nil { @@ -379,15 +327,11 @@ func (h *UserHandler) GetLoggedUsersFriendsHandler(c *gin.Context) { // @Success 201 {object} nil // @Router /friends [post] func (h *UserHandler) AddFriendHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) request := c.MustGet("validatedBody").(dto.AddNewFriendRequest) - err = h.Service.AddNewFriend(c.Request.Context(), userID, &request) + err := h.Service.AddNewFriend(c.Request.Context(), userID, &request) if err != nil { handleError(c, err) return @@ -405,11 +349,7 @@ func (h *UserHandler) AddFriendHandler(c *gin.Context) { // @Success 200 {object} dto.UserValidationResponse // @Router /validate [post] func (h *UserHandler) ValidateUserHandler(c *gin.Context) { - userID, err := h.validateToken(c) - if err != nil { - handleError(c, err) - return - } + userID := c.MustGet("userID").(uint) c.JSON(200, dto.UserValidationResponse{UserID: userID}) } @@ -443,14 +383,14 @@ func (h *UserHandler) GoogleCallbackHandler(c *gin.Context) { state := c.Query("state") if code == "" || state == "" { - handleError(c, middleware.NewAuthError(400, "Missing code or state in callback")) + handleError(c, authError.NewAuthError(400, "Missing code or state in callback")) return } url := h.Service.HandleGoogleOAuthCallback(c.Request.Context(), code, state) if url == "" { - handleError(c, middleware.NewAuthError(500, "Failed to process Google OAuth callback")) + handleError(c, authError.NewAuthError(500, "Failed to process Google OAuth callback")) return } diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 2e77242..09a47ef 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -1,30 +1,44 @@ package middleware import ( + "errors" "strings" "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" + "github.com/paularynty/transcendence/auth-service-go/internal/service" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) const PrefixBearer = "Bearer " -func Auth(dep *dependency.Dependency) gin.HandlerFunc { +func Auth(userService *service.UserService) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" || !strings.HasPrefix(authHeader, PrefixBearer) { - _ = c.AbortWithError(401, NewAuthError(401, "Invalid or expired token")) + _ = c.AbortWithError(401, authError.NewAuthError(401, "Invalid or expired token")) return } tokenString := authHeader[len(PrefixBearer):] - userJwtPayload, err := jwt.ValidateUserTokenGeneric(dep, tokenString) + userJwtPayload, err := jwt.ValidateUserTokenGeneric(userService.Dep, tokenString) if err != nil { - _ = c.AbortWithError(401, NewAuthError(401, "Invalid or expired token")) + _ = c.AbortWithError(401, authError.NewAuthError(401, "Invalid or expired token")) + return + } + + err = userService.ValidateUserToken(c.Request.Context(), tokenString, userJwtPayload.UserID) + + var authError *authError.AuthError + if err != nil { + if errors.As(err, &authError) && authError.Status == 401 { + _ = c.AbortWithError(401, authError) + return + } + _ = c.AbortWithError(500, err) return } diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go index 6f77ec9..417cc5a 100644 --- a/backend/internal/middleware/auth_test.go +++ b/backend/internal/middleware/auth_test.go @@ -8,29 +8,46 @@ import ( "github.com/gin-gonic/gin" + model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/service" "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) -func setupAuthDep(t *testing.T) *dependency.Dependency { +func setupAuthDeps(t *testing.T) (*dependency.Dependency, *service.UserService) { t.Helper() cfg := testutil.NewTestConfig() cfg.JwtSecret = "test-secret-key" cfg.UserTokenExpiry = 3600 cfg.OauthStateTokenExpiry = 120 cfg.TwoFaTokenExpiry = 300 - return testutil.NewTestDependency(cfg, nil, nil, nil) + dbName := "file:" + t.Name() + "?mode=memory&cache=shared" + db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect to db: %v", err) + } + if err := db.AutoMigrate(&model.User{}, &model.Token{}); err != nil { + t.Fatalf("failed to migrate db: %v", err) + } + dep := testutil.NewTestDependency(cfg, db, nil, nil) + userService, err := service.NewUserService(dep) + if err != nil { + t.Fatalf("failed to create user service: %v", err) + } + return dep, userService } func TestAuthMiddlewareRejectsMissingToken(t *testing.T) { gin.SetMode(gin.TestMode) - dep := setupAuthDep(t) + _, userService := setupAuthDeps(t) r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(dep)) + r.Use(middleware.Auth(userService)) r.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) @@ -56,16 +73,22 @@ func TestAuthMiddlewareRejectsMissingToken(t *testing.T) { func TestAuthMiddlewareAllowsValidToken(t *testing.T) { gin.SetMode(gin.TestMode) - dep := setupAuthDep(t) + dep, userService := setupAuthDeps(t) token, err := jwt.SignUserToken(dep, 99) if err != nil { t.Fatalf("failed to sign user token: %v", err) } + if err := dep.DB.Create(&model.User{Model: gorm.Model{ID: 99}, Username: "u99", Email: "u99@example.com"}).Error; err != nil { + t.Fatalf("failed to create user: %v", err) + } + if err := dep.DB.Create(&model.Token{UserID: 99, Token: token}).Error; err != nil { + t.Fatalf("failed to create token: %v", err) + } r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(dep)) + r.Use(middleware.Auth(userService)) r.GET("/protected", func(c *gin.Context) { userID, ok := c.Get("userID") if !ok { @@ -97,7 +120,7 @@ func TestAuthMiddlewareAllowsValidToken(t *testing.T) { func TestAuthMiddlewareRejectsInvalidToken(t *testing.T) { gin.SetMode(gin.TestMode) - dep := setupAuthDep(t) + dep, userService := setupAuthDeps(t) token, err := jwt.SignTwoFAToken(dep, 10) if err != nil { @@ -106,7 +129,7 @@ func TestAuthMiddlewareRejectsInvalidToken(t *testing.T) { r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(dep)) + r.Use(middleware.Auth(userService)) r.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) diff --git a/backend/internal/middleware/error_handler.go b/backend/internal/middleware/error_handler.go index ea3c992..62088a2 100644 --- a/backend/internal/middleware/error_handler.go +++ b/backend/internal/middleware/error_handler.go @@ -5,24 +5,9 @@ import ( "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" ) -type AuthError struct { - Status int - Message string -} - -func (e *AuthError) Error() string { - return e.Message -} - -func NewAuthError(status int, message string) *AuthError { - return &AuthError{ - Status: status, - Message: message, - } -} - func ErrorHandler() gin.HandlerFunc { return func(c *gin.Context) { c.Next() @@ -35,7 +20,7 @@ func ErrorHandler() gin.HandlerFunc { err := c.Errors.Last().Err - var authErr *AuthError + var authErr *authError.AuthError // Handle AuthError specifically if errors.As(err, &authErr) { diff --git a/backend/internal/middleware/error_handler_test.go b/backend/internal/middleware/error_handler_test.go index 241a019..e55585c 100644 --- a/backend/internal/middleware/error_handler_test.go +++ b/backend/internal/middleware/error_handler_test.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) @@ -19,7 +20,7 @@ func TestErrorHandlerReturnsAuthErrorPayload(t *testing.T) { r := gin.New() r.Use(middleware.ErrorHandler()) r.GET("/auth", func(c *gin.Context) { - _ = c.AbortWithError(http.StatusUnauthorized, middleware.NewAuthError(http.StatusUnauthorized, "Invalid or expired token")) + _ = c.AbortWithError(http.StatusUnauthorized, authError.NewAuthError(http.StatusUnauthorized, "Invalid or expired token")) }) req := httptest.NewRequest(http.MethodGet, "/auth", nil) diff --git a/backend/internal/middleware/rate_limiter.go b/backend/internal/middleware/rate_limiter.go index 4649e65..2c47f07 100644 --- a/backend/internal/middleware/rate_limiter.go +++ b/backend/internal/middleware/rate_limiter.go @@ -6,22 +6,28 @@ import ( "time" "github.com/gin-gonic/gin" + + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" ) type RateLimiter struct { - mu sync.Mutex - limit int - duration time.Duration - requestCounts map[string]int - requestExpiry map[string]time.Time + mu sync.Mutex + limit int + duration time.Duration + requestCounts map[string]int + requestExpiry map[string]time.Time + lastCleanup time.Time + cleanupInterval time.Duration } -func NewRateLimiter(duration time.Duration, limit int) *RateLimiter { +func NewRateLimiter(duration time.Duration, limit int, cleanupInterval time.Duration) *RateLimiter { return &RateLimiter{ - limit: limit, - duration: duration, - requestCounts: make(map[string]int), - requestExpiry: make(map[string]time.Time), + limit: limit, + duration: duration, + requestCounts: make(map[string]int), + requestExpiry: make(map[string]time.Time), + lastCleanup: time.Now(), + cleanupInterval: cleanupInterval, } } @@ -30,6 +36,12 @@ func (rl *RateLimiter) AllowRequest(clientID string) bool { rl.mu.Lock() defer rl.mu.Unlock() + + if ts.Sub(rl.lastCleanup) > rl.cleanupInterval { + unSafeClearExpiredEntries(ts, rl) + rl.lastCleanup = ts + } + expiry, exists := rl.requestExpiry[clientID] if !exists || ts.After(expiry) { rl.requestCounts[clientID] = 1 @@ -55,10 +67,20 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { clientID := c.ClientIP() if !rl.AllowRequest(clientID) { - _ = c.AbortWithError(429, NewAuthError(429, "Too many requests")) + _ = c.AbortWithError(429, authError.NewAuthError(429, "Too many requests")) return } c.Next() } } + +// unSafeClearExpiredEntries Not thread-safe; caller must hold rl.mu lock. +func unSafeClearExpiredEntries(ts time.Time, rl *RateLimiter) { + for clientID, expiry := range rl.requestExpiry { + if ts.After(expiry) { + delete(rl.requestCounts, clientID) + delete(rl.requestExpiry, clientID) + } + } +} diff --git a/backend/internal/middleware/rate_limiter_internal_test.go b/backend/internal/middleware/rate_limiter_internal_test.go new file mode 100644 index 0000000..24046f6 --- /dev/null +++ b/backend/internal/middleware/rate_limiter_internal_test.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "testing" + "time" +) + +func TestAllowRequestCleansExpiredEntriesAtInterval(t *testing.T) { + rl := NewRateLimiter(10*time.Millisecond, 1, 5*time.Millisecond) + + now := time.Now() + rl.requestCounts["old"] = 2 + rl.requestExpiry["old"] = now.Add(-time.Second) + rl.lastCleanup = now.Add(-rl.cleanupInterval - time.Second) + + _ = rl.AllowRequest("new-client") + + if _, exists := rl.requestCounts["old"]; exists { + t.Fatalf("expected expired request count to be removed during cleanup") + } + if _, exists := rl.requestExpiry["old"]; exists { + t.Fatalf("expected expired request expiry to be removed during cleanup") + } +} + +func TestUnsafeClearExpiredEntriesRemovesOnlyExpired(t *testing.T) { + rl := NewRateLimiter(10*time.Millisecond, 1, time.Minute) + + now := time.Now() + rl.requestCounts["expired"] = 1 + rl.requestExpiry["expired"] = now.Add(-time.Second) + rl.requestCounts["active"] = 1 + rl.requestExpiry["active"] = now.Add(time.Second) + + unSafeClearExpiredEntries(now, rl) + + if _, exists := rl.requestCounts["expired"]; exists { + t.Fatalf("expected expired request count to be removed") + } + if _, exists := rl.requestExpiry["expired"]; exists { + t.Fatalf("expected expired request expiry to be removed") + } + if _, exists := rl.requestCounts["active"]; !exists { + t.Fatalf("expected active request count to remain") + } + if _, exists := rl.requestExpiry["active"]; !exists { + t.Fatalf("expected active request expiry to remain") + } +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index f63ba93..3b544a9 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -15,7 +15,7 @@ import ( func TestAllowRequestResetsAfterWindow(t *testing.T) { gin.SetMode(gin.TestMode) - rl := middleware.NewRateLimiter(30*time.Millisecond, 2) + rl := middleware.NewRateLimiter(30*time.Millisecond, 2, time.Minute) clientID := "client-1" if !rl.AllowRequest(clientID) { @@ -38,7 +38,7 @@ func TestAllowRequestResetsAfterWindow(t *testing.T) { func TestRateLimitMiddlewareBlocksAfterLimit(t *testing.T) { gin.SetMode(gin.TestMode) - rl := middleware.NewRateLimiter(50*time.Millisecond, 1) + rl := middleware.NewRateLimiter(50*time.Millisecond, 1, time.Minute) r := gin.New() r.Use(middleware.ErrorHandler()) @@ -76,7 +76,7 @@ func TestRateLimitMiddlewareBlocksAfterLimit(t *testing.T) { func TestRateLimitMiddlewareSkipsOptions(t *testing.T) { gin.SetMode(gin.TestMode) - rl := middleware.NewRateLimiter(50*time.Millisecond, 1) + rl := middleware.NewRateLimiter(50*time.Millisecond, 1, time.Minute) r := gin.New() r.Use(middleware.ErrorHandler()) @@ -116,7 +116,7 @@ func TestRateLimitMiddlewareSkipsOptions(t *testing.T) { func TestRateLimitMiddlewareUsesClientSpecificCounters(t *testing.T) { gin.SetMode(gin.TestMode) - rl := middleware.NewRateLimiter(100*time.Millisecond, 1) + rl := middleware.NewRateLimiter(100*time.Millisecond, 1, time.Minute) r := gin.New() r.Use(middleware.ErrorHandler()) diff --git a/backend/internal/middleware/validation.go b/backend/internal/middleware/validation.go index 75d5974..0b009ab 100644 --- a/backend/internal/middleware/validation.go +++ b/backend/internal/middleware/validation.go @@ -2,6 +2,7 @@ package middleware import ( "github.com/gin-gonic/gin" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/dto" ) @@ -9,7 +10,7 @@ func ValidateBody[T any]() gin.HandlerFunc { return func(c *gin.Context) { var body T if err := c.ShouldBindJSON(&body); err != nil { - _ = c.AbortWithError(400, NewAuthError(400, err.Error())) + _ = c.AbortWithError(400, authError.NewAuthError(400, err.Error())) return } diff --git a/backend/internal/routers/users_router.go b/backend/internal/routers/users_router.go index d5b015f..033d527 100644 --- a/backend/internal/routers/users_router.go +++ b/backend/internal/routers/users_router.go @@ -3,15 +3,14 @@ package routers import ( "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/handler" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/service" ) -func UsersRouter(r *gin.RouterGroup, dep *dependency.Dependency) { - h := &handler.UserHandler{Service: service.NewUserService(dep)} +func UsersRouter(r *gin.RouterGroup, userService *service.UserService) { + h := &handler.UserHandler{Service: userService} // Public endpoints r.POST("/", middleware.ValidateBody[dto.CreateUserRequest](), h.CreateUserHandler) @@ -22,7 +21,7 @@ func UsersRouter(r *gin.RouterGroup, dep *dependency.Dependency) { // Authenticated endpoints auth := r.Group("") - auth.Use(middleware.Auth(dep)) + auth.Use(middleware.Auth(userService)) auth.GET("/me", h.GetLoggedUserProfileHandler) auth.PUT("/password", middleware.ValidateBody[dto.UpdateUserPasswordRequest](), h.UpdateLoggedUserPasswordHandler) diff --git a/backend/internal/routers/users_router_test.go b/backend/internal/routers/users_router_test.go index d4232f4..342b7e3 100644 --- a/backend/internal/routers/users_router_test.go +++ b/backend/internal/routers/users_router_test.go @@ -29,6 +29,15 @@ import ( "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) +func mustNewUserService(t *testing.T, dep *dependency.Dependency) *service.UserService { + t.Helper() + svc, err := service.NewUserService(dep) + if err != nil { + t.Fatalf("failed to create user service: %v", err) + } + return svc +} + type usersRouterEnv struct { router *gin.Engine dep *dependency.Dependency @@ -78,7 +87,8 @@ func setupUsersRouterTest(t *testing.T, useRedis bool) *usersRouterEnv { dep := dependency.NewDependency(cfg, dbConn, redisClient, logger) router := gin.New() - UsersRouter(router.Group("/users"), dep) + userService := mustNewUserService(t, dep) + UsersRouter(router.Group("/users"), userService) if sqlDB, err := dbConn.DB(); err == nil && sqlDB != nil { sqlDB.SetMaxOpenConns(1) @@ -124,7 +134,7 @@ func addUserToken(t *testing.T, dep *dependency.Dependency, userID uint) string func createUser(t *testing.T, dep *dependency.Dependency, username, email, password string) *dto.UserWithoutTokenResponse { t.Helper() - svc := service.NewUserService(dep) + svc := mustNewUserService(t, dep) user, err := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: username}, Email: email}, Password: dto.Password{Password: password}, @@ -357,7 +367,7 @@ func TestUsersRouter_UpdateUser_Failures(t *testing.T) { env := setupUsersRouterTest(t, false) defer env.cleanup() - svc := service.NewUserService(env.dep) + svc := mustNewUserService(t, env.dep) u, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, Password: dto.Password{Password: "pass123"}, @@ -399,7 +409,7 @@ func TestUsersRouter_UpdateUserPassword(t *testing.T) { env := setupUsersRouterTest(t, false) defer env.cleanup() - svc := service.NewUserService(env.dep) + svc := mustNewUserService(t, env.dep) userResp, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "pw"}, Email: "pw@e.com"}, Password: dto.Password{Password: "oldpass"}, @@ -493,7 +503,7 @@ func TestUsersRouter_Friends(t *testing.T) { env := setupUsersRouterTest(t, false) defer env.cleanup() - svc := service.NewUserService(env.dep) + svc := mustNewUserService(t, env.dep) u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") _ = svc @@ -531,7 +541,7 @@ func TestUsersRouter_Friends_Failures(t *testing.T) { env := setupUsersRouterTest(t, false) defer env.cleanup() - svc := service.NewUserService(env.dep) + svc := mustNewUserService(t, env.dep) u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") token := addUserToken(t, env.dep, u1.ID) @@ -636,7 +646,7 @@ func TestUsersRouter_2FA_Failures(t *testing.T) { env := setupUsersRouterTest(t, false) defer env.cleanup() - svc := service.NewUserService(env.dep) + svc := mustNewUserService(t, env.dep) u := createUser(t, env.dep, "2fafail", "2fafail@e.com", "pass123") token := addUserToken(t, env.dep, u.ID) diff --git a/backend/internal/service/friend_service.go b/backend/internal/service/friend_service.go index 72d29ad..d93c744 100644 --- a/backend/internal/service/friend_service.go +++ b/backend/internal/service/friend_service.go @@ -4,9 +4,9 @@ import ( "context" "errors" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "gorm.io/gorm" ) @@ -51,7 +51,7 @@ func (s *UserService) GetUserFriends(ctx context.Context, userID uint) ([]dto.Fr func (s *UserService) AddNewFriend(ctx context.Context, userID uint, request *dto.AddNewFriendRequest) error { if userID == request.UserID { - return middleware.NewAuthError(400, "cannot add yourself as a friend") + return authError.NewAuthError(400, "cannot add yourself as a friend") } newFriend := model.Friend{ @@ -62,10 +62,10 @@ func (s *UserService) AddNewFriend(ctx context.Context, userID uint, request *dt err := gorm.G[model.Friend](s.Dep.DB).Create(ctx, &newFriend) if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return middleware.NewAuthError(409, "friend already added") + return authError.NewAuthError(409, "friend already added") } if errors.Is(err, gorm.ErrForeignKeyViolated) { - return middleware.NewAuthError(404, "user not found") + return authError.NewAuthError(404, "user not found") } return err } diff --git a/backend/internal/service/friend_service_test.go b/backend/internal/service/friend_service_test.go index 93c9e1a..0534e2b 100644 --- a/backend/internal/service/friend_service_test.go +++ b/backend/internal/service/friend_service_test.go @@ -11,7 +11,7 @@ import ( func TestGetAllUsersLimitedInfo(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() // Create users @@ -46,7 +46,7 @@ func TestGetAllUsersLimitedInfo(t *testing.T) { func TestAddNewFriend(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -111,7 +111,7 @@ func TestAddNewFriend(t *testing.T) { func TestGetUserFriends(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ diff --git a/backend/internal/service/google_oauth_service.go b/backend/internal/service/google_oauth_service.go index 0c48117..277f721 100644 --- a/backend/internal/service/google_oauth_service.go +++ b/backend/internal/service/google_oauth_service.go @@ -11,10 +11,10 @@ import ( "cloud.google.com/go/auth/credentials/idtoken" "github.com/google/uuid" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "gorm.io/gorm" ) @@ -109,18 +109,18 @@ var ExchangeCodeForTokens = func(dep *dependency.Dependency, ctx context.Context var FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { sub := payload.Subject if sub == "" { - return nil, middleware.NewAuthError(400, "google id token missing subject") + return nil, authError.NewAuthError(400, "google id token missing subject") } jsonClaims, err := json.Marshal(payload.Claims) if err != nil { - return nil, middleware.NewAuthError(500, "failed to Marshal google jwt token") + return nil, authError.NewAuthError(500, "failed to Marshal google jwt token") } var claims dto.GoogleClaims err = json.Unmarshal(jsonClaims, &claims) if err != nil { - return nil, middleware.NewAuthError(500, "failed to Unmarshal google jwt token") + return nil, authError.NewAuthError(500, "failed to Unmarshal google jwt token") } googleUserInfo := &dto.GoogleUserData{ @@ -138,7 +138,7 @@ var FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, e // This feature does not work unless we can verify the user's password/email ownership. func (s *UserService) linkGoogleAccountToExistingUser(ctx context.Context, modelUser *model.User, googleUserInfo *dto.GoogleUserData) error { - return middleware.NewAuthError(409, "same email exists") + return authError.NewAuthError(409, "same email exists") /** @@ -176,7 +176,7 @@ func (s *UserService) createNewUserFromGoogleInfo(ctx context.Context, googleUse if isRetry { uuidUsername, err := uuid.NewRandom() if err != nil { - return nil, middleware.NewAuthError(500, "failed to generate UUID for Google user") + return nil, authError.NewAuthError(500, "failed to generate UUID for Google user") } username = "G_" + uuidUsername.String() } else { @@ -202,7 +202,7 @@ func (s *UserService) createNewUserFromGoogleInfo(ctx context.Context, googleUse if !isRetry { return s.createNewUserFromGoogleInfo(ctx, googleUserInfo, true) } - return nil, middleware.NewAuthError(409, "username or email already in use") + return nil, authError.NewAuthError(409, "username or email already in use") } return nil, err } @@ -244,10 +244,12 @@ func (s *UserService) HandleGoogleOAuthCallback(ctx context.Context, code string modelUser, err = gorm.G[model.User](s.Dep.DB).Where("email = ?", googleUserInfo.Email).First(ctx) if err == nil { // User with this email exists, link Google account + // Autolinking is disabled and always returns 409 for now, as we cannot verify ownership of the email/password here. err = s.linkGoogleAccountToExistingUser(ctx, &modelUser, googleUserInfo) if err != nil { // Failed to link Google account return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to link google account to existing user") } + // Successfully linked Google account finalUserID = modelUser.ID } else if !errors.Is(err, gorm.ErrRecordNotFound) { diff --git a/backend/internal/service/google_oauth_service_test.go b/backend/internal/service/google_oauth_service_test.go index ed052aa..5141c55 100644 --- a/backend/internal/service/google_oauth_service_test.go +++ b/backend/internal/service/google_oauth_service_test.go @@ -7,10 +7,10 @@ import ( "testing" "cloud.google.com/go/auth/credentials/idtoken" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) @@ -18,7 +18,7 @@ import ( func TestGetGoogleOAuthURL(t *testing.T) { db := setupTestDB(t.Name()) cfg := testutil.NewTestConfig() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, nil)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, nil)) ctx := context.Background() t.Run("Success", func(t *testing.T) { @@ -47,7 +47,7 @@ func TestGetGoogleOAuthURL(t *testing.T) { func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() // Helper to parse redirect URL @@ -82,7 +82,7 @@ func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { func TestHandleGoogleOAuthCallback_Success(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() // Mock dependencies @@ -207,7 +207,7 @@ func TestHandleGoogleOAuthCallback_Success(t *testing.T) { func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens @@ -249,7 +249,7 @@ func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { func TestLinkGoogleAccountToExistingUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -273,7 +273,7 @@ func TestLinkGoogleAccountToExistingUser(t *testing.T) { if err == nil { t.Fatal("expected linking to be blocked") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok { t.Fatalf("expected AuthError, got %T: %v", err, err) } @@ -297,7 +297,7 @@ func TestLinkGoogleAccountToExistingUser(t *testing.T) { Email: "other@e.com", } err := svc.linkGoogleAccountToExistingUser(ctx, &modelUser, googleInfo) - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if err == nil || !ok || authErr.Status != 409 { t.Errorf("expected 409 AuthError, got %v", err) } @@ -310,7 +310,7 @@ func TestLinkGoogleAccountToExistingUser(t *testing.T) { Email: "link@e.com", } err := svc.linkGoogleAccountToExistingUser(ctx, &modelUser, googleInfo) - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if err == nil || !ok || authErr.Status != 409 { t.Errorf("expected 409 AuthError, got %v", err) } @@ -319,7 +319,7 @@ func TestLinkGoogleAccountToExistingUser(t *testing.T) { func TestCreateNewUserFromGoogleInfo(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() t.Run("Success", func(t *testing.T) { @@ -369,7 +369,7 @@ func TestCreateNewUserFromGoogleInfo(t *testing.T) { func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens @@ -405,7 +405,7 @@ func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { func TestHandleGoogleOAuthCallback_LinkError(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens diff --git a/backend/internal/service/helper.go b/backend/internal/service/helper.go index 6d93b42..b6b7868 100644 --- a/backend/internal/service/helper.go +++ b/backend/internal/service/helper.go @@ -18,19 +18,19 @@ import ( const HeartBeatPrefix = "heartbeat:" -func NewUserService(dep *dependency.Dependency) *UserService { +func NewUserService(dep *dependency.Dependency) (*UserService, error) { if dep.DB == nil { - panic("UserService: db is nil") + return nil, fmt.Errorf("UserService: db is nil") } if dep.Cfg.IsRedisEnabled && dep.Redis == nil { - panic("UserService: redis is enabled but redis client is nil") + return nil, fmt.Errorf("UserService: redis is enabled but redis client is nil") } return &UserService{ Dep: dep, - } + }, nil } func isTwoFAEnabled(twoFAToken *string) bool { diff --git a/backend/internal/service/helper_test.go b/backend/internal/service/helper_test.go index 3a2b640..d13dd95 100644 --- a/backend/internal/service/helper_test.go +++ b/backend/internal/service/helper_test.go @@ -6,9 +6,19 @@ import ( "time" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" ) +func mustNewUserService(t *testing.T, dep *dependency.Dependency) *UserService { + t.Helper() + svc, err := NewUserService(dep) + if err != nil { + t.Fatalf("failed to create user service: %v", err) + } + return svc +} + func TestHelperFunctions(t *testing.T) { t.Run("isTwoFAEnabled", func(t *testing.T) { token := "pre-secret" @@ -61,7 +71,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("UpdateHeartBeat", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) // Create user first to satisfy FK _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ @@ -83,7 +93,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("IssueNewTokenForUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) // Create user first _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ @@ -113,7 +123,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("IssueNewTokenForUser_DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) sqlDB, _ := db.DB() _ = sqlDB.Close() diff --git a/backend/internal/service/redis_service_test.go b/backend/internal/service/redis_service_test.go index 0396fd0..34deed0 100644 --- a/backend/internal/service/redis_service_test.go +++ b/backend/internal/service/redis_service_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/config" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/redis/go-redis/v9" ) @@ -27,7 +27,7 @@ func TestRedisTokenLifecycle(t *testing.T) { mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -80,7 +80,7 @@ func TestRedisTokenLifecycle(t *testing.T) { if err == nil { t.Fatal("expected token to be invalid after logout") } - var authErr *middleware.AuthError + var authErr *authError.AuthError if !strings.Contains(err.Error(), "invalid token") || !errors.As(err, &authErr) { t.Fatalf("expected auth error for invalid token, got %v", err) } @@ -92,7 +92,7 @@ func TestRedisHeartbeatOnlineStatusAndCleanup(t *testing.T) { _, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() u1, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -153,7 +153,7 @@ func TestRedisLoginUpdatesHeartbeat(t *testing.T) { _, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() created, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -194,7 +194,7 @@ func TestRedisLogoutRevokesAllTokens(t *testing.T) { mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -242,7 +242,7 @@ func TestRedisDeleteUserRevokesAllTokens(t *testing.T) { mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -290,7 +290,7 @@ func TestRedisUpdatePasswordRevokesOldTokens(t *testing.T) { mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) + svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ diff --git a/backend/internal/service/twofa_service.go b/backend/internal/service/twofa_service.go index c0f8e8b..7b180fb 100644 --- a/backend/internal/service/twofa_service.go +++ b/backend/internal/service/twofa_service.go @@ -5,9 +5,9 @@ import ( "errors" "strings" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "github.com/pquerna/otp/totp" "golang.org/x/crypto/bcrypt" @@ -18,17 +18,17 @@ func (s *UserService) StartTwoFaSetup(ctx context.Context, userID uint) (*dto.Tw modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } if isTwoFAEnabled(modelUser.TwoFAToken) { - return nil, middleware.NewAuthError(400, "2FA is already enabled") + return nil, authError.NewAuthError(400, "2FA is already enabled") } if modelUser.GoogleOauthID != nil { - return nil, middleware.NewAuthError(400, "2FA cannot be enabled for Google OAuth users") + return nil, authError.NewAuthError(400, "2FA cannot be enabled for Google OAuth users") } secret, err := totp.Generate(totp.GenerateOpts{ @@ -61,37 +61,37 @@ func (s *UserService) StartTwoFaSetup(ctx context.Context, userID uint) (*dto.Tw func (s *UserService) ConfirmTwoFaSetup(ctx context.Context, userID uint, request *dto.TwoFAConfirmRequest) (*dto.UserWithTokenResponse, error) { claims, err := jwt.ValidateTwoFASetupToken(s.Dep, request.SetupToken) if err != nil || claims.Type != jwt.TwoFASetupType { - return nil, middleware.NewAuthError(400, "invalid setup token") + return nil, authError.NewAuthError(400, "invalid setup token") } if claims.UserID != userID { - return nil, middleware.NewAuthError(400, "setup token does not match user") + return nil, authError.NewAuthError(400, "setup token does not match user") } modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } if modelUser.TwoFAToken == nil { - return nil, middleware.NewAuthError(400, "2FA setup was not initiated") + return nil, authError.NewAuthError(400, "2FA setup was not initiated") } if isTwoFAEnabled(modelUser.TwoFAToken) { - return nil, middleware.NewAuthError(400, "2FA is already enabled") + return nil, authError.NewAuthError(400, "2FA is already enabled") } if modelUser.GoogleOauthID != nil { - return nil, middleware.NewAuthError(400, "2FA cannot be enabled for Google OAuth users") + return nil, authError.NewAuthError(400, "2FA cannot be enabled for Google OAuth users") } twoFaSecret := strings.TrimPrefix(*modelUser.TwoFAToken, TwoFAPrePrefix) valid := totp.Validate(request.TwoFACode, twoFaSecret) if !valid { - return nil, middleware.NewAuthError(400, "invalid 2FA code") + return nil, authError.NewAuthError(400, "invalid 2FA code") } _, err = gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", twoFaSecret) @@ -112,23 +112,23 @@ func (s *UserService) DisableTwoFA(ctx context.Context, userID uint, request *dt modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } if modelUser.PasswordHash == nil { - return nil, middleware.NewAuthError(400, "2FA cannot be disabled for OAuth users") + return nil, authError.NewAuthError(400, "2FA cannot be disabled for OAuth users") } if !isTwoFAEnabled(modelUser.TwoFAToken) { - return nil, middleware.NewAuthError(400, "2FA is not enabled") + return nil, authError.NewAuthError(400, "2FA is not enabled") } err = bcrypt.CompareHashAndPassword([]byte(*modelUser.PasswordHash), []byte(request.Password.Password)) if err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return nil, middleware.NewAuthError(401, "invalid credentials") + return nil, authError.NewAuthError(401, "invalid credentials") } return nil, err } @@ -150,24 +150,24 @@ func (s *UserService) DisableTwoFA(ctx context.Context, userID uint, request *dt func (s *UserService) SubmitTwoFAChallenge(ctx context.Context, request *dto.TwoFAChallengeRequest) (*dto.UserWithTokenResponse, error) { claims, err := jwt.ValidateTwoFAToken(s.Dep, request.SessionToken) if err != nil || claims.Type != jwt.TwoFATokenType { - return nil, middleware.NewAuthError(400, "invalid session token") + return nil, authError.NewAuthError(400, "invalid session token") } modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", claims.UserID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } if !isTwoFAEnabled(modelUser.TwoFAToken) || modelUser.TwoFAToken == nil { - return nil, middleware.NewAuthError(400, "2FA is not enabled for this user") + return nil, authError.NewAuthError(400, "2FA is not enabled for this user") } valid := totp.Validate(request.TwoFACode, *modelUser.TwoFAToken) if !valid { - return nil, middleware.NewAuthError(400, "invalid 2FA code") + return nil, authError.NewAuthError(400, "invalid 2FA code") } userToken, err := s.issueNewTokenForUser(ctx, modelUser.ID, false) diff --git a/backend/internal/service/twofa_service_test.go b/backend/internal/service/twofa_service_test.go index ee42bd1..b086534 100644 --- a/backend/internal/service/twofa_service_test.go +++ b/backend/internal/service/twofa_service_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "github.com/pquerna/otp/totp" ) @@ -16,7 +16,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, Password: dto.Password{Password: "p"}, @@ -33,7 +33,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("ConfirmSetup_Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, Password: dto.Password{Password: "p"}, @@ -61,7 +61,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_AlreadyEnabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u3"}, Email: "u3@e.com"}, Password: dto.Password{Password: "p"}, @@ -77,7 +77,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { if err == nil { t.Fatal("expected error") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 400 { t.Errorf("expected 400 error, got %v", err) } @@ -85,7 +85,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_OAuthUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) // Mock OAuth user oauthUser := dto.GoogleUserData{ ID: "oauth123", @@ -100,7 +100,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { if err == nil { t.Fatal("expected error for oauth user") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 400 { t.Errorf("expected 400 error, got %v", err) } @@ -108,7 +108,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u4"}, Email: "u4@e.com"}, Password: dto.Password{Password: "p"}, @@ -125,7 +125,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { func TestConfirmTwoFaSetup_Errors(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -193,7 +193,7 @@ func TestConfirmTwoFaSetup_Errors(t *testing.T) { t.Run("NotInitiated", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) // User with no 2FA token u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ni"}, Email: "ni@e.com"}, @@ -221,7 +221,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch1"}, Email: "ch1@e.com"}, Password: dto.Password{Password: "p"}, @@ -253,7 +253,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("InvalidCode", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch2"}, Email: "ch2@e.com"}, Password: dto.Password{Password: "p"}, @@ -281,7 +281,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("NotEnabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "chne"}, Email: "chne@e.com"}, Password: dto.Password{Password: "p"}, @@ -304,7 +304,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch3"}, Email: "ch3@e.com"}, Password: dto.Password{Password: "p"}, @@ -338,7 +338,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis1"}, Email: "dis1@e.com"}, Password: dto.Password{Password: "p"}, @@ -361,7 +361,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("AlreadyDisabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis2"}, Email: "dis2@e.com"}, Password: dto.Password{Password: "p"}, @@ -378,7 +378,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("OAuthUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) // Mock OAuth user oauthUser := dto.GoogleUserData{ ID: "oauth456", @@ -393,7 +393,7 @@ func TestDisableTwoFA(t *testing.T) { if err == nil { t.Fatal("expected error for oauth user") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 400 { t.Errorf("expected 400 error, got %v", err) } @@ -401,7 +401,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis3"}, Email: "dis3@e.com"}, Password: dto.Password{Password: "p"}, @@ -421,7 +421,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("InvalidPassword", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "disinv"}, Email: "disinv@e.com"}, Password: dto.Password{Password: "correct"}, @@ -438,7 +438,7 @@ func TestDisableTwoFA(t *testing.T) { if err == nil { t.Fatal("expected error for invalid password") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 401 { t.Errorf("expected 401 error, got %v", err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index d9c4480..a14761f 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -9,10 +9,10 @@ import ( "golang.org/x/crypto/bcrypt" "gorm.io/gorm" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "github.com/redis/go-redis/v9" ) @@ -47,7 +47,7 @@ func (s *UserService) CreateUser(ctx context.Context, request *dto.CreateUserReq err = gorm.G[model.User](s.Dep.DB).Create(ctx, &modelUser) if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return nil, middleware.NewAuthError(409, "username or email already in use") + return nil, authError.NewAuthError(409, "username or email already in use") } return nil, err } @@ -72,7 +72,7 @@ func (s *UserService) LoginUser(ctx context.Context, request *dto.LoginUserReque modelUser, err := gorm.G[model.User](s.Dep.DB).Where(identifierField+" = ?", request.Identifier.Identifier).First(ctx) if err != nil || modelUser.PasswordHash == nil { if errors.Is(err, gorm.ErrRecordNotFound) || modelUser.PasswordHash == nil { - return nil, middleware.NewAuthError(401, "invalid credentials") + return nil, authError.NewAuthError(401, "invalid credentials") } return nil, err } @@ -80,7 +80,7 @@ func (s *UserService) LoginUser(ctx context.Context, request *dto.LoginUserReque err = bcrypt.CompareHashAndPassword([]byte(*modelUser.PasswordHash), []byte(request.Password.Password)) if err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return nil, middleware.NewAuthError(401, "invalid credentials") + return nil, authError.NewAuthError(401, "invalid credentials") } return nil, err } @@ -114,7 +114,7 @@ func (s *UserService) GetUserByID(ctx context.Context, userID uint) (*dto.UserWi modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } @@ -126,19 +126,19 @@ func (s *UserService) UpdateUserPassword(ctx context.Context, userID uint, reque modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } if modelUser.PasswordHash == nil { - return nil, middleware.NewAuthError(400, "password cannot be changed for OAuth users") + return nil, authError.NewAuthError(400, "password cannot be changed for OAuth users") } err = bcrypt.CompareHashAndPassword([]byte(*modelUser.PasswordHash), []byte(request.OldPassword.OldPassword)) if err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return nil, middleware.NewAuthError(401, "invalid credentials") + return nil, authError.NewAuthError(401, "invalid credentials") } return nil, err } @@ -165,7 +165,7 @@ func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, reques modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, middleware.NewAuthError(404, "user not found") + return nil, authError.NewAuthError(404, "user not found") } return nil, err } @@ -178,7 +178,7 @@ func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, reques if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { - return nil, middleware.NewAuthError(409, "username or email already in use") + return nil, authError.NewAuthError(409, "username or email already in use") } return nil, err } @@ -239,13 +239,13 @@ func (s *UserService) validateUserTokenDB(ctx context.Context, token string, use modelToken, err := gorm.G[model.Token](s.Dep.DB).Where("token = ?", token).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return middleware.NewAuthError(401, "invalid token") + return authError.NewAuthError(401, "invalid token") } return err } if modelToken.UserID != userId { - return middleware.NewAuthError(401, "token does not match user") + return authError.NewAuthError(401, "token does not match user") } s.updateHeartBeat(userId) @@ -256,7 +256,7 @@ func (s *UserService) validateUserTokenRedis(ctx context.Context, token string, _, err := s.Dep.Redis.Get(ctx, buildTokenKey(userId, token)).Result() if err != nil { if errors.Is(err, redis.Nil) { - return middleware.NewAuthError(401, "invalid token") + return authError.NewAuthError(401, "invalid token") } return err } diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index d3ee9d9..8a35244 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -5,14 +5,14 @@ import ( "strings" "testing" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) func requireAuthStatus(t *testing.T, err error, status int) { t.Helper() - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != status { t.Fatalf("expected %d error, got %v", status, err) } @@ -20,7 +20,7 @@ func requireAuthStatus(t *testing.T, err error, status int) { func TestCreateUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() cases := []struct { @@ -112,7 +112,7 @@ func TestCreateUser(t *testing.T) { func TestLoginUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() // Setup user @@ -165,7 +165,7 @@ func TestLoginUser(t *testing.T) { if err == nil { t.Fatal("expected error") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 401 { t.Errorf("expected 401 error, got %v", err) } @@ -181,7 +181,7 @@ func TestLoginUser(t *testing.T) { if err == nil { t.Fatal("expected error") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 401 { t.Errorf("expected 401 error, got %v", err) } @@ -228,7 +228,7 @@ func TestLoginUser(t *testing.T) { if err == nil { t.Fatal("expected error") } - authErr, ok := err.(*middleware.AuthError) + authErr, ok := err.(*authError.AuthError) if !ok || authErr.Status != 401 { t.Errorf("expected 401 error, got %v", err) } @@ -253,7 +253,7 @@ func TestLoginUser(t *testing.T) { t.Fatal("expected error") } // Should return raw error, not AuthError - if _, ok := err.(*middleware.AuthError); ok { + if _, ok := err.(*authError.AuthError); ok { t.Error("expected raw error for invalid hash") } }) @@ -261,7 +261,7 @@ func TestLoginUser(t *testing.T) { func TestGetUserByID(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -303,7 +303,7 @@ func TestGetUserByID(t *testing.T) { func TestUpdateUserPassword(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -399,7 +399,7 @@ func TestUpdateUserPassword(t *testing.T) { if err == nil { t.Fatal("expected error") } - if _, ok := err.(*middleware.AuthError); ok { + if _, ok := err.(*authError.AuthError); ok { t.Error("expected raw error") } }) @@ -407,7 +407,7 @@ func TestUpdateUserPassword(t *testing.T) { func TestUpdateUserProfile(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -480,7 +480,7 @@ func TestUpdateUserProfile(t *testing.T) { func TestDeleteUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -506,7 +506,7 @@ func TestDeleteUser(t *testing.T) { func TestValidateUserToken(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() createReq := &dto.CreateUserRequest{ @@ -563,7 +563,7 @@ func TestValidateUserToken(t *testing.T) { func TestLogoutUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) ctx := context.Background() createReq := &dto.CreateUserRequest{ @@ -673,7 +673,7 @@ func TestDBErrors(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(newTestDependency(db, nil)) + svc := mustNewUserService(t, newTestDependency(db, nil)) sqlDB, _ := db.DB() _ = sqlDB.Close() diff --git a/backend/internal/util/log.go b/backend/internal/util/log.go index eee3fcb..d484250 100644 --- a/backend/internal/util/log.go +++ b/backend/internal/util/log.go @@ -18,3 +18,11 @@ func GetLogger(level slog.Leveler) *slog.Logger { slog.SetDefault(logger) return logger } + +// LogFatalErr logs the error message and exits the program if err is not nil. +func LogFatalErr(logger *slog.Logger, err error, msg string) { + if err != nil { + logger.Error(msg, "err", err) + os.Exit(1) + } +}