From 24c792d58d4c2ef548df0499e1b69bc5a0f1f9a8 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:04:48 +0200 Subject: [PATCH 1/5] fix: update password validation rules to require a minimum length of 6 characters --- backend/internal/dto/schemas.go | 6 +++--- backend/internal/dto/schemas_test.go | 28 +++++++++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/backend/internal/dto/schemas.go b/backend/internal/dto/schemas.go index 0c831c1..adb7f90 100644 --- a/backend/internal/dto/schemas.go +++ b/backend/internal/dto/schemas.go @@ -80,15 +80,15 @@ func registerUsernameTranslation(v *validator.Validate, trans ut.Translator) { // Password type Password struct { - Password string `json:"password" validate:"required,trim,min=3,max=20,password"` + Password string `json:"password" validate:"required,trim,min=6,max=20,password"` } type OldPassword struct { - OldPassword string `json:"oldPassword" validate:"required,trim,password,min=3,max=20"` + OldPassword string `json:"oldPassword" validate:"required,trim,password,min=6,max=20"` } type NewPassword struct { - NewPassword string `json:"newPassword" validate:"required,trim,password,min=3,max=20"` + NewPassword string `json:"newPassword" validate:"required,trim,password,min=6,max=20"` } // Contains only letters, numbers, ".", "_" or "-" diff --git a/backend/internal/dto/schemas_test.go b/backend/internal/dto/schemas_test.go index 2643dd8..fade7ee 100644 --- a/backend/internal/dto/schemas_test.go +++ b/backend/internal/dto/schemas_test.go @@ -82,22 +82,27 @@ func TestUsersResponseMarshalsAsObjectWithSlice(t *testing.T) { func TestTrimValidationStripsWhitespace(t *testing.T) { type payload struct { - Value string `validate:"required,trim,min=3"` + Value string `validate:"required,trim,min=6"` } - data := &payload{Value: " foo "} + data := &payload{Value: " foobar "} if err := dto.Validate.Struct(data); err != nil { t.Fatalf("expected trimmed value to pass validation, got error: %v", err) } - if data.Value != "foo" { + if data.Value != "foobar" { t.Fatalf("expected trim validator to remove outer spaces, got %q", data.Value) } - tooShort := &payload{Value: " a "} + tooShort := &payload{Value: " abcde "} if err := dto.Validate.Struct(tooShort); err == nil { t.Fatalf("expected trimmed value shorter than min to fail validation") } + + emptyAfterTrim := &payload{Value: " "} + if err := dto.Validate.Struct(emptyAfterTrim); err == nil { + t.Fatalf("expected whitespace-only value to fail validation after trim") + } } func TestUsernameValidatorRules(t *testing.T) { @@ -108,8 +113,11 @@ func TestUsernameValidatorRules(t *testing.T) { }{ {"Valid", "valid_user", false}, {"ValidTrimmed", " valid-user ", false}, - {"TooShort", "ab", true}, - {"TooShortAfterTrim", " ab ", true}, + {"ValidTrimmedRight", "valid-user ", false}, + {"ValidTrimmedLeft", " valid-user", false}, + {"EmptyAfterTrim", " ", true}, + {"TooShort", "abcde", true}, + {"TooShortAfterTrim", " abcde ", true}, {"ContainsSpace", "user name", true}, {"IllegalChars", "user@name", true}, } @@ -144,6 +152,9 @@ func TestPasswordValidatorRules(t *testing.T) { }{ {"ValidBasic", "Abc123", false}, {"ValidSymbols", "pass,#$%", false}, + {"ValidTrimmedRight", "Abc123 ", false}, + {"ValidTrimmedLeft", " Abc123", false}, + {"EmptyAfterTrim", " ", true}, {"TooShort", "ab", true}, {"TooShortAfterTrim", " ab ", true}, {"ContainsSpace", "pass word", true}, @@ -181,8 +192,11 @@ func TestIdentifierValidatorAcceptsUsernameOrEmail(t *testing.T) { {"Username", "valid_user", false}, {"Email", "user@example.com", false}, {"TrimmedEmail", " user@example.com ", false}, + {"TrimmedEmailRight", "user@example.com ", false}, + {"TrimmedEmailLeft", " user@example.com", false}, + {"EmptyAfterTrim", " ", true}, {"Invalid", "???", true}, - {"TooShort", "ab", true}, + {"TooShort", "abcde", true}, } for _, tc := range cases { From 3257db140952e749fe40319061b937d302c1696f Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:12:31 +0200 Subject: [PATCH 2/5] fix: require essential environment variables for JWT and Google OAuth --- backend/internal/config/config.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index a2058ba..234a930 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -34,6 +34,16 @@ func getEnvStrOrDefault(key string, defaultValue string) string { return value } +func getEnvStrOrPanic(key string) string { + value := os.Getenv(key) + + if value == "" { + panic("environment variable " + key + " is required but not set") + } + + return value +} + func getEnvIntOrDefault(key string, defaultValue int) int { strValue := os.Getenv(key) @@ -49,11 +59,11 @@ func LoadConfig() { Cfg = &Config{ GinMode: getEnvStrOrDefault("GIN_MODE", "debug"), DbAddress: getEnvStrOrDefault("DB_ADDRESS", "data/auth_service_db.sqlite"), - JwtSecret: getEnvStrOrDefault("JWT_SECRET", "test-secret"), + JwtSecret: getEnvStrOrPanic("JWT_SECRET"), UserTokenExpiry: getEnvIntOrDefault("USER_TOKEN_EXPIRY", 3600), OauthStateTokenExpiry: getEnvIntOrDefault("OAUTH_STATE_TOKEN_EXPIRY", 600), - GoogleClientId: getEnvStrOrDefault("GOOGLE_CLIENT_ID", "test-google-client-id"), - GoogleClientSecret: getEnvStrOrDefault("GOOGLE_CLIENT_SECRET", "test-google-client-secret"), + GoogleClientId: getEnvStrOrPanic("GOOGLE_CLIENT_ID"), + GoogleClientSecret: getEnvStrOrPanic("GOOGLE_CLIENT_SECRET"), GoogleRedirectUri: getEnvStrOrDefault("GOOGLE_REDIRECT_URI", "test-google-redirect-uri"), FrontendUrl: getEnvStrOrDefault("FRONTEND_URL", "http://localhost:5173"), TwoFaUrlPrefix: getEnvStrOrDefault("TWO_FA_URL_PREFIX", "otpauth://totp/Transcendence?secret="), From a0259301da6ef8367539852b071c661cf7750b97 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Thu, 29 Jan 2026 23:19:09 +0200 Subject: [PATCH 3/5] Refactor/backend: apply dependency injection --- backend/cmd/server/main.go | 38 ++++---- backend/internal/config/config.go | 6 +- backend/internal/db/db.go | 38 ++++---- backend/internal/db/redis.go | 26 +++--- backend/internal/dependency/dependency.go | 24 +++++ backend/internal/middleware/auth.go | 5 +- backend/internal/routers/dev_router.go | 20 ---- backend/internal/routers/dev_router_test.go | 91 ------------------- backend/internal/routers/users_router.go | 8 +- backend/internal/service/friend_service.go | 6 +- .../internal/service/google_oauth_service.go | 65 +++++++------ backend/internal/service/helper.go | 48 +++++----- backend/internal/service/twofa_service.go | 20 ++-- backend/internal/service/user_service.go | 41 ++++----- backend/internal/util/jwt/jwt.go | 50 +++++----- backend/internal/util/log.go | 9 +- 16 files changed, 200 insertions(+), 295 deletions(-) create mode 100644 backend/internal/dependency/dependency.go delete mode 100644 backend/internal/routers/dev_router.go delete mode 100644 backend/internal/routers/dev_router_test.go diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index f5ea8b5..0bfaa04 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -25,9 +25,10 @@ import ( "github.com/gin-contrib/cors" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" ) -func SetupRouter(logger *slog.Logger) *gin.Engine { +func SetupRouter(dep *dependency.Dependency) *gin.Engine { r := gin.New() logConfig := sloggin.Config{ @@ -59,39 +60,40 @@ func SetupRouter(logger *slog.Logger) *gin.Engine { r.Use(rateLimiter.RateLimit()) r.Use(middleware.PanicHandler()) - r.Use(sloggin.NewWithConfig(logger, logConfig)) + r.Use(sloggin.NewWithConfig(dep.Logger, logConfig)) r.Use(middleware.ErrorHandler()) return r } +func initDependency() *dependency.Dependency { + logger := util.GetLogger(slog.LevelInfo) + cfg := config.LoadConfigFromEnv() + myDB := db.GetDB(cfg.DbAddress, logger) + redis := db.GetRedis(cfg.RedisURL, cfg, logger) + + return dependency.NewDependency(cfg, myDB, redis, logger) +} + // @title Auth Service API // @version 1.0 -// @description Auth service for Transcendence +// @description Auth service // @BasePath /api func main() { // config _ = godotenv.Load() - config.LoadConfig() - - // logger - util.InitLogger(slog.LevelInfo) + // init dependency + dep := initDependency() + defer db.CloseDB(dep.DB, dep.Logger) + defer db.CloseRedis(dep.Redis, dep.Logger) // validator dto.InitValidator() - // database - db.ConnectDB(config.Cfg.DbAddress) - defer db.CloseDB() - - db.ConnectRedis(config.Cfg.RedisURL) - defer db.CloseRedis() - // router - r := SetupRouter(util.Logger) - routers.UsersRouter(r.Group("/api/users")) - routers.DevRouter(r.Group("/api/dev")) + r := SetupRouter(dep) + routers.UsersRouter(r.Group("/api/users"), dep) // Health check r.GET("/api/ping", func(c *gin.Context) { @@ -104,7 +106,7 @@ func main() { r.GET("/api/docs/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) if err := r.Run(":3003"); err != nil { - util.Logger.Error("failed to start server", "err", err) + dep.Logger.Error("failed to start server", "err", err) os.Exit(1) } } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 234a930..28cfe08 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -22,8 +22,6 @@ type Config struct { UserTokenAbsoluteExpiry int } -var Cfg *Config - func getEnvStrOrDefault(key string, defaultValue string) string { value := os.Getenv(key) @@ -55,8 +53,8 @@ func getEnvIntOrDefault(key string, defaultValue int) int { return intValue } -func LoadConfig() { - Cfg = &Config{ +func LoadConfigFromEnv() *Config { + return &Config{ GinMode: getEnvStrOrDefault("GIN_MODE", "debug"), DbAddress: getEnvStrOrDefault("DB_ADDRESS", "data/auth_service_db.sqlite"), JwtSecret: getEnvStrOrPanic("JWT_SECRET"), diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go index 7fe413e..8379022 100644 --- a/backend/internal/db/db.go +++ b/backend/internal/db/db.go @@ -2,24 +2,21 @@ package db import ( "context" + "log/slog" "gorm.io/driver/sqlite" "gorm.io/gorm" - - "github.com/paularynty/transcendence/auth-service-go/internal/util" ) -var DB *gorm.DB - -func ConnectDB(dbName string) { +func GetDB(dbName string, logger *slog.Logger) *gorm.DB { var err error - DB, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) + db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) if err != nil { panic("failed to connect to db: " + dbName) } - DB.Exec("PRAGMA foreign_keys = ON") + db.Exec("PRAGMA foreign_keys = ON") for _, model := range []any{ &User{}, @@ -27,32 +24,33 @@ func ConnectDB(dbName string) { &Token{}, &HeartBeat{}, } { - if err := DB.AutoMigrate(model); err != nil { + if err := db.AutoMigrate(model); err != nil { panic("failed to migrate model: " + err.Error()) } } - util.Logger.Info("connected to db") + logger.Info("connected to db") + + return db } -func CloseDB() { - sqlDB, err := DB.DB() +func CloseDB(db *gorm.DB, logger *slog.Logger) { + sqlDB, err := db.DB() if err != nil { - util.Logger.Error("failed to get db instance", "err", err) + logger.Error("failed to get db instance", "err", err) return } if err := sqlDB.Close(); err != nil { - util.Logger.Error("failed to close db", "err", err) + logger.Error("failed to close db", "err", err) return } - util.Logger.Info("db connection closed") + logger.Info("db connection closed") } -func ResetDB() { - util.Logger.Warn("resetting db...") - +func ResetDB(db *gorm.DB, logger *slog.Logger) { + logger.Warn("resetting db...") ctx := context.Background() tables := []string{ "heart_beats", @@ -62,11 +60,11 @@ func ResetDB() { } for _, table := range tables { - err := gorm.G[any](DB).Exec(ctx, "DELETE FROM "+table) + err := gorm.G[any](db).Exec(ctx, "DELETE FROM "+table) if err != nil { - util.Logger.Error("failed to reset table", table, err.Error()) + logger.Error("failed to reset table", table, err.Error()) } } - util.Logger.Info("db is reset") + logger.Info("db is reset") } diff --git a/backend/internal/db/redis.go b/backend/internal/db/redis.go index 76c935e..b96d1f8 100644 --- a/backend/internal/db/redis.go +++ b/backend/internal/db/redis.go @@ -2,18 +2,16 @@ package db import ( "context" + "log/slog" "github.com/paularynty/transcendence/auth-service-go/internal/config" - "github.com/paularynty/transcendence/auth-service-go/internal/util" "github.com/redis/go-redis/v9" ) -var Redis *redis.Client - -func ConnectRedis(redisURL string) { - if !config.Cfg.IsRedisEnabled { - util.Logger.Info("redis is disabled by config") - return +func GetRedis(redisURL string, cfg *config.Config, logger *slog.Logger) *redis.Client { + if !cfg.IsRedisEnabled { + logger.Info("redis is disabled by config") + return nil } opt, err := redis.ParseURL(redisURL) @@ -31,20 +29,20 @@ func ConnectRedis(redisURL string) { panic("failed to connect to redis: " + err.Error()) } - Redis = client + logger.Info("connected to redis") - util.Logger.Info("connected to redis") + return client } -func CloseRedis() { - if Redis == nil { +func CloseRedis(client *redis.Client, logger *slog.Logger) { + if client == nil { return } - err := Redis.Close() + err := client.Close() if err != nil { - util.Logger.Error("failed to close redis connection", "error", err) + logger.Error("failed to close redis connection", "error", err) } else { - util.Logger.Info("redis connection closed") + logger.Info("redis connection closed") } } diff --git a/backend/internal/dependency/dependency.go b/backend/internal/dependency/dependency.go new file mode 100644 index 0000000..1d77735 --- /dev/null +++ b/backend/internal/dependency/dependency.go @@ -0,0 +1,24 @@ +package dependency + +import ( + "gorm.io/gorm" + "log/slog" + "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/redis/go-redis/v9" +) + +type Dependency struct { + Cfg *config.Config + DB *gorm.DB + Redis *redis.Client + Logger *slog.Logger +} + +func NewDependency(cfg *config.Config, db *gorm.DB, redis *redis.Client, logger *slog.Logger) *Dependency { + return &Dependency{ + Cfg: cfg, + DB: db, + Redis: redis, + Logger: logger, + } +} diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index af1db9b..2e77242 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -5,12 +5,13 @@ import ( "github.com/gin-gonic/gin" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) const PrefixBearer = "Bearer " -func Auth() gin.HandlerFunc { +func Auth(dep *dependency.Dependency) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") @@ -21,7 +22,7 @@ func Auth() gin.HandlerFunc { tokenString := authHeader[len(PrefixBearer):] - userJwtPayload, err := jwt.ValidateUserTokenGeneric(tokenString) + userJwtPayload, err := jwt.ValidateUserTokenGeneric(dep, tokenString) if err != nil { _ = c.AbortWithError(401, NewAuthError(401, "Invalid or expired token")) return diff --git a/backend/internal/routers/dev_router.go b/backend/internal/routers/dev_router.go deleted file mode 100644 index 72834a6..0000000 --- a/backend/internal/routers/dev_router.go +++ /dev/null @@ -1,20 +0,0 @@ -package routers - -import ( - "net/http" - - "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/config" - "github.com/paularynty/transcendence/auth-service-go/internal/db" -) - -func DevRouter(r *gin.RouterGroup) { - if config.Cfg.GinMode != "debug" { - return - } - - r.GET("/reset", func(c *gin.Context) { - db.ResetDB() - c.JSON(http.StatusOK, gin.H{"message": "ok"}) - }) -} diff --git a/backend/internal/routers/dev_router_test.go b/backend/internal/routers/dev_router_test.go deleted file mode 100644 index 9f385af..0000000 --- a/backend/internal/routers/dev_router_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package routers - -import ( - "encoding/json" - "log/slog" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - - "github.com/paularynty/transcendence/auth-service-go/internal/config" - "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/util" -) - -func TestDevRouterResetDebugMode(t *testing.T) { - gin.SetMode(gin.TestMode) - util.InitLogger(slog.LevelError) - - prevCfg := config.Cfg - config.Cfg = &config.Config{GinMode: "debug"} - t.Cleanup(func() { - config.Cfg = prevCfg - }) - - db.ConnectDB("file::memory:?cache=shared") - t.Cleanup(func() { - if db.DB != nil { - sqlDB, err := db.DB.DB() - if err == nil { - _ = sqlDB.Close() - } - db.DB = nil - } - }) - - if err := db.DB.Create(&db.User{Username: "tester", Email: "tester@example.com"}).Error; err != nil { - t.Fatalf("failed to seed user: %v", err) - } - - router := gin.New() - DevRouter(router.Group("/api/dev")) - - req := httptest.NewRequest(http.MethodGet, "/api/dev/reset", nil) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - - var payload map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &payload); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if payload["message"] != "ok" { - t.Fatalf("unexpected response message: %v", payload) - } - - var count int64 - if err := db.DB.Model(&db.User{}).Count(&count).Error; err != nil { - t.Fatalf("failed to count users: %v", err) - } - - if count != 0 { - t.Fatalf("expected user table to be empty, found %d rows", count) - } -} - -func TestDevRouterNoopOutsideDebug(t *testing.T) { - gin.SetMode(gin.TestMode) - - prevCfg := config.Cfg - config.Cfg = &config.Config{GinMode: "release"} - t.Cleanup(func() { - config.Cfg = prevCfg - }) - - router := gin.New() - DevRouter(router.Group("/api/dev")) - - req := httptest.NewRequest(http.MethodGet, "/api/dev/reset", nil) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusNotFound { - t.Fatalf("expected status 404 when debug routes disabled, got %d", resp.Code) - } -} diff --git a/backend/internal/routers/users_router.go b/backend/internal/routers/users_router.go index 9ca89f3..d5b015f 100644 --- a/backend/internal/routers/users_router.go +++ b/backend/internal/routers/users_router.go @@ -3,15 +3,15 @@ package routers import ( "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/handler" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/service" ) -func UsersRouter(r *gin.RouterGroup) { - h := &handler.UserHandler{Service: service.NewUserService(db.DB, db.Redis)} +func UsersRouter(r *gin.RouterGroup, dep *dependency.Dependency) { + h := &handler.UserHandler{Service: service.NewUserService(dep)} // Public endpoints r.POST("/", middleware.ValidateBody[dto.CreateUserRequest](), h.CreateUserHandler) @@ -22,7 +22,7 @@ func UsersRouter(r *gin.RouterGroup) { // Authenticated endpoints auth := r.Group("") - auth.Use(middleware.Auth()) + auth.Use(middleware.Auth(dep)) auth.GET("/me", h.GetLoggedUserProfileHandler) auth.PUT("/password", middleware.ValidateBody[dto.UpdateUserPasswordRequest](), h.UpdateLoggedUserPasswordHandler) diff --git a/backend/internal/service/friend_service.go b/backend/internal/service/friend_service.go index c0a41a1..72d29ad 100644 --- a/backend/internal/service/friend_service.go +++ b/backend/internal/service/friend_service.go @@ -11,7 +11,7 @@ import ( ) func (s *UserService) GetAllUsersLimitedInfo(ctx context.Context) ([]dto.SimpleUser, error) { - modelUsers, err := gorm.G[model.User](s.DB).Find(ctx) + modelUsers, err := gorm.G[model.User](s.Dep.DB).Find(ctx) if err != nil { return nil, err } @@ -25,7 +25,7 @@ func (s *UserService) GetAllUsersLimitedInfo(ctx context.Context) ([]dto.SimpleU } func (s *UserService) GetUserFriends(ctx context.Context, userID uint) ([]dto.FriendResponse, error) { - friends, err := gorm.G[model.Friend](s.DB).Preload("Friend", nil).Where("user_id = ?", userID).Find(ctx) + friends, err := gorm.G[model.Friend](s.Dep.DB).Preload("Friend", nil).Where("user_id = ?", userID).Find(ctx) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (s *UserService) AddNewFriend(ctx context.Context, userID uint, request *dt FriendID: request.UserID, } - err := gorm.G[model.Friend](s.DB).Create(ctx, &newFriend) + err := gorm.G[model.Friend](s.Dep.DB).Create(ctx, &newFriend) if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return middleware.NewAuthError(409, "friend already added") diff --git a/backend/internal/service/google_oauth_service.go b/backend/internal/service/google_oauth_service.go index 6c84158..0c48117 100644 --- a/backend/internal/service/google_oauth_service.go +++ b/backend/internal/service/google_oauth_service.go @@ -11,31 +11,30 @@ import ( "cloud.google.com/go/auth/credentials/idtoken" "github.com/google/uuid" - "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" - "github.com/paularynty/transcendence/auth-service-go/internal/util" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "gorm.io/gorm" ) func (s *UserService) GetGoogleOAuthURL(ctx context.Context) (string, error) { - state, err := jwt.SignOauthStateToken() + state, err := jwt.SignOauthStateToken(s.Dep) if err != nil { - util.Logger.Error("failed to sign oauth state token:", "err", err) + s.Dep.Logger.Error("failed to sign oauth state token:", "err", err) return "", err } u, err := url.Parse(BaseGoogleOAuthURL) if err != nil { - util.Logger.Error("failed to parse google oauth base url:", "err", err) + s.Dep.Logger.Error("failed to parse google oauth base url:", "err", err) return "", err } q := u.Query() - q.Set("client_id", config.Cfg.GoogleClientId) - q.Set("redirect_uri", config.Cfg.GoogleRedirectUri) + q.Set("client_id", s.Dep.Cfg.GoogleClientId) + q.Set("redirect_uri", s.Dep.Cfg.GoogleRedirectUri) q.Set("response_type", "code") q.Set("scope", "openid email profile") q.Set("state", state) @@ -45,10 +44,10 @@ func (s *UserService) GetGoogleOAuthURL(ctx context.Context) (string, error) { return u.String(), nil } -func assembleFrontendRedirectURL(token *string, errMsg *string) string { - u, err := url.Parse(config.Cfg.FrontendUrl + "/user/oauth-callback-google") +func assembleFrontendRedirectURL(dep *dependency.Dependency, token *string, errMsg *string) string { + u, err := url.Parse(dep.Cfg.FrontendUrl + "/user/oauth-callback-google") if err != nil { - util.Logger.Error("failed to parse frontend redirect url:", "err", err) + dep.Logger.Error("failed to parse frontend redirect url:", "err", err) return "/unrecovered-error" } @@ -64,12 +63,12 @@ func assembleFrontendRedirectURL(token *string, errMsg *string) string { return u.String() } -var ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { +var ExchangeCodeForTokens = func(dep *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { data := url.Values{} data.Set("code", code) - data.Set("client_id", config.Cfg.GoogleClientId) - data.Set("client_secret", config.Cfg.GoogleClientSecret) - data.Set("redirect_uri", config.Cfg.GoogleRedirectUri) + data.Set("client_id", dep.Cfg.GoogleClientId) + data.Set("client_secret", dep.Cfg.GoogleClientSecret) + data.Set("redirect_uri", dep.Cfg.GoogleRedirectUri) data.Set("grant_type", "authorization_code") req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode())) @@ -99,7 +98,7 @@ var ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Pay return nil, err } - payload, err := idtoken.Validate(ctx, tokenResp.IdToken, config.Cfg.GoogleClientId) + payload, err := idtoken.Validate(ctx, tokenResp.IdToken, dep.Cfg.GoogleClientId) if err != nil { return nil, err } @@ -197,7 +196,7 @@ func (s *UserService) createNewUserFromGoogleInfo(ctx context.Context, googleUse TwoFAToken: nil, } - err := gorm.G[model.User](s.DB).Create(ctx, &modelUser) + err := gorm.G[model.User](s.Dep.DB).Create(ctx, &modelUser) if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { if !isRetry { @@ -211,53 +210,53 @@ func (s *UserService) createNewUserFromGoogleInfo(ctx context.Context, googleUse return &modelUser, nil } -func HandleGoogleOAuthCallbackError(err error, errMsg string) string { +func HandleGoogleOAuthCallbackError(dep *dependency.Dependency, err error, errMsg string) string { publicMsg := "Failed to handle Google OAuth callback." - util.Logger.Error(errMsg, "error", err) - return assembleFrontendRedirectURL(nil, &publicMsg) + dep.Logger.Error(errMsg, "error", err) + return assembleFrontendRedirectURL(dep, nil, &publicMsg) } func (s *UserService) HandleGoogleOAuthCallback(ctx context.Context, code string, state string) string { var finalUserID uint - claims, err := jwt.ValidateOauthStateToken(state) + claims, err := jwt.ValidateOauthStateToken(s.Dep, state) if err != nil || claims.Type != jwt.GoogleOAuthStateType { - return HandleGoogleOAuthCallbackError(err, "invalid oauth state token") + return HandleGoogleOAuthCallbackError(s.Dep, err, "invalid oauth state token") } - googlePayload, err := ExchangeCodeForTokens(ctx, code) + googlePayload, err := ExchangeCodeForTokens(s.Dep, ctx, code) if err != nil { - return HandleGoogleOAuthCallbackError(err, "failed to exchange code for tokens") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to exchange code for tokens") } googleUserInfo, err := FetchGoogleUserInfo(googlePayload) if err != nil { - return HandleGoogleOAuthCallbackError(err, "failed to fetch google user info from id token") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to fetch google user info from id token") } - modelUser, err := gorm.G[model.User](s.DB).Where("google_oauth_id = ?", googleUserInfo.ID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("google_oauth_id = ?", googleUserInfo.ID).First(ctx) if err == nil { // User with this Google OAuth ID exists, log them in finalUserID = modelUser.ID } else if !errors.Is(err, gorm.ErrRecordNotFound) { - return HandleGoogleOAuthCallbackError(err, "failed to query user by google oauth id") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to query user by google oauth id") } else { // No user with this Google OAuth ID, check if a user with this email exists - modelUser, err = gorm.G[model.User](s.DB).Where("email = ?", googleUserInfo.Email).First(ctx) + modelUser, err = gorm.G[model.User](s.Dep.DB).Where("email = ?", googleUserInfo.Email).First(ctx) if err == nil { // User with this email exists, link Google account err = s.linkGoogleAccountToExistingUser(ctx, &modelUser, googleUserInfo) if err != nil { // Failed to link Google account - return HandleGoogleOAuthCallbackError(err, "failed to link google account to existing user") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to link google account to existing user") } // Successfully linked Google account finalUserID = modelUser.ID } else if !errors.Is(err, gorm.ErrRecordNotFound) { - return HandleGoogleOAuthCallbackError(err, "failed to query user by email") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to query user by email") } else { // No user with this email exists, create a new user newUser, err := s.createNewUserFromGoogleInfo(ctx, googleUserInfo, false) if err != nil { - return HandleGoogleOAuthCallbackError(err, "failed to create new user from google info") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to create new user from google info") } finalUserID = newUser.ID @@ -265,13 +264,13 @@ func (s *UserService) HandleGoogleOAuthCallback(ctx context.Context, code string } if finalUserID == 0 { - return HandleGoogleOAuthCallbackError(errors.New("finalUserID is zero"), "internal error determining final user ID") + return HandleGoogleOAuthCallbackError(s.Dep, errors.New("finalUserID is zero"), "internal error determining final user ID") } userToken, err := s.issueNewTokenForUser(ctx, finalUserID, false) if err != nil { - return HandleGoogleOAuthCallbackError(err, "failed to issue new token for user") + return HandleGoogleOAuthCallbackError(s.Dep, err, "failed to issue new token for user") } - return assembleFrontendRedirectURL(&userToken, nil) + return assembleFrontendRedirectURL(s.Dep, &userToken, nil) } diff --git a/backend/internal/service/helper.go b/backend/internal/service/helper.go index 4d21991..7521cd4 100644 --- a/backend/internal/service/helper.go +++ b/backend/internal/service/helper.go @@ -7,10 +7,9 @@ import ( "strings" "time" - "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/util" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "github.com/redis/go-redis/v9" "gorm.io/gorm" @@ -19,19 +18,18 @@ import ( const HeartBeatPrefix = "heartbeat:" -func NewUserService(db *gorm.DB, redis *redis.Client) *UserService { +func NewUserService(dep *dependency.Dependency) *UserService { - if db == nil { + if dep.DB == nil { panic("UserService: db is nil") } - if config.Cfg.IsRedisEnabled && redis == nil { + if dep.Cfg.IsRedisEnabled && dep.Redis == nil { panic("UserService: redis is enabled but redis client is nil") } return &UserService{ - DB: db, - Redis: redis, + Dep: dep, } } @@ -102,7 +100,7 @@ func (s *UserService) updateHeartBeatByDB(userID uint) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - err := s.DB.WithContext(ctx).Clauses(clause.OnConflict{ + err := s.Dep.DB.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "user_id"}}, UpdateAll: true, }).Create(&model.HeartBeat{ @@ -111,7 +109,7 @@ func (s *UserService) updateHeartBeatByDB(userID uint) { }).Error if err != nil { - util.Logger.Warn("failed to update heartbeat for user", fmt.Sprint(userID), err.Error()) + s.Dep.Logger.Warn("failed to update heartbeat for user", fmt.Sprint(userID), err.Error()) } }() } @@ -121,19 +119,19 @@ func (s *UserService) updateHeartBeatByRedis(userID uint) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - err := s.Redis.ZAdd(ctx, HeartBeatPrefix, redis.Z{ + err := s.Dep.Redis.ZAdd(ctx, HeartBeatPrefix, redis.Z{ Score: float64(time.Now().Unix()), Member: userID, }).Err() if err != nil { - util.Logger.Warn("failed to update heartbeat for user", fmt.Sprint(userID), err.Error()) + s.Dep.Logger.Warn("failed to update heartbeat for user", fmt.Sprint(userID), err.Error()) } }() } func (s *UserService) updateHeartBeat(userID uint) { - if config.Cfg.IsRedisEnabled { + if s.Dep.Cfg.IsRedisEnabled { s.updateHeartBeatByRedis(userID) } else { s.updateHeartBeatByDB(userID) @@ -141,7 +139,7 @@ func (s *UserService) updateHeartBeat(userID uint) { } func (s *UserService) getOnlineStatusByDB(ctx context.Context) ([]model.HeartBeat, error) { - onlineStatus, err := gorm.G[model.HeartBeat](s.DB).Where("last_seen_at > ?", time.Now().Add(-2*time.Minute)).Find(ctx) + onlineStatus, err := gorm.G[model.HeartBeat](s.Dep.DB).Where("last_seen_at > ?", time.Now().Add(-2*time.Minute)).Find(ctx) if err != nil { return nil, err } @@ -155,9 +153,9 @@ func (s *UserService) clearExpiredHeartBeatsByRedis() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - err := s.Redis.ZRemRangeByScore(ctx, HeartBeatPrefix, "-inf", strconv.FormatInt(time.Now().Add(-2*time.Minute).Unix(), 10)).Err() + err := s.Dep.Redis.ZRemRangeByScore(ctx, HeartBeatPrefix, "-inf", strconv.FormatInt(time.Now().Add(-2*time.Minute).Unix(), 10)).Err() if err != nil { - util.Logger.Warn("failed to clear expired heartbeats from redis", "err", err) + s.Dep.Logger.Warn("failed to clear expired heartbeats from redis", "err", err) } }() } @@ -165,7 +163,7 @@ func (s *UserService) clearExpiredHeartBeatsByRedis() { func (s *UserService) getOnlineStatusByRedis(ctx context.Context) ([]model.HeartBeat, error) { heartBeats := make([]model.HeartBeat, 0) - zs, err := s.Redis.ZRangeByScoreWithScores(ctx, HeartBeatPrefix, &redis.ZRangeBy{ + zs, err := s.Dep.Redis.ZRangeByScoreWithScores(ctx, HeartBeatPrefix, &redis.ZRangeBy{ Min: strconv.FormatInt(time.Now().Add(-2*time.Minute).Unix(), 10), Max: "+inf", }).Result() @@ -192,7 +190,7 @@ func (s *UserService) getOnlineStatusByRedis(ctx context.Context) ([]model.Heart } func (s *UserService) getOnlineStatus(ctx context.Context) ([]model.HeartBeat, error) { - if config.Cfg.IsRedisEnabled { + if s.Dep.Cfg.IsRedisEnabled { return s.getOnlineStatusByRedis(ctx) } else { return s.getOnlineStatusByDB(ctx) @@ -206,18 +204,18 @@ func buildTokenKey(userID uint, token string) string { func (s *UserService) issueNewTokenForUserByDB(ctx context.Context, userID uint, revokeAllTokens bool) (string, error) { if revokeAllTokens { - res := s.DB.WithContext(ctx).Exec("DELETE FROM tokens WHERE user_id = ?", userID) + res := s.Dep.DB.WithContext(ctx).Exec("DELETE FROM tokens WHERE user_id = ?", userID) if res.Error != nil { return "", res.Error } } - token, err := jwt.SignUserToken(userID) + token, err := jwt.SignUserToken(s.Dep, userID) if err != nil { return "", err } - err = gorm.G[model.Token](s.DB).Create(ctx, &model.Token{ + err = gorm.G[model.Token](s.Dep.DB).Create(ctx, &model.Token{ UserID: userID, Token: token, }) @@ -234,9 +232,9 @@ func (s *UserService) issueNewTokenForUserByRedis(ctx context.Context, userID ui if revokeAllTokens { // A rough way to delete all tokens for the user - iter := s.Redis.Scan(ctx, 0, buildTokenKey(userID, "*"), 100).Iterator() + iter := s.Dep.Redis.Scan(ctx, 0, buildTokenKey(userID, "*"), 100).Iterator() for iter.Next(ctx) { - err := s.Redis.Del(ctx, iter.Val()).Err() + err := s.Dep.Redis.Del(ctx, iter.Val()).Err() if err != nil { return "", err } @@ -246,12 +244,12 @@ func (s *UserService) issueNewTokenForUserByRedis(ctx context.Context, userID ui } } - token, err := jwt.SignUserToken(userID) + token, err := jwt.SignUserToken(s.Dep, userID) if err != nil { return "", err } - err = s.Redis.Set(ctx, buildTokenKey(userID, token), "", time.Duration(config.Cfg.UserTokenExpiry)*time.Second).Err() + err = s.Dep.Redis.Set(ctx, buildTokenKey(userID, token), "", time.Duration(s.Dep.Cfg.UserTokenExpiry)*time.Second).Err() if err != nil { return "", err } @@ -262,7 +260,7 @@ func (s *UserService) issueNewTokenForUserByRedis(ctx context.Context, userID ui } func (s *UserService) issueNewTokenForUser(ctx context.Context, userID uint, revokeAllTokens bool) (string, error) { - if config.Cfg.IsRedisEnabled { + if s.Dep.Cfg.IsRedisEnabled { return s.issueNewTokenForUserByRedis(ctx, userID, revokeAllTokens) } else { return s.issueNewTokenForUserByDB(ctx, userID, revokeAllTokens) diff --git a/backend/internal/service/twofa_service.go b/backend/internal/service/twofa_service.go index 705543d..c0f8e8b 100644 --- a/backend/internal/service/twofa_service.go +++ b/backend/internal/service/twofa_service.go @@ -15,7 +15,7 @@ import ( ) func (s *UserService) StartTwoFaSetup(ctx context.Context, userID uint) (*dto.TwoFASetupResponse, error) { - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -41,12 +41,12 @@ func (s *UserService) StartTwoFaSetup(ctx context.Context, userID uint) (*dto.Tw twoFAToken := TwoFAPrePrefix + secret.Secret() - _, err = gorm.G[model.User](s.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", twoFAToken) + _, err = gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", twoFAToken) if err != nil { return nil, err } - setupToken, err := jwt.SignTwoFASetupToken(userID, secret.Secret()) + setupToken, err := jwt.SignTwoFASetupToken(s.Dep, userID, secret.Secret()) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (s *UserService) StartTwoFaSetup(ctx context.Context, userID uint) (*dto.Tw } func (s *UserService) ConfirmTwoFaSetup(ctx context.Context, userID uint, request *dto.TwoFAConfirmRequest) (*dto.UserWithTokenResponse, error) { - claims, err := jwt.ValidateTwoFASetupToken(request.SetupToken) + claims, err := jwt.ValidateTwoFASetupToken(s.Dep, request.SetupToken) if err != nil || claims.Type != jwt.TwoFASetupType { return nil, middleware.NewAuthError(400, "invalid setup token") } @@ -68,7 +68,7 @@ func (s *UserService) ConfirmTwoFaSetup(ctx context.Context, userID uint, reques return nil, middleware.NewAuthError(400, "setup token does not match user") } - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -94,7 +94,7 @@ func (s *UserService) ConfirmTwoFaSetup(ctx context.Context, userID uint, reques return nil, middleware.NewAuthError(400, "invalid 2FA code") } - _, err = gorm.G[model.User](s.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", twoFaSecret) + _, err = gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", twoFaSecret) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func (s *UserService) ConfirmTwoFaSetup(ctx context.Context, userID uint, reques } func (s *UserService) DisableTwoFA(ctx context.Context, userID uint, request *dto.DisableTwoFARequest) (*dto.UserWithTokenResponse, error) { - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -133,7 +133,7 @@ func (s *UserService) DisableTwoFA(ctx context.Context, userID uint, request *dt return nil, err } - _, err = gorm.G[model.User](s.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", nil) + _, err = gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).Update(ctx, "two_fa_token", nil) if err != nil { return nil, err } @@ -148,12 +148,12 @@ func (s *UserService) DisableTwoFA(ctx context.Context, userID uint, request *dt } func (s *UserService) SubmitTwoFAChallenge(ctx context.Context, request *dto.TwoFAChallengeRequest) (*dto.UserWithTokenResponse, error) { - claims, err := jwt.ValidateTwoFAToken(request.SessionToken) + claims, err := jwt.ValidateTwoFAToken(s.Dep, request.SessionToken) if err != nil || claims.Type != jwt.TwoFATokenType { return nil, middleware.NewAuthError(400, "invalid session token") } - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", claims.UserID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", claims.UserID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 2a7a558..97524d4 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -9,8 +9,8 @@ import ( "golang.org/x/crypto/bcrypt" "gorm.io/gorm" - "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" @@ -23,8 +23,7 @@ const MaxAvatarSize = 1 * 1024 * 1024 // 1 MB const BaseGoogleOAuthURL = "https://accounts.google.com/o/oauth2/v2/auth" type UserService struct { - DB *gorm.DB - Redis *redis.Client + Dep *dependency.Dependency } func (s *UserService) CreateUser(ctx context.Context, request *dto.CreateUserRequest) (*dto.UserWithoutTokenResponse, error) { @@ -45,7 +44,7 @@ func (s *UserService) CreateUser(ctx context.Context, request *dto.CreateUserReq TwoFAToken: nil, } - err = gorm.G[model.User](s.DB).Create(ctx, &modelUser) + err = gorm.G[model.User](s.Dep.DB).Create(ctx, &modelUser) if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return nil, middleware.NewAuthError(409, "username or email already in use") @@ -70,7 +69,7 @@ func (s *UserService) LoginUser(ctx context.Context, request *dto.LoginUserReque identifierField = "username" } - modelUser, err := gorm.G[model.User](s.DB).Where(identifierField+" = ?", request.Identifier.Identifier).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where(identifierField+" = ?", request.Identifier.Identifier).First(ctx) if err != nil || modelUser.PasswordHash == nil { if errors.Is(err, gorm.ErrRecordNotFound) || modelUser.PasswordHash == nil { return nil, middleware.NewAuthError(401, "invalid credentials") @@ -88,7 +87,7 @@ func (s *UserService) LoginUser(ctx context.Context, request *dto.LoginUserReque isTwoFAEnabled := isTwoFAEnabled(modelUser.TwoFAToken) if isTwoFAEnabled { - sessionToken, err := jwt.SignTwoFAToken(modelUser.ID) + sessionToken, err := jwt.SignTwoFAToken(s.Dep, modelUser.ID) if err != nil { return nil, err } @@ -112,7 +111,7 @@ func (s *UserService) LoginUser(ctx context.Context, request *dto.LoginUserReque } func (s *UserService) GetUserByID(ctx context.Context, userID uint) (*dto.UserWithoutTokenResponse, error) { - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -124,7 +123,7 @@ func (s *UserService) GetUserByID(ctx context.Context, userID uint) (*dto.UserWi } func (s *UserService) UpdateUserPassword(ctx context.Context, userID uint, request *dto.UpdateUserPasswordRequest) (*dto.UserWithTokenResponse, error) { - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -149,7 +148,7 @@ func (s *UserService) UpdateUserPassword(ctx context.Context, userID uint, reque return nil, err } - _, err = gorm.G[model.User](s.DB).Where("id = ?", userID).Update(ctx, "password_hash", string(newPasswordBytes)) + _, err = gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).Update(ctx, "password_hash", string(newPasswordBytes)) if err != nil { return nil, err } @@ -163,7 +162,7 @@ func (s *UserService) UpdateUserPassword(ctx context.Context, userID uint, reque } func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, request *dto.UpdateUserRequest) (*dto.UserWithoutTokenResponse, error) { - modelUser, err := gorm.G[model.User](s.DB).Where("id = ?", userID).First(ctx) + modelUser, err := gorm.G[model.User](s.Dep.DB).Where("id = ?", userID).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, middleware.NewAuthError(404, "user not found") @@ -175,7 +174,7 @@ func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, reques modelUser.Avatar = request.Avatar modelUser.Email = request.Email - err = s.DB.WithContext(ctx).Save(&modelUser).Error + err = s.Dep.DB.WithContext(ctx).Save(&modelUser).Error if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { @@ -188,14 +187,14 @@ func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, reques } func (s *UserService) DeleteUser(ctx context.Context, userID uint) error { - if config.Cfg.IsRedisEnabled { - err := logoutUserByRedis(ctx, s.Redis, userID) + if s.Dep.Cfg.IsRedisEnabled { + err := logoutUserByRedis(ctx, s.Dep.Redis, userID) if err != nil { return err } } - res := s.DB.WithContext(ctx).Unscoped().Delete(&model.User{}, userID) + res := s.Dep.DB.WithContext(ctx).Unscoped().Delete(&model.User{}, userID) if res.Error != nil { return res.Error } @@ -229,15 +228,15 @@ func logoutUserByRedis(ctx context.Context, redis *redis.Client, userID uint) er } func (s *UserService) LogoutUser(ctx context.Context, userID uint) error { - if config.Cfg.IsRedisEnabled { - return logoutUserByRedis(ctx, s.Redis, userID) + if s.Dep.Cfg.IsRedisEnabled { + return logoutUserByRedis(ctx, s.Dep.Redis, userID) } else { - return logoutUserByDB(ctx, s.DB, userID) + return logoutUserByDB(ctx, s.Dep.DB, userID) } } func (s *UserService) validateUserTokenDB(ctx context.Context, token string, userId uint) error { - modelToken, err := gorm.G[model.Token](s.DB).Where("token = ?", token).First(ctx) + modelToken, err := gorm.G[model.Token](s.Dep.DB).Where("token = ?", token).First(ctx) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return middleware.NewAuthError(401, "invalid token") @@ -254,7 +253,7 @@ func (s *UserService) validateUserTokenDB(ctx context.Context, token string, use } func (s *UserService) validateUserTokenRedis(ctx context.Context, token string, userId uint) error { - _, err := s.Redis.Get(ctx, buildTokenKey(userId, token)).Result() + _, err := s.Dep.Redis.Get(ctx, buildTokenKey(userId, token)).Result() if err != nil { if errors.Is(err, redis.Nil) { return middleware.NewAuthError(401, "invalid token") @@ -263,14 +262,14 @@ func (s *UserService) validateUserTokenRedis(ctx context.Context, token string, } // A rough way to implement sliding expiration - s.Redis.Expire(ctx, buildTokenKey(userId, token), time.Duration(config.Cfg.UserTokenExpiry)*time.Second) + s.Dep.Redis.Expire(ctx, buildTokenKey(userId, token), time.Duration(s.Dep.Cfg.UserTokenExpiry)*time.Second) s.updateHeartBeat(userId) return nil } func (s *UserService) ValidateUserToken(ctx context.Context, token string, userId uint) error { - if config.Cfg.IsRedisEnabled { + if s.Dep.Cfg.IsRedisEnabled { return s.validateUserTokenRedis(ctx, token, userId) } else { return s.validateUserTokenDB(ctx, token, userId) diff --git a/backend/internal/util/jwt/jwt.go b/backend/internal/util/jwt/jwt.go index 8f76365..bbd74e8 100644 --- a/backend/internal/util/jwt/jwt.go +++ b/backend/internal/util/jwt/jwt.go @@ -6,7 +6,7 @@ import ( libjwt "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" - "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" ) @@ -25,12 +25,12 @@ func generateRegisteredClaims(expiration int) libjwt.RegisteredClaims { } } -func SignUserToken(userID uint) (string, error) { - userTokenExpiry := config.Cfg.UserTokenExpiry +func SignUserToken(dep *dependency.Dependency, userID uint) (string, error) { + userTokenExpiry := dep.Cfg.UserTokenExpiry // For Redis mode, use absolute expiry to limit max token lifetime, // because the actual expiry is managed in Redis with sliding expiration. - if config.Cfg.IsRedisEnabled { - userTokenExpiry = config.Cfg.UserTokenAbsoluteExpiry + if dep.Cfg.IsRedisEnabled { + userTokenExpiry = dep.Cfg.UserTokenAbsoluteExpiry } claims := dto.UserJwtPayload{ @@ -40,7 +40,7 @@ func SignUserToken(userID uint) (string, error) { } token := libjwt.NewWithClaims(libjwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(config.Cfg.JwtSecret)) + signedToken, err := token.SignedString([]byte(dep.Cfg.JwtSecret)) if err != nil { return "", err } @@ -48,14 +48,14 @@ func SignUserToken(userID uint) (string, error) { return signedToken, nil } -func SignOauthStateToken() (string, error) { +func SignOauthStateToken(dep *dependency.Dependency) (string, error) { claims := dto.OauthStateJwtPayload{ Type: GoogleOAuthStateType, - RegisteredClaims: generateRegisteredClaims(config.Cfg.OauthStateTokenExpiry), + RegisteredClaims: generateRegisteredClaims(dep.Cfg.OauthStateTokenExpiry), } token := libjwt.NewWithClaims(libjwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(config.Cfg.JwtSecret)) + signedToken, err := token.SignedString([]byte(dep.Cfg.JwtSecret)) if err != nil { return "", err } @@ -63,16 +63,16 @@ func SignOauthStateToken() (string, error) { return signedToken, nil } -func SignTwoFASetupToken(userID uint, secret string) (string, error) { +func SignTwoFASetupToken(dep *dependency.Dependency, userID uint, secret string) (string, error) { claims := dto.TwoFaSetupJwtPayload{ UserID: userID, Secret: secret, Type: TwoFASetupType, - RegisteredClaims: generateRegisteredClaims(config.Cfg.TwoFaTokenExpiry), + RegisteredClaims: generateRegisteredClaims(dep.Cfg.TwoFaTokenExpiry), } token := libjwt.NewWithClaims(libjwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(config.Cfg.JwtSecret)) + signedToken, err := token.SignedString([]byte(dep.Cfg.JwtSecret)) if err != nil { return "", err } @@ -80,15 +80,15 @@ func SignTwoFASetupToken(userID uint, secret string) (string, error) { return signedToken, nil } -func SignTwoFAToken(userID uint) (string, error) { +func SignTwoFAToken(dep *dependency.Dependency, userID uint) (string, error) { claims := dto.TwoFaJwtPayload{ UserID: userID, Type: TwoFATokenType, - RegisteredClaims: generateRegisteredClaims(config.Cfg.TwoFaTokenExpiry), + RegisteredClaims: generateRegisteredClaims(dep.Cfg.TwoFaTokenExpiry), } token := libjwt.NewWithClaims(libjwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(config.Cfg.JwtSecret)) + signedToken, err := token.SignedString([]byte(dep.Cfg.JwtSecret)) if err != nil { return "", err } @@ -96,12 +96,12 @@ func SignTwoFAToken(userID uint) (string, error) { return signedToken, nil } -func validateToken[T libjwt.Claims](signedToken string, claims T) (T, error) { +func validateToken[T libjwt.Claims](dep *dependency.Dependency, signedToken string, claims T) (T, error) { token, err := libjwt.ParseWithClaims( signedToken, claims, func(token *libjwt.Token) (any, error) { - return []byte(config.Cfg.JwtSecret), nil + return []byte(dep.Cfg.JwtSecret), nil }, ) if err != nil { @@ -115,9 +115,9 @@ func validateToken[T libjwt.Claims](signedToken string, claims T) (T, error) { return claims, nil } -func ValidateUserTokenGeneric(signedToken string) (*dto.UserJwtPayload, error) { +func ValidateUserTokenGeneric(dep *dependency.Dependency, signedToken string) (*dto.UserJwtPayload, error) { claims := &dto.UserJwtPayload{} - parsedClaims, err := validateToken(signedToken, claims) + parsedClaims, err := validateToken(dep, signedToken, claims) if err != nil { return nil, err } @@ -129,9 +129,9 @@ func ValidateUserTokenGeneric(signedToken string) (*dto.UserJwtPayload, error) { return parsedClaims, nil } -func ValidateOauthStateToken(signedToken string) (*dto.OauthStateJwtPayload, error) { +func ValidateOauthStateToken(dep *dependency.Dependency, signedToken string) (*dto.OauthStateJwtPayload, error) { claims := &dto.OauthStateJwtPayload{} - parsedClaims, err := validateToken(signedToken, claims) + parsedClaims, err := validateToken(dep, signedToken, claims) if err != nil { return nil, err } @@ -143,9 +143,9 @@ func ValidateOauthStateToken(signedToken string) (*dto.OauthStateJwtPayload, err return parsedClaims, nil } -func ValidateTwoFAToken(signedToken string) (*dto.TwoFaJwtPayload, error) { +func ValidateTwoFAToken(dep *dependency.Dependency, signedToken string) (*dto.TwoFaJwtPayload, error) { claims := &dto.TwoFaJwtPayload{} - parsedClaims, err := validateToken(signedToken, claims) + parsedClaims, err := validateToken(dep, signedToken, claims) if err != nil { return nil, err } @@ -157,9 +157,9 @@ func ValidateTwoFAToken(signedToken string) (*dto.TwoFaJwtPayload, error) { return parsedClaims, nil } -func ValidateTwoFASetupToken(signedToken string) (*dto.TwoFaSetupJwtPayload, error) { +func ValidateTwoFASetupToken(dep *dependency.Dependency, signedToken string) (*dto.TwoFaSetupJwtPayload, error) { claims := &dto.TwoFaSetupJwtPayload{} - parsedClaims, err := validateToken(signedToken, claims) + parsedClaims, err := validateToken(dep, signedToken, claims) if err != nil { return nil, err } diff --git a/backend/internal/util/log.go b/backend/internal/util/log.go index 78f736c..eee3fcb 100644 --- a/backend/internal/util/log.go +++ b/backend/internal/util/log.go @@ -8,14 +8,13 @@ import ( "github.com/lmittmann/tint" ) -var Logger *slog.Logger - -func InitLogger(level slog.Leveler) { - Logger = slog.New(tint.NewHandler(os.Stdout, &tint.Options{ +func GetLogger(level slog.Leveler) *slog.Logger { + logger := slog.New(tint.NewHandler(os.Stdout, &tint.Options{ Level: level, TimeFormat: time.Kitchen, AddSource: true, })) - slog.SetDefault(Logger) + slog.SetDefault(logger) + return logger } From 0ded076f50175cb7831539ab1af9b8778026c102 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Fri, 30 Jan 2026 00:18:16 +0200 Subject: [PATCH 4/5] Refactor/backend: refactor test cases --- backend/cmd/server/main.go | 2 +- backend/internal/config/config_test.go | 95 +++ backend/internal/db/db.go | 2 +- backend/internal/dependency/dependency.go | 6 +- backend/internal/dto/schemas_test.go | 6 +- backend/internal/middleware/auth_test.go | 41 +- .../routers/users_router_failure_test.go | 365 ---------- .../routers/users_router_redis_test.go | 187 ----- backend/internal/routers/users_router_test.go | 680 +++++++++++++----- .../internal/service/friend_service_test.go | 188 +++-- .../service/google_oauth_service_test.go | 50 +- backend/internal/service/helper.go | 2 +- backend/internal/service/helper_test.go | 6 +- .../internal/service/redis_service_test.go | 48 +- backend/internal/service/setup_test.go | 68 +- .../internal/service/twofa_service_test.go | 38 +- backend/internal/service/user_service.go | 2 +- backend/internal/service/user_service_test.go | 597 +++++++-------- backend/internal/testutil/testutil.go | 50 ++ backend/internal/util/jwt/token_test.go | 191 ++--- 20 files changed, 1284 insertions(+), 1340 deletions(-) create mode 100644 backend/internal/config/config_test.go delete mode 100644 backend/internal/routers/users_router_failure_test.go delete mode 100644 backend/internal/routers/users_router_redis_test.go create mode 100644 backend/internal/testutil/testutil.go diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 0bfaa04..739ac45 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -24,8 +24,8 @@ import ( "github.com/gin-contrib/cors" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) func SetupRouter(dep *dependency.Dependency) *gin.Engine { diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go new file mode 100644 index 0000000..87f57d7 --- /dev/null +++ b/backend/internal/config/config_test.go @@ -0,0 +1,95 @@ +package config + +import ( + "testing" +) + +func assertPanics(t *testing.T, fn func(), name string) { + t.Helper() + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for %s", name) + } + }() + fn() +} + +func assertNotPanics(t *testing.T, fn func(), name string) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic for %s: %v", name, r) + } + }() + fn() +} + +func TestGetEnvStrOrDefault(t *testing.T) { + t.Setenv("TEST_STR", "") + if got := getEnvStrOrDefault("TEST_STR", "fallback"); got != "fallback" { + t.Fatalf("expected default value, got %q", got) + } + + t.Setenv("TEST_STR", "value") + if got := getEnvStrOrDefault("TEST_STR", "fallback"); got != "value" { + t.Fatalf("expected env value, got %q", got) + } +} + +func TestGetEnvStrOrPanic(t *testing.T) { + t.Setenv("TEST_PANIC", "") + assertPanics(t, func() { + _ = getEnvStrOrPanic("TEST_PANIC") + }, "empty env") + + t.Setenv("TEST_PANIC", "value") + assertNotPanics(t, func() { + if got := getEnvStrOrPanic("TEST_PANIC"); got != "value" { + t.Fatalf("expected env value, got %q", got) + } + }, "set env") +} + +func TestGetEnvIntOrDefault(t *testing.T) { + t.Setenv("TEST_INT", "") + if got := getEnvIntOrDefault("TEST_INT", 7); got != 7 { + t.Fatalf("expected default value, got %d", got) + } + + t.Setenv("TEST_INT", "42") + if got := getEnvIntOrDefault("TEST_INT", 7); got != 42 { + t.Fatalf("expected env value, got %d", got) + } + + t.Setenv("TEST_INT", "not-an-int") + if got := getEnvIntOrDefault("TEST_INT", 7); got != 7 { + t.Fatalf("expected default value for invalid int, got %d", got) + } +} + +func TestLoadConfigFromEnv_PanicsOnMissingRequired(t *testing.T) { + t.Setenv("JWT_SECRET", "jwt") + t.Setenv("GOOGLE_CLIENT_ID", "client") + t.Setenv("GOOGLE_CLIENT_SECRET", "secret") + + assertNotPanics(t, func() { + _ = LoadConfigFromEnv() + }, "all required set") + + t.Setenv("JWT_SECRET", "") + assertPanics(t, func() { + _ = LoadConfigFromEnv() + }, "JWT_SECRET unset") + + t.Setenv("JWT_SECRET", "jwt") + t.Setenv("GOOGLE_CLIENT_ID", "") + assertPanics(t, func() { + _ = LoadConfigFromEnv() + }, "GOOGLE_CLIENT_ID unset") + + t.Setenv("GOOGLE_CLIENT_ID", "client") + t.Setenv("GOOGLE_CLIENT_SECRET", "") + assertPanics(t, func() { + _ = LoadConfigFromEnv() + }, "GOOGLE_CLIENT_SECRET unset") +} diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go index 8379022..923361f 100644 --- a/backend/internal/db/db.go +++ b/backend/internal/db/db.go @@ -8,7 +8,7 @@ import ( "gorm.io/gorm" ) -func GetDB(dbName string, logger *slog.Logger) *gorm.DB { +func GetDB(dbName string, logger *slog.Logger) *gorm.DB { var err error db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) diff --git a/backend/internal/dependency/dependency.go b/backend/internal/dependency/dependency.go index 1d77735..0bad228 100644 --- a/backend/internal/dependency/dependency.go +++ b/backend/internal/dependency/dependency.go @@ -1,10 +1,10 @@ package dependency -import ( - "gorm.io/gorm" - "log/slog" +import ( "github.com/paularynty/transcendence/auth-service-go/internal/config" "github.com/redis/go-redis/v9" + "gorm.io/gorm" + "log/slog" ) type Dependency struct { diff --git a/backend/internal/dto/schemas_test.go b/backend/internal/dto/schemas_test.go index fade7ee..fa3d3ca 100644 --- a/backend/internal/dto/schemas_test.go +++ b/backend/internal/dto/schemas_test.go @@ -116,8 +116,8 @@ func TestUsernameValidatorRules(t *testing.T) { {"ValidTrimmedRight", "valid-user ", false}, {"ValidTrimmedLeft", " valid-user", false}, {"EmptyAfterTrim", " ", true}, - {"TooShort", "abcde", true}, - {"TooShortAfterTrim", " abcde ", true}, + {"TooShort", "ab", true}, + {"TooShortAfterTrim", " ab ", true}, {"ContainsSpace", "user name", true}, {"IllegalChars", "user@name", true}, } @@ -196,7 +196,7 @@ func TestIdentifierValidatorAcceptsUsernameOrEmail(t *testing.T) { {"TrimmedEmailLeft", " user@example.com", false}, {"EmptyAfterTrim", " ", true}, {"Invalid", "???", true}, - {"TooShort", "abcde", true}, + {"TooShort", "ab", true}, } for _, tc := range cases { diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go index 2eab42c..6f77ec9 100644 --- a/backend/internal/middleware/auth_test.go +++ b/backend/internal/middleware/auth_test.go @@ -8,34 +8,29 @@ import ( "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) -func setupAuthConfig(t *testing.T) func() { +func setupAuthDep(t *testing.T) *dependency.Dependency { t.Helper() - prev := config.Cfg - config.Cfg = &config.Config{ - JwtSecret: "test-secret-key", - UserTokenExpiry: 3600, - OauthStateTokenExpiry: 120, - TwoFaTokenExpiry: 300, - } - - return func() { - config.Cfg = prev - } + cfg := testutil.NewTestConfig() + cfg.JwtSecret = "test-secret-key" + cfg.UserTokenExpiry = 3600 + cfg.OauthStateTokenExpiry = 120 + cfg.TwoFaTokenExpiry = 300 + return testutil.NewTestDependency(cfg, nil, nil, nil) } func TestAuthMiddlewareRejectsMissingToken(t *testing.T) { gin.SetMode(gin.TestMode) - cleanup := setupAuthConfig(t) - defer cleanup() + dep := setupAuthDep(t) r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth()) + r.Use(middleware.Auth(dep)) r.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) @@ -61,17 +56,16 @@ func TestAuthMiddlewareRejectsMissingToken(t *testing.T) { func TestAuthMiddlewareAllowsValidToken(t *testing.T) { gin.SetMode(gin.TestMode) - cleanup := setupAuthConfig(t) - defer cleanup() + dep := setupAuthDep(t) - token, err := jwt.SignUserToken(99) + token, err := jwt.SignUserToken(dep, 99) if err != nil { t.Fatalf("failed to sign user token: %v", err) } r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth()) + r.Use(middleware.Auth(dep)) r.GET("/protected", func(c *gin.Context) { userID, ok := c.Get("userID") if !ok { @@ -103,17 +97,16 @@ func TestAuthMiddlewareAllowsValidToken(t *testing.T) { func TestAuthMiddlewareRejectsInvalidToken(t *testing.T) { gin.SetMode(gin.TestMode) - cleanup := setupAuthConfig(t) - defer cleanup() + dep := setupAuthDep(t) - token, err := jwt.SignTwoFAToken(10) + token, err := jwt.SignTwoFAToken(dep, 10) if err != nil { t.Fatalf("failed to sign 2fa token: %v", err) } r := gin.New() r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth()) + r.Use(middleware.Auth(dep)) r.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) diff --git a/backend/internal/routers/users_router_failure_test.go b/backend/internal/routers/users_router_failure_test.go deleted file mode 100644 index 0a767d9..0000000 --- a/backend/internal/routers/users_router_failure_test.go +++ /dev/null @@ -1,365 +0,0 @@ -package routers - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "log/slog" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - "time" - - "cloud.google.com/go/auth/credentials/idtoken" - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - - "github.com/paularynty/transcendence/auth-service-go/internal/config" - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/service" - "github.com/paularynty/transcendence/auth-service-go/internal/util" - "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" -) - -func setupUsersRouterTestFailure(t *testing.T) (*gin.Engine, func()) { - gin.SetMode(gin.TestMode) - - util.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelError, - })) - - prevCfg := config.Cfg - config.Cfg = &config.Config{ - JwtSecret: "test-secret", - UserTokenExpiry: 3600, - TwoFaTokenExpiry: 3600, - OauthStateTokenExpiry: 3600, - GoogleClientId: "test-client", - GoogleRedirectUri: "http://localhost/cb", - FrontendUrl: "http://localhost:3000", - } - dto.InitValidator() - - dbName := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000&_foreign_keys=on" - var err error - model.DB, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) - if err != nil { - t.Fatalf("failed to connect to db: %v", err) - } - model.DB.Exec("PRAGMA foreign_keys = ON") - - err = model.DB.AutoMigrate(&model.User{}, &model.Friend{}, &model.Token{}, &model.HeartBeat{}) - if err != nil { - t.Fatalf("failed to migrate db: %v", err) - } - - router := gin.New() - UsersRouter(router.Group("/users")) - - // Set MaxOpenConns to 1 to avoid locking issues - if model.DB != nil { - sqlDB, _ := model.DB.DB() - if sqlDB != nil { - sqlDB.SetMaxOpenConns(1) - } - } - - return router, func() { - config.Cfg = prevCfg - if model.DB != nil { - sqlDB, _ := model.DB.DB() - if sqlDB != nil { - _ = sqlDB.Close() - } - model.DB = nil - } - } -} - -func TestUsersRouter_CreateUser_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - // 1. Invalid Body - reqBody := `{"username": "u"}` // Missing email, password, invalid username length - req := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBufferString(reqBody)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusBadRequest { - t.Errorf("expected 400 for invalid body, got %d", resp.Code) - } - - // 2. Duplicate User - // Create first - validReq := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "dupuser"}, Email: "dup@e.com"}, - Password: dto.Password{Password: "pass"}, - } - body, _ := json.Marshal(validReq) - router.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(body))) - - // Try duplicate - req = httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusConflict { - t.Errorf("expected 409 for duplicate, got %d", resp.Code) - } -} - -func TestUsersRouter_LoginUser_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - // 1. Invalid Body - req := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBufferString("{}")) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusBadRequest { - t.Errorf("expected 400 for invalid body, got %d", resp.Code) - } - - // 2. User Not Found - loginReq := dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "missing"}, - Password: dto.Password{Password: "pass"}, - } - body, _ := json.Marshal(loginReq) - req = httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusUnauthorized { - t.Errorf("expected 401 for missing user, got %d", resp.Code) - } - - // 3. Invalid Credentials - // Create user - svc := service.NewUserService(model.DB, nil) - _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "loginfail"}, Email: "fail@e.com"}, - Password: dto.Password{Password: "correct"}, - }) - - loginReq.Identifier.Identifier = "loginfail" - loginReq.Password.Password = "wrong" - body, _ = json.Marshal(loginReq) - req = httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusUnauthorized { - t.Errorf("expected 401 for wrong password, got %d", resp.Code) - } -} - -func TestUsersRouter_UpdateUser_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - svc := service.NewUserService(model.DB, nil) - u, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - token, _ := jwt.SignUserToken(u.ID) - model.DB.Create(&model.Token{UserID: u.ID, Token: token}) - - // 1. Update Profile Duplicate - // Create another user - _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - - updateReq := dto.UpdateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "update_u2"}, Email: "u2@e.com"}, - } - body, _ := json.Marshal(updateReq) - req := httptest.NewRequest(http.MethodPut, "/users/me", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusConflict { - t.Errorf("expected 409 for duplicate profile update, got %d", resp.Code) - } - - // 2. Update Password Wrong Old - pwReq := dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "wrong"}, - NewPassword: dto.NewPassword{NewPassword: "newpass"}, - } - body, _ = json.Marshal(pwReq) - req = httptest.NewRequest(http.MethodPut, "/users/password", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusUnauthorized { - t.Errorf("expected 401 for wrong old password, got %d", resp.Code) - } -} - -func TestUsersRouter_Friends_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - svc := service.NewUserService(model.DB, nil) - u1, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f1"}, Email: "f1@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - u2, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f2"}, Email: "f2@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - token, _ := jwt.SignUserToken(u1.ID) - model.DB.Create(&model.Token{UserID: u1.ID, Token: token}) - - // 1. Add Self - reqBody := dto.AddNewFriendRequest{UserID: u1.ID} - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusBadRequest { - t.Errorf("expected 400 for adding self, got %d", resp.Code) - } - - // 2. Add Non-existent - reqBody = dto.AddNewFriendRequest{UserID: 999} - body, _ = json.Marshal(reqBody) - req = httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusNotFound { - t.Errorf("expected 404 for missing friend, got %d", resp.Code) - } - - // 3. Duplicate Friend - _ = svc.AddNewFriend(context.Background(), u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - - // Let DB settle - time.Sleep(200 * time.Millisecond) - - reqBody = dto.AddNewFriendRequest{UserID: u2.ID} - body, _ = json.Marshal(reqBody) - req = httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusConflict { - t.Errorf("expected 409 for duplicate friend, got %d", resp.Code) - } -} - -func TestUsersRouter_2FA_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - svc := service.NewUserService(model.DB, nil) - u, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "2fafail"}, Email: "2fafail@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - token, _ := jwt.SignUserToken(u.ID) - model.DB.Create(&model.Token{UserID: u.ID, Token: token}) - - // 1. Confirm with invalid code - setupResp, _ := svc.StartTwoFaSetup(context.Background(), u.ID) - confirmReq := dto.TwoFAConfirmRequest{ - SetupToken: setupResp.SetupToken, - TwoFACode: "000000", - } - body, _ := json.Marshal(confirmReq) - req := httptest.NewRequest(http.MethodPost, "/users/2fa/confirm", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusBadRequest { - t.Errorf("expected 400 for invalid code, got %d", resp.Code) - } - - // 2. Start Setup when Already Enabled - // Enable it correctly first - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - confirmRes, _ := svc.ConfirmTwoFaSetup(context.Background(), u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - - // Update token as confirming 2FA issues a new one - token = confirmRes.Token - - req = httptest.NewRequest(http.MethodPost, "/users/2fa/setup", nil) - req.Header.Set("Authorization", "Bearer "+token) - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusBadRequest { - t.Errorf("expected 400 for setup when already enabled, got %d", resp.Code) - } - - // 3. Disable with Wrong Password - disableReq := dto.DisableTwoFARequest{Password: dto.Password{Password: "wrong"}} - body, _ = json.Marshal(disableReq) - req = httptest.NewRequest(http.MethodPut, "/users/2fa/disable", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusUnauthorized { - t.Errorf("expected 401 for wrong password disable, got %d", resp.Code) - } -} - -func TestUsersRouter_GoogleOAuth_Failures(t *testing.T) { - router, cleanup := setupUsersRouterTestFailure(t) - defer cleanup() - - // 1. Missing Params - req := httptest.NewRequest(http.MethodGet, "/users/google/callback", nil) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - if resp.Code != http.StatusBadRequest { // 400 from handler check - t.Errorf("expected 400 for missing params, got %d", resp.Code) - } - - // 2. Invalid State/Code (Service Fail) - // We mocked service vars in other test file, but here they are originals unless we mock them again. - // Since setupUsersRouterTestFailure is separate, vars are global in `service` package. - // We should mock them to RETURN ERROR. - - origExchange := service.ExchangeCodeForTokens - defer func() { service.ExchangeCodeForTokens = origExchange }() - service.ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { - return nil, errors.New("mock error") - } - - state, _ := jwt.SignOauthStateToken() - req = httptest.NewRequest(http.MethodGet, "/users/google/callback?code=c&state="+state, nil) - resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) - - if resp.Code != http.StatusFound { // Redirect to error page - t.Errorf("expected 302 redirect to error, got %d", resp.Code) - } - loc := resp.Header().Get("Location") - if !strings.Contains(loc, "error=") { - t.Errorf("expected error param in redirect: %s", loc) - } -} diff --git a/backend/internal/routers/users_router_redis_test.go b/backend/internal/routers/users_router_redis_test.go deleted file mode 100644 index daf0dff..0000000 --- a/backend/internal/routers/users_router_redis_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package routers - -import ( - "bytes" - "context" - "encoding/json" - "log/slog" - "net/http" - "net/http/httptest" - "os" - "strconv" - "strings" - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - - "github.com/paularynty/transcendence/auth-service-go/internal/config" - db "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/util" -) - -func setupUsersRouterTestRedis(t *testing.T) (*gin.Engine, *miniredis.Miniredis, func()) { - t.Helper() - gin.SetMode(gin.TestMode) - - util.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelError, - })) - - prevCfg := config.Cfg - config.Cfg = &config.Config{ - JwtSecret: "test-secret", - UserTokenExpiry: 60, - UserTokenAbsoluteExpiry: 600, - TwoFaTokenExpiry: 3600, - OauthStateTokenExpiry: 3600, - GoogleClientId: "test-client", - GoogleRedirectUri: "http://localhost/cb", - FrontendUrl: "http://localhost:3000", - IsRedisEnabled: true, - } - dto.InitValidator() - - // DB setup matches existing patterns. - dbName := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000&_foreign_keys=on" - var err error - db.DB, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) - if err != nil { - t.Fatalf("failed to connect to db: %v", err) - } - db.DB.Exec("PRAGMA foreign_keys = ON") - - err = db.DB.AutoMigrate(&db.User{}, &db.Friend{}, &db.Token{}, &db.HeartBeat{}) - if err != nil { - t.Fatalf("failed to migrate db: %v", err) - } - - // Redis setup. - mr := miniredis.RunT(t) - redisClient := redis.NewClient(&redis.Options{Addr: mr.Addr()}) - db.Redis = redisClient - config.Cfg.RedisURL = "redis://" + mr.Addr() - - router := gin.New() - UsersRouter(router.Group("/users")) - - if db.DB != nil { - sqlDB, _ := db.DB.DB() - if sqlDB != nil { - sqlDB.SetMaxOpenConns(1) - } - } - - cleanup := func() { - config.Cfg = prevCfg - if db.Redis != nil { - _ = db.Redis.Close() - db.Redis = nil - } - mr.Close() - if db.DB != nil { - sqlDB, _ := db.DB.DB() - if sqlDB != nil { - _ = sqlDB.Close() - } - db.DB = nil - } - } - - return router, mr, cleanup -} - -func TestUsersRouter_Redis_LoginValidateLogout(t *testing.T) { - router, mr, cleanup := setupUsersRouterTestRedis(t) - defer cleanup() - - // Create user - createReq := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "redisrouter"}, Email: "redisrouter@example.com"}, - Password: dto.Password{Password: "password123"}, - } - createBody, _ := json.Marshal(createReq) - createResp := httptest.NewRecorder() - createHTTP := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(createBody)) - createHTTP.Header.Set("Content-Type", "application/json") - router.ServeHTTP(createResp, createHTTP) - if createResp.Code != http.StatusCreated { - t.Fatalf("expected 201 on create, got %d. Body: %s", createResp.Code, createResp.Body.String()) - } - - // Login user - loginReq := dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "redisrouter"}, - Password: dto.Password{Password: "password123"}, - } - loginBody, _ := json.Marshal(loginReq) - loginResp := httptest.NewRecorder() - loginHTTP := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(loginBody)) - loginHTTP.Header.Set("Content-Type", "application/json") - router.ServeHTTP(loginResp, loginHTTP) - if loginResp.Code != http.StatusOK { - t.Fatalf("expected 200 on login, got %d. Body: %s", loginResp.Code, loginResp.Body.String()) - } - - var loginUser dto.UserWithTokenResponse - _ = json.Unmarshal(loginResp.Body.Bytes(), &loginUser) - if loginUser.Token == "" || loginUser.ID == 0 { - t.Fatalf("expected login token and id, got: %+v", loginUser) - } - - // Ensure token is stored in Redis (by key prefix). - keys := mr.Keys() - wantPrefix := "user_token:" + strconv.FormatUint(uint64(loginUser.ID), 10) + ":" - foundTokenKey := false - for _, k := range keys { - if strings.HasPrefix(k, wantPrefix) { - foundTokenKey = true - break - } - } - if !foundTokenKey { - t.Fatalf("expected redis token key with prefix %q, keys: %v", wantPrefix, keys) - } - - // Login should update heartbeat in Redis. - time.Sleep(100 * time.Millisecond) - score, err := db.Redis.ZScore(context.Background(), "heartbeat:", strconv.FormatUint(uint64(loginUser.ID), 10)).Result() - if err != nil { - t.Fatalf("expected heartbeat entry after login, got error: %v", err) - } - if int64(score) < time.Now().Unix()-5 { - t.Fatalf("expected recent heartbeat score after login, got %v", score) - } - - // Validate should succeed - validateResp := httptest.NewRecorder() - validateHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) - validateHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - router.ServeHTTP(validateResp, validateHTTP) - if validateResp.Code != http.StatusOK { - t.Fatalf("expected 200 on validate, got %d. Body: %s", validateResp.Code, validateResp.Body.String()) - } - - // Logout should revoke redis tokens - logoutResp := httptest.NewRecorder() - logoutHTTP := httptest.NewRequest(http.MethodDelete, "/users/logout", nil) - logoutHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - router.ServeHTTP(logoutResp, logoutHTTP) - if logoutResp.Code != http.StatusNoContent { - t.Fatalf("expected 204 on logout, got %d. Body: %s", logoutResp.Code, logoutResp.Body.String()) - } - - // Validate again should fail - validateAfterResp := httptest.NewRecorder() - validateAfterHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) - validateAfterHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - router.ServeHTTP(validateAfterResp, validateAfterHTTP) - if validateAfterResp.Code != http.StatusUnauthorized { - t.Fatalf("expected 401 on validate after logout, got %d. Body: %s", validateAfterResp.Code, validateAfterResp.Body.String()) - } -} diff --git a/backend/internal/routers/users_router_test.go b/backend/internal/routers/users_router_test.go index 338547b..d4232f4 100644 --- a/backend/internal/routers/users_router_test.go +++ b/backend/internal/routers/users_router_test.go @@ -4,95 +4,140 @@ import ( "bytes" "context" "encoding/json" - "log/slog" + "errors" "net/http" "net/http/httptest" "net/url" - "os" + "strconv" "strings" "testing" "time" "cloud.google.com/go/auth/credentials/idtoken" + "github.com/alicebob/miniredis/v2" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" + "github.com/redis/go-redis/v9" "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/service" - "github.com/paularynty/transcendence/auth-service-go/internal/util" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) -func setupUsersRouterTestUnique(t *testing.T) (*gin.Engine, func()) { +type usersRouterEnv struct { + router *gin.Engine + dep *dependency.Dependency + mr *miniredis.Miniredis + cleanup func() +} + +func setupUsersRouterTest(t *testing.T, useRedis bool) *usersRouterEnv { + t.Helper() gin.SetMode(gin.TestMode) - // Mock Logger - util.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelError, - })) - - // Mock Config - prevCfg := config.Cfg - config.Cfg = &config.Config{ - JwtSecret: "test-secret", - UserTokenExpiry: 3600, - TwoFaTokenExpiry: 3600, - OauthStateTokenExpiry: 3600, - GoogleClientId: "test-client", - GoogleRedirectUri: "http://localhost/cb", - FrontendUrl: "http://localhost:3000", + logger := testutil.NewTestLogger() + cfg := testutil.NewTestConfig() + cfg.JwtSecret = "test-secret" + cfg.UserTokenExpiry = 3600 + cfg.UserTokenAbsoluteExpiry = 600 + cfg.TwoFaTokenExpiry = 3600 + cfg.OauthStateTokenExpiry = 3600 + cfg.GoogleClientId = "test-client" + cfg.GoogleRedirectUri = "http://localhost/cb" + cfg.FrontendUrl = "http://localhost:3000" + + if useRedis { + cfg.IsRedisEnabled = true } - dto.InitValidator() - // Mock DB - var err error - // Use unique DB name for each test run to avoid lock issues - // Sanitize test name - dbName := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000" + dto.InitValidator() - model.DB, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) + dbName := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000&_foreign_keys=on" + dbConn, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) if err != nil { t.Fatalf("failed to connect to db: %v", err) } - // Explicitly enable foreign keys - model.DB.Exec("PRAGMA foreign_keys = ON") - - err = model.DB.AutoMigrate(&model.User{}, &model.Friend{}, &model.Token{}, &model.HeartBeat{}) - if err != nil { + dbConn.Exec("PRAGMA foreign_keys = ON") + if err := dbConn.AutoMigrate(&model.User{}, &model.Friend{}, &model.Token{}, &model.HeartBeat{}); err != nil { t.Fatalf("failed to migrate db: %v", err) } + var mr *miniredis.Miniredis + var redisClient *redis.Client + if useRedis { + mr = miniredis.RunT(t) + redisClient = redis.NewClient(&redis.Options{Addr: mr.Addr()}) + cfg.RedisURL = "redis://" + mr.Addr() + } + + dep := dependency.NewDependency(cfg, dbConn, redisClient, logger) router := gin.New() - UsersRouter(router.Group("/users")) + UsersRouter(router.Group("/users"), dep) - // Set MaxOpenConns to 1 to avoid locking issues in tests with SQLite - if model.DB != nil { - sqlDB, _ := model.DB.DB() - if sqlDB != nil { - sqlDB.SetMaxOpenConns(1) - } + if sqlDB, err := dbConn.DB(); err == nil && sqlDB != nil { + sqlDB.SetMaxOpenConns(1) } - return router, func() { - config.Cfg = prevCfg - if model.DB != nil { - sqlDB, _ := model.DB.DB() - if sqlDB != nil { - _ = sqlDB.Close() - } - model.DB = nil + cleanup := func() { + if redisClient != nil { + _ = redisClient.Close() } + if mr != nil { + mr.Close() + } + if sqlDB, err := dbConn.DB(); err == nil && sqlDB != nil { + _ = sqlDB.Close() + } + } + + return &usersRouterEnv{ + router: router, + dep: dep, + mr: mr, + cleanup: cleanup, + } +} + +func signUserToken(t *testing.T, dep *dependency.Dependency, userID uint) string { + t.Helper() + token, err := jwt.SignUserToken(dep, userID) + if err != nil { + t.Fatalf("failed to sign user token: %v", err) + } + return token +} + +func addUserToken(t *testing.T, dep *dependency.Dependency, userID uint) string { + t.Helper() + token := signUserToken(t, dep, userID) + if err := dep.DB.Create(&model.Token{UserID: userID, Token: token}).Error; err != nil { + t.Fatalf("failed to insert token: %v", err) } + return token +} + +func createUser(t *testing.T, dep *dependency.Dependency, username, email, password string) *dto.UserWithoutTokenResponse { + t.Helper() + svc := service.NewUserService(dep) + user, err := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: username}, Email: email}, + Password: dto.Password{Password: password}, + }) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + return user } func TestUsersRouter_CreateUser(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() reqBody := dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "newuser"}, Email: "new@example.com"}, @@ -104,7 +149,7 @@ func TestUsersRouter_CreateUser(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusCreated { t.Fatalf("expected status 201, got %d. Body: %s", resp.Code, resp.Body.String()) @@ -117,11 +162,48 @@ func TestUsersRouter_CreateUser(t *testing.T) { } } +func TestUsersRouter_CreateUser_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + cases := []struct { + name string + body string + wantStatus int + }{ + {"InvalidBody", `{"username": "u"}`, http.StatusBadRequest}, + {"DuplicateUser", "duplicate", http.StatusConflict}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body := tc.body + if tc.name == "DuplicateUser" { + validReq := dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "dupuser"}, Email: "dup@e.com"}, + Password: dto.Password{Password: "pass123"}, + } + payload, _ := json.Marshal(validReq) + env.router.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(payload))) + body = string(payload) + } + + req := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) + + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + }) + } +} + func TestUsersRouter_LoginUser(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() - // Create user createReq := dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "loginuser"}, Email: "login@example.com"}, Password: dto.Password{Password: "password123"}, @@ -130,9 +212,8 @@ func TestUsersRouter_LoginUser(t *testing.T) { cReq := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(createBody)) cReq.Header.Set("Content-Type", "application/json") - router.ServeHTTP(httptest.NewRecorder(), cReq) + env.router.ServeHTTP(httptest.NewRecorder(), cReq) - // Login loginReq := dto.LoginUserRequest{ Identifier: dto.Identifier{Identifier: "loginuser"}, Password: dto.Password{Password: "password123"}, @@ -143,7 +224,7 @@ func TestUsersRouter_LoginUser(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) @@ -156,24 +237,66 @@ func TestUsersRouter_LoginUser(t *testing.T) { } } -func TestUsersRouter_GetProfile(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() +func TestUsersRouter_LoginUser_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + cases := []struct { + name string + body string + setup func() + wantStatus int + }{ + {"InvalidBody", `{}`, nil, http.StatusBadRequest}, + {"UserNotFound", "missing", nil, http.StatusUnauthorized}, + {"WrongPassword", "wrong", func() { + _ = createUser(t, env.dep, "loginfail", "fail@e.com", "correct123") + }, http.StatusUnauthorized}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + + var payload []byte + switch tc.body { + case "missing": + loginReq := dto.LoginUserRequest{Identifier: dto.Identifier{Identifier: "missing"}, Password: dto.Password{Password: "pass123"}} + payload, _ = json.Marshal(loginReq) + case "wrong": + loginReq := dto.LoginUserRequest{Identifier: dto.Identifier{Identifier: "loginfail"}, Password: dto.Password{Password: "wrong123"}} + payload, _ = json.Marshal(loginReq) + default: + payload = []byte(tc.body) + } + + req := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(payload)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) - user := model.User{ - Username: "profileuser", - Email: "profile@example.com", + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + }) } - model.DB.Create(&user) +} + +func TestUsersRouter_GetProfile(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + user := model.User{Username: "profileuser", Email: "profile@example.com"} + env.dep.DB.Create(&user) + tokenStr := addUserToken(t, env.dep, user.ID) req := httptest.NewRequest(http.MethodGet, "/users/me", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) @@ -186,14 +309,14 @@ func TestUsersRouter_GetProfile(t *testing.T) { } } -func TestUsersRouter_Unathorized(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() +func TestUsersRouter_Unauthorized(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() req := httptest.NewRequest(http.MethodGet, "/users/me", nil) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusUnauthorized { t.Fatalf("expected status 401, got %d", resp.Code) @@ -201,18 +324,15 @@ func TestUsersRouter_Unathorized(t *testing.T) { } func TestUsersRouter_UpdateUserProfile(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() user := model.User{Username: "u", Email: "u@e.com"} - model.DB.Create(&user) - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + env.dep.DB.Create(&user) + tokenStr := addUserToken(t, env.dep, user.ID) newAvatar := "http://pic.com/1.png" - reqBody := dto.UpdateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "newname"}, Email: "new@e.com", Avatar: &newAvatar}, - } + reqBody := dto.UpdateUserRequest{User: dto.User{UserName: dto.UserName{Username: "newname"}, Email: "new@e.com", Avatar: &newAvatar}} body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPut, "/users/me", bytes.NewBuffer(body)) @@ -220,7 +340,7 @@ func TestUsersRouter_UpdateUserProfile(t *testing.T) { req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) @@ -233,17 +353,58 @@ func TestUsersRouter_UpdateUserProfile(t *testing.T) { } } +func TestUsersRouter_UpdateUser_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + svc := service.NewUserService(env.dep) + u, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, + Password: dto.Password{Password: "pass123"}, + }) + token := addUserToken(t, env.dep, u.ID) + + _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, + Password: dto.Password{Password: "pass123"}, + }) + + cases := []struct { + name string + method string + path string + body any + wantStatus int + }{ + {"DuplicateProfile", http.MethodPut, "/users/me", dto.UpdateUserRequest{User: dto.User{UserName: dto.UserName{Username: "update_u2"}, Email: "u2@e.com"}}, http.StatusConflict}, + {"WrongOldPassword", http.MethodPut, "/users/password", dto.UpdateUserPasswordRequest{OldPassword: dto.OldPassword{OldPassword: "wrong123"}, NewPassword: dto.NewPassword{NewPassword: "newpass"}}, http.StatusUnauthorized}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + payload, _ := json.Marshal(tc.body) + req := httptest.NewRequest(tc.method, tc.path, bytes.NewBuffer(payload)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + }) + } +} + func TestUsersRouter_UpdateUserPassword(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() - svc := service.NewUserService(model.DB, nil) + svc := service.NewUserService(env.dep) userResp, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "pw"}, Email: "pw@e.com"}, Password: dto.Password{Password: "oldpass"}, }) - tokenStr, _ := jwt.SignUserToken(userResp.ID) - model.DB.Create(&model.Token{UserID: userResp.ID, Token: tokenStr}) + tokenStr := addUserToken(t, env.dep, userResp.ID) reqBody := dto.UpdateUserPasswordRequest{ OldPassword: dto.OldPassword{OldPassword: "oldpass"}, @@ -256,7 +417,7 @@ func TestUsersRouter_UpdateUserPassword(t *testing.T) { req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) @@ -264,22 +425,20 @@ func TestUsersRouter_UpdateUserPassword(t *testing.T) { } func TestUsersRouter_DeleteUser(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() user := model.User{Username: "del", Email: "del@e.com"} - model.DB.Create(&user) - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + env.dep.DB.Create(&user) + tokenStr := addUserToken(t, env.dep, user.ID) - // Let DB settle time.Sleep(500 * time.Millisecond) req := httptest.NewRequest(http.MethodDelete, "/users/me", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusNoContent { t.Fatalf("expected status 204, got %d", resp.Code) @@ -287,19 +446,18 @@ func TestUsersRouter_DeleteUser(t *testing.T) { } func TestUsersRouter_GetUsersWithLimitedInfo(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() user := model.User{Username: "list", Email: "list@e.com"} - model.DB.Create(&user) - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + env.dep.DB.Create(&user) + tokenStr := addUserToken(t, env.dep, user.ID) req := httptest.NewRequest(http.MethodGet, "/users/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", resp.Code) @@ -307,19 +465,18 @@ func TestUsersRouter_GetUsersWithLimitedInfo(t *testing.T) { } func TestUsersRouter_ValidateUser(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() user := model.User{Username: "val", Email: "val@e.com"} - model.DB.Create(&user) - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + env.dep.DB.Create(&user) + tokenStr := addUserToken(t, env.dep, user.ID) req := httptest.NewRequest(http.MethodPost, "/users/validate", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", resp.Code) @@ -333,40 +490,32 @@ func TestUsersRouter_ValidateUser(t *testing.T) { } func TestUsersRouter_Friends(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() - svc := service.NewUserService(model.DB, nil) - u1, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f1"}, Email: "f1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - u2, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f2"}, Email: "f2@e.com"}, - Password: dto.Password{Password: "p"}, - }) + svc := service.NewUserService(env.dep) + u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") + u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") + _ = svc - tokenStr, _ := jwt.SignUserToken(u1.ID) - model.DB.Create(&model.Token{UserID: u1.ID, Token: tokenStr}) + tokenStr := addUserToken(t, env.dep, u1.ID) - // Add Friend reqBody := dto.AddNewFriendRequest{UserID: u2.ID} body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusCreated { t.Fatalf("expected status 201, got %d. Body: %s", resp.Code, resp.Body.String()) } - // Get Friends req = httptest.NewRequest(http.MethodGet, "/users/friends", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", resp.Code) @@ -378,23 +527,60 @@ func TestUsersRouter_Friends(t *testing.T) { } } +func TestUsersRouter_Friends_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + svc := service.NewUserService(env.dep) + u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") + u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") + token := addUserToken(t, env.dep, u1.ID) + + cases := []struct { + name string + payload dto.AddNewFriendRequest + setup func() + wantStatus int + }{ + {"AddSelf", dto.AddNewFriendRequest{UserID: u1.ID}, nil, http.StatusBadRequest}, + {"AddMissing", dto.AddNewFriendRequest{UserID: 999}, nil, http.StatusNotFound}, + {"Duplicate", dto.AddNewFriendRequest{UserID: u2.ID}, func() { + _ = svc.AddNewFriend(context.Background(), u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) + }, http.StatusConflict}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + if tc.name == "Duplicate" { + time.Sleep(200 * time.Millisecond) + } + } + body, _ := json.Marshal(tc.payload) + req := httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + }) + } +} + func TestUsersRouter_2FA(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() - svc := service.NewUserService(model.DB, nil) - user, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "2fa"}, Email: "2fa@e.com"}, - Password: dto.Password{Password: "pass"}, - }) - tokenStr, _ := jwt.SignUserToken(user.ID) - model.DB.Create(&model.Token{UserID: user.ID, Token: tokenStr}) + user := createUser(t, env.dep, "2fa", "2fa@e.com", "pass123") + tokenStr := addUserToken(t, env.dep, user.ID) - // 1. Setup 2FA req := httptest.NewRequest(http.MethodPost, "/users/2fa/setup", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("setup failed: %d", resp.Code) @@ -402,87 +588,121 @@ func TestUsersRouter_2FA(t *testing.T) { var setupRes dto.TwoFASetupResponse _ = json.Unmarshal(resp.Body.Bytes(), &setupRes) - // Let DB settle time.Sleep(200 * time.Millisecond) - // 2. Confirm 2FA code, _ := totp.GenerateCode(setupRes.TwoFASecret, time.Now()) - confirmBody, _ := json.Marshal(dto.TwoFAConfirmRequest{ - SetupToken: setupRes.SetupToken, - TwoFACode: code, - }) + confirmBody, _ := json.Marshal(dto.TwoFAConfirmRequest{SetupToken: setupRes.SetupToken, TwoFACode: code}) req = httptest.NewRequest(http.MethodPost, "/users/2fa/confirm", bytes.NewBuffer(confirmBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tokenStr) resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("confirm failed: %d", resp.Code) } - // 3. Login challenge - // Need pending session token - sessionToken, _ := jwt.SignTwoFAToken(user.ID) + sessionToken, _ := jwt.SignTwoFAToken(env.dep, user.ID) code, _ = totp.GenerateCode(setupRes.TwoFASecret, time.Now()) - challengeBody, _ := json.Marshal(dto.TwoFAChallengeRequest{ - SessionToken: sessionToken, - TwoFACode: code, - }) + challengeBody, _ := json.Marshal(dto.TwoFAChallengeRequest{SessionToken: sessionToken, TwoFACode: code}) req = httptest.NewRequest(http.MethodPost, "/users/2fa", bytes.NewBuffer(challengeBody)) req.Header.Set("Content-Type", "application/json") resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("challenge failed: %d body: %s", resp.Code, resp.Body.String()) } - // 4. Disable 2FA - // We need a valid user token again (new one from confirm or challenge, or reuse old if valid) - // But `issueNewTokenForUser` revokes old tokens if true passed. Confirm passed true. - // So tokenStr is invalid. We need the one from confirm response. var userRes dto.UserWithTokenResponse _ = json.Unmarshal(resp.Body.Bytes(), &userRes) tokenStr = userRes.Token - // Let DB settle time.Sleep(200 * time.Millisecond) - disableBody, _ := json.Marshal(dto.DisableTwoFARequest{ - Password: dto.Password{Password: "pass"}, - }) + disableBody, _ := json.Marshal(dto.DisableTwoFARequest{Password: dto.Password{Password: "pass123"}}) req = httptest.NewRequest(http.MethodPut, "/users/2fa/disable", bytes.NewBuffer(disableBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tokenStr) resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Fatalf("disable failed: %d", resp.Code) } } +func TestUsersRouter_2FA_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + svc := service.NewUserService(env.dep) + u := createUser(t, env.dep, "2fafail", "2fafail@e.com", "pass123") + token := addUserToken(t, env.dep, u.ID) + + setupResp, _ := svc.StartTwoFaSetup(context.Background(), u.ID) + + cases := []struct { + name string + method string + path string + body any + wantStatus int + setup func() + }{ + {"InvalidCode", http.MethodPost, "/users/2fa/confirm", dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: "000000"}, http.StatusBadRequest, nil}, + {"SetupAlreadyEnabled", http.MethodPost, "/users/2fa/setup", nil, http.StatusBadRequest, func() { + code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) + confirmRes, _ := svc.ConfirmTwoFaSetup(context.Background(), u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) + token = confirmRes.Token + }}, + {"WrongDisablePassword", http.MethodPut, "/users/2fa/disable", dto.DisableTwoFARequest{Password: dto.Password{Password: "wrong123"}}, http.StatusUnauthorized, func() { + if token == "" { + code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) + confirmRes, _ := svc.ConfirmTwoFaSetup(context.Background(), u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) + token = confirmRes.Token + } + }}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + var body []byte + if tc.body != nil { + body, _ = json.Marshal(tc.body) + } + req := httptest.NewRequest(tc.method, tc.path, bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+token) + if tc.body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + }) + } +} + func TestUsersRouter_GoogleOAuth(t *testing.T) { - router, cleanup := setupUsersRouterTestUnique(t) - defer cleanup() + env := setupUsersRouterTest(t, false) + defer env.cleanup() - // 1. Google Login (Redirect) req := httptest.NewRequest(http.MethodGet, "/users/google/login", nil) resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusFound { t.Fatalf("expected status 302, got %d", resp.Code) } - // Verify location header format roughly - loc := resp.Header().Get("Location") - if loc == "" { + if loc := resp.Header().Get("Location"); loc == "" { t.Error("expected location header") } - // 2. Google Callback - // Mock service vars origExchange := service.ExchangeCodeForTokens origFetch := service.FetchGoogleUserInfo defer func() { @@ -490,23 +710,17 @@ func TestUsersRouter_GoogleOAuth(t *testing.T) { service.FetchGoogleUserInfo = origFetch }() - service.ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + service.ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return &idtoken.Payload{Subject: "g123"}, nil } service.FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "g123", - Email: "test@google.com", - Name: "Google User", - }, nil + return &dto.GoogleUserData{ID: "g123", Email: "test@google.com", Name: "Google User"}, nil } - // Generate state - state, _ := jwt.SignOauthStateToken() - + state, _ := jwt.SignOauthStateToken(env.dep) req = httptest.NewRequest(http.MethodGet, "/users/google/callback?code=valid&state="+state, nil) resp = httptest.NewRecorder() - router.ServeHTTP(resp, req) + env.router.ServeHTTP(resp, req) if resp.Code != http.StatusFound { t.Fatalf("expected status 302, got %d", resp.Code) @@ -518,3 +732,119 @@ func TestUsersRouter_GoogleOAuth(t *testing.T) { t.Error("expected token in redirect") } } + +func TestUsersRouter_GoogleOAuth_Failures(t *testing.T) { + env := setupUsersRouterTest(t, false) + defer env.cleanup() + + cases := []struct { + name string + path string + setup func() + wantStatus int + }{ + {"MissingParams", "/users/google/callback", nil, http.StatusBadRequest}, + {"ExchangeError", "/users/google/callback?code=c&state=state", func() { + origExchange := service.ExchangeCodeForTokens + service.ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { + return nil, errors.New("mock error") + } + t.Cleanup(func() { service.ExchangeCodeForTokens = origExchange }) + }, http.StatusFound}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + path := tc.path + if strings.Contains(path, "state=state") { + state, _ := jwt.SignOauthStateToken(env.dep) + path = "/users/google/callback?code=c&state=" + state + } + req := httptest.NewRequest(http.MethodGet, path, nil) + resp := httptest.NewRecorder() + env.router.ServeHTTP(resp, req) + if resp.Code != tc.wantStatus { + t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) + } + if tc.name == "ExchangeError" { + loc := resp.Header().Get("Location") + if !strings.Contains(loc, "error=") { + t.Fatalf("expected error param in redirect: %s", loc) + } + } + }) + } +} + +func TestUsersRouter_Redis_LoginValidateLogout(t *testing.T) { + env := setupUsersRouterTest(t, true) + defer env.cleanup() + + createReq := dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "redisrouter"}, Email: "redisrouter@example.com"}, + Password: dto.Password{Password: "password123"}, + } + createBody, _ := json.Marshal(createReq) + createResp := httptest.NewRecorder() + createHTTP := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(createBody)) + createHTTP.Header.Set("Content-Type", "application/json") + env.router.ServeHTTP(createResp, createHTTP) + if createResp.Code != http.StatusCreated { + t.Fatalf("expected 201 on create, got %d. Body: %s", createResp.Code, createResp.Body.String()) + } + + loginReq := dto.LoginUserRequest{ + Identifier: dto.Identifier{Identifier: "redisrouter"}, + Password: dto.Password{Password: "password123"}, + } + loginBody, _ := json.Marshal(loginReq) + loginResp := httptest.NewRecorder() + loginHTTP := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(loginBody)) + loginHTTP.Header.Set("Content-Type", "application/json") + env.router.ServeHTTP(loginResp, loginHTTP) + if loginResp.Code != http.StatusOK { + t.Fatalf("expected 200 on login, got %d. Body: %s", loginResp.Code, loginResp.Body.String()) + } + + var loginUser dto.UserWithTokenResponse + _ = json.Unmarshal(loginResp.Body.Bytes(), &loginUser) + if loginUser.Token == "" || loginUser.ID == 0 { + t.Fatalf("expected login to return token and id") + } + + validateResp := httptest.NewRecorder() + validateHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) + validateHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) + env.router.ServeHTTP(validateResp, validateHTTP) + if validateResp.Code != http.StatusOK { + t.Fatalf("expected 200 on validate, got %d. Body: %s", validateResp.Code, validateResp.Body.String()) + } + + logoutResp := httptest.NewRecorder() + logoutHTTP := httptest.NewRequest(http.MethodDelete, "/users/logout", nil) + logoutHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) + env.router.ServeHTTP(logoutResp, logoutHTTP) + if logoutResp.Code != http.StatusNoContent { + t.Fatalf("expected 204 on logout, got %d. Body: %s", logoutResp.Code, logoutResp.Body.String()) + } + + validateAfterResp := httptest.NewRecorder() + validateAfterHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) + validateAfterHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) + env.router.ServeHTTP(validateAfterResp, validateAfterHTTP) + if validateAfterResp.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 on validate after logout, got %d. Body: %s", validateAfterResp.Code, validateAfterResp.Body.String()) + } + + time.Sleep(200 * time.Millisecond) + score, err := env.dep.Redis.ZScore(context.Background(), "heartbeat:", strconv.FormatUint(uint64(loginUser.ID), 10)).Result() + if err != nil { + t.Fatalf("expected heartbeat entry, got error: %v", err) + } + if int64(score) < time.Now().Unix()-10 { + t.Fatalf("expected recent heartbeat score, got %v", score) + } +} diff --git a/backend/internal/service/friend_service_test.go b/backend/internal/service/friend_service_test.go index 8d16b2d..93c9e1a 100644 --- a/backend/internal/service/friend_service_test.go +++ b/backend/internal/service/friend_service_test.go @@ -7,12 +7,11 @@ import ( model "github.com/paularynty/transcendence/auth-service-go/internal/db" "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) func TestGetAllUsersLimitedInfo(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() // Create users @@ -47,7 +46,7 @@ func TestGetAllUsersLimitedInfo(t *testing.T) { func TestAddNewFriend(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -59,67 +58,60 @@ func TestAddNewFriend(t *testing.T) { Password: dto.Password{Password: "p"}, }) - t.Run("Success", func(t *testing.T) { - err := svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - }) - - t.Run("AddSelf", func(t *testing.T) { - err := svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u1.ID}) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 400 { - t.Errorf("expected 400 error, got %v", err) - } - }) - - t.Run("DuplicateFriend", func(t *testing.T) { - err := svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 409 { - t.Errorf("expected 409 error, got %v", err) - } - }) - - t.Run("UserNotFound", func(t *testing.T) { - err := svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: 999}) - if err == nil { - // Check if friend was actually added (should be 0) - var count int64 - db.Model(&model.Friend{}).Where("user_id = ? AND friend_id = ?", u1.ID, 999).Count(&count) - if count > 0 { - t.Fatal("expected error, but friend was added despite FK violation") + cases := []struct { + name string + userID uint + friendID uint + setup func() + wantErrStatus int + checkFK bool + }{ + {"Success", u1.ID, u2.ID, nil, 0, false}, + {"AddSelf", u1.ID, u1.ID, nil, 400, false}, + {"DuplicateFriend", u1.ID, u2.ID, func() { + _ = svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) + }, 409, false}, + {"UserNotFound", u1.ID, 999, nil, 404, true}, + {"DBError", u1.ID, u2.ID, func() { + sqlDB, _ := db.DB() + _ = sqlDB.Close() + }, -1, false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() } - return - } - authErr, ok := err.(*middleware.AuthError) - if ok { - if authErr.Status != 404 { - t.Errorf("expected 404 error, got %d", authErr.Status) + err := svc.AddNewFriend(ctx, tc.userID, &dto.AddNewFriendRequest{UserID: tc.friendID}) + if tc.wantErrStatus == 0 { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + return } - } - }) - - t.Run("DBError", func(t *testing.T) { - sqlDB, _ := db.DB() - _ = sqlDB.Close() - err := svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - if err == nil { - t.Error("expected error on closed db") - } - }) + if tc.wantErrStatus == -1 { + if err == nil { + t.Error("expected error on closed db") + } + return + } + if err == nil && tc.checkFK { + var count int64 + db.Model(&model.Friend{}).Where("user_id = ? AND friend_id = ?", u1.ID, tc.friendID).Count(&count) + if count > 0 { + t.Fatal("expected error, but friend was added despite FK violation") + } + return + } + requireAuthStatus(t, err, tc.wantErrStatus) + }) + } } func TestGetUserFriends(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -134,44 +126,46 @@ func TestGetUserFriends(t *testing.T) { // Add friend _ = svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - t.Run("Success", func(t *testing.T) { - friends, err := svc.GetUserFriends(ctx, u1.ID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(friends) != 1 { - t.Errorf("expected 1 friend, got %d", len(friends)) - } - if friends[0].ID != u2.ID { - t.Errorf("expected friend ID %d, got %d", u2.ID, friends[0].ID) - } - if friends[0].Online { - t.Error("expected friend to be offline") - } - }) - - t.Run("OnlineFriend", func(t *testing.T) { - // Manually insert heartbeat for u2 - db.Create(&model.HeartBeat{ - UserID: u2.ID, - LastSeenAt: time.Now(), + cases := []struct { + name string + setup func() + wantOnline bool + wantErrStatus int + }{ + {"Success", nil, false, 0}, + {"OnlineFriend", func() { + db.Create(&model.HeartBeat{UserID: u2.ID, LastSeenAt: time.Now()}) + }, true, 0}, + {"DBError", func() { + sqlDB, _ := db.DB() + _ = sqlDB.Close() + }, false, -1}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + friends, err := svc.GetUserFriends(ctx, u1.ID) + if tc.wantErrStatus == -1 { + if err == nil { + t.Error("expected error on closed db") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(friends) != 1 { + t.Errorf("expected 1 friend, got %d", len(friends)) + } + if friends[0].ID != u2.ID { + t.Errorf("expected friend ID %d, got %d", u2.ID, friends[0].ID) + } + if friends[0].Online != tc.wantOnline { + t.Errorf("expected online=%v, got %v", tc.wantOnline, friends[0].Online) + } }) - - friends, err := svc.GetUserFriends(ctx, u1.ID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !friends[0].Online { - t.Error("expected friend to be online") - } - }) - - t.Run("DBError", func(t *testing.T) { - sqlDB, _ := db.DB() - _ = sqlDB.Close() - _, err := svc.GetUserFriends(ctx, u1.ID) - if err == nil { - t.Error("expected error on closed db") - } - }) + } } diff --git a/backend/internal/service/google_oauth_service_test.go b/backend/internal/service/google_oauth_service_test.go index 4e6fcc2..ed052aa 100644 --- a/backend/internal/service/google_oauth_service_test.go +++ b/backend/internal/service/google_oauth_service_test.go @@ -7,16 +7,18 @@ import ( "testing" "cloud.google.com/go/auth/credentials/idtoken" - "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) func TestGetGoogleOAuthURL(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + cfg := testutil.NewTestConfig() + svc := NewUserService(newTestDependencyWithConfig(cfg, db, nil)) ctx := context.Background() t.Run("Success", func(t *testing.T) { @@ -31,11 +33,11 @@ func TestGetGoogleOAuthURL(t *testing.T) { } q := u.Query() - if q.Get("client_id") != config.Cfg.GoogleClientId { - t.Errorf("expected client_id %s, got %s", config.Cfg.GoogleClientId, q.Get("client_id")) + if q.Get("client_id") != cfg.GoogleClientId { + t.Errorf("expected client_id %s, got %s", cfg.GoogleClientId, q.Get("client_id")) } - if q.Get("redirect_uri") != config.Cfg.GoogleRedirectUri { - t.Errorf("expected redirect_uri %s, got %s", config.Cfg.GoogleRedirectUri, q.Get("redirect_uri")) + if q.Get("redirect_uri") != cfg.GoogleRedirectUri { + t.Errorf("expected redirect_uri %s, got %s", cfg.GoogleRedirectUri, q.Get("redirect_uri")) } if q.Get("state") == "" { t.Error("expected state param") @@ -45,7 +47,7 @@ func TestGetGoogleOAuthURL(t *testing.T) { func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() // Helper to parse redirect URL @@ -68,7 +70,7 @@ func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { }) t.Run("ExpiredState", func(t *testing.T) { - userToken, _ := jwt.SignUserToken(1) + userToken, _ := jwt.SignUserToken(svc.Dep, 1) redirectURL := svc.HandleGoogleOAuthCallback(ctx, "somecode", userToken) _, errMsg := parseRedirect(redirectURL) @@ -80,7 +82,7 @@ func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { func TestHandleGoogleOAuthCallback_Success(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() // Mock dependencies @@ -91,7 +93,7 @@ func TestHandleGoogleOAuthCallback_Success(t *testing.T) { FetchGoogleUserInfo = origFetch }() - ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return &idtoken.Payload{Subject: "g123"}, nil } @@ -103,7 +105,7 @@ func TestHandleGoogleOAuthCallback_Success(t *testing.T) { }, nil } - state, _ := jwt.SignOauthStateToken() + state, _ := jwt.SignOauthStateToken(svc.Dep) t.Run("NewUser", func(t *testing.T) { redirectURL := svc.HandleGoogleOAuthCallback(ctx, "validcode", state) @@ -205,7 +207,7 @@ func TestHandleGoogleOAuthCallback_Success(t *testing.T) { func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens @@ -215,10 +217,10 @@ func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { FetchGoogleUserInfo = origFetch }() - state, _ := jwt.SignOauthStateToken() + state, _ := jwt.SignOauthStateToken(svc.Dep) t.Run("ExchangeError", func(t *testing.T) { - ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return nil, errors.New("exchange failed") } @@ -230,7 +232,7 @@ func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { }) t.Run("FetchError", func(t *testing.T) { - ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return &idtoken.Payload{}, nil } FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { @@ -247,7 +249,7 @@ func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { func TestLinkGoogleAccountToExistingUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -317,7 +319,7 @@ func TestLinkGoogleAccountToExistingUser(t *testing.T) { func TestCreateNewUserFromGoogleInfo(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() t.Run("Success", func(t *testing.T) { @@ -367,7 +369,7 @@ func TestCreateNewUserFromGoogleInfo(t *testing.T) { func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens @@ -377,9 +379,9 @@ func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { FetchGoogleUserInfo = origFetch }() - state, _ := jwt.SignOauthStateToken() + state, _ := jwt.SignOauthStateToken(svc.Dep) - ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return &idtoken.Payload{Subject: "g123"}, nil } @@ -403,7 +405,7 @@ func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { func TestHandleGoogleOAuthCallback_LinkError(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() origExchange := ExchangeCodeForTokens @@ -413,9 +415,9 @@ func TestHandleGoogleOAuthCallback_LinkError(t *testing.T) { FetchGoogleUserInfo = origFetch }() - state, _ := jwt.SignOauthStateToken() + state, _ := jwt.SignOauthStateToken(svc.Dep) - ExchangeCodeForTokens = func(ctx context.Context, code string) (*idtoken.Payload, error) { + ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { return &idtoken.Payload{Subject: "new_g_id"}, nil } @@ -429,7 +431,7 @@ func TestHandleGoogleOAuthCallback_LinkError(t *testing.T) { // Create user with SAME email but DIFFERENT google ID (already linked) googleID := "old_g_id" - svc.DB.Create(&model.User{ + svc.Dep.DB.Create(&model.User{ Username: "existing", Email: "test@google.com", GoogleOauthID: &googleID, diff --git a/backend/internal/service/helper.go b/backend/internal/service/helper.go index 7521cd4..6d93b42 100644 --- a/backend/internal/service/helper.go +++ b/backend/internal/service/helper.go @@ -29,7 +29,7 @@ func NewUserService(dep *dependency.Dependency) *UserService { } return &UserService{ - Dep: dep, + Dep: dep, } } diff --git a/backend/internal/service/helper_test.go b/backend/internal/service/helper_test.go index 863d331..3a2b640 100644 --- a/backend/internal/service/helper_test.go +++ b/backend/internal/service/helper_test.go @@ -61,7 +61,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("UpdateHeartBeat", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) // Create user first to satisfy FK _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ @@ -83,7 +83,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("IssueNewTokenForUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) // Create user first _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ @@ -113,7 +113,7 @@ func TestHelperFunctions(t *testing.T) { t.Run("IssueNewTokenForUser_DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) sqlDB, _ := db.DB() _ = sqlDB.Close() diff --git a/backend/internal/service/redis_service_test.go b/backend/internal/service/redis_service_test.go index a57a8c4..0396fd0 100644 --- a/backend/internal/service/redis_service_test.go +++ b/backend/internal/service/redis_service_test.go @@ -11,30 +11,23 @@ import ( "github.com/paularynty/transcendence/auth-service-go/internal/config" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/redis/go-redis/v9" ) -func withRedisTestExpiries(t *testing.T, userTTLSeconds int, absoluteTTLSeconds int) func() { - t.Helper() - - prevCfg := config.Cfg - cfgCopy := *prevCfg - cfgCopy.UserTokenExpiry = userTTLSeconds - cfgCopy.UserTokenAbsoluteExpiry = absoluteTTLSeconds - config.Cfg = &cfgCopy - - return func() { - config.Cfg = prevCfg - } +func withRedisTestExpiries(cfg *config.Config, userTTLSeconds int, absoluteTTLSeconds int) { + cfg.UserTokenExpiry = userTTLSeconds + cfg.UserTokenAbsoluteExpiry = absoluteTTLSeconds } func TestRedisTokenLifecycle(t *testing.T) { db := setupTestDB(t.Name()) - mr, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + withRedisTestExpiries(cfg, 10, 30) + mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - defer withRedisTestExpiries(t, 10, 30)() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -95,10 +88,11 @@ func TestRedisTokenLifecycle(t *testing.T) { func TestRedisHeartbeatOnlineStatusAndCleanup(t *testing.T) { db := setupTestDB(t.Name()) - _, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + _, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() u1, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -155,10 +149,11 @@ func TestRedisHeartbeatOnlineStatusAndCleanup(t *testing.T) { func TestRedisLoginUpdatesHeartbeat(t *testing.T) { db := setupTestDB(t.Name()) - _, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + _, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() created, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -195,10 +190,11 @@ func TestRedisLoginUpdatesHeartbeat(t *testing.T) { func TestRedisLogoutRevokesAllTokens(t *testing.T) { db := setupTestDB(t.Name()) - mr, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -242,10 +238,11 @@ func TestRedisLogoutRevokesAllTokens(t *testing.T) { func TestRedisDeleteUserRevokesAllTokens(t *testing.T) { db := setupTestDB(t.Name()) - mr, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -289,10 +286,11 @@ func TestRedisDeleteUserRevokesAllTokens(t *testing.T) { func TestRedisUpdatePasswordRevokesOldTokens(t *testing.T) { db := setupTestDB(t.Name()) - mr, redisClient, cleanupRedis := setupTestRedis(t) + cfg := testutil.NewTestConfig() + mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) defer cleanupRedis() - svc := NewUserService(db, redisClient) + svc := NewUserService(newTestDependencyWithConfig(cfg, db, redisClient)) ctx := context.Background() userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ diff --git a/backend/internal/service/setup_test.go b/backend/internal/service/setup_test.go index 1c19af0..a9b8ed0 100644 --- a/backend/internal/service/setup_test.go +++ b/backend/internal/service/setup_test.go @@ -1,7 +1,6 @@ package service import ( - "log/slog" "os" "strings" "testing" @@ -13,7 +12,8 @@ import ( "github.com/paularynty/transcendence/auth-service-go/internal/config" model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/util" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" ) func setupTestDB(testName string) *gorm.DB { @@ -44,35 +44,41 @@ func setupTestDB(testName string) *gorm.DB { return db } -func setupConfig() { - config.Cfg = &config.Config{ - JwtSecret: "test-secret", - UserTokenExpiry: 3600, - UserTokenAbsoluteExpiry: 2592000, - OauthStateTokenExpiry: 600, - GoogleClientId: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectUri: "http://localhost:8080/callback", - FrontendUrl: "http://localhost:3000", - TwoFaUrlPrefix: "otpauth://totp/Transcendence?secret=", - TwoFaTokenExpiry: 600, - RedisURL: "", - IsRedisEnabled: false, - } - - // Mock logger to discard output during tests - util.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelError, // Only show errors - })) -} - func TestMain(m *testing.M) { - setupConfig() code := m.Run() os.Exit(code) } -func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *redis.Client, func()) { +func newTestDependency(db *gorm.DB, redis *redis.Client, cfgMutators ...func(*config.Config)) *dependency.Dependency { + cfg := testutil.NewTestConfig() + for _, mutate := range cfgMutators { + mutate(cfg) + } + logger := testutil.NewTestLogger() + if redis != nil { + cfg.IsRedisEnabled = true + if cfg.RedisURL == "" { + cfg.RedisURL = "redis://test" + } + } + return dependency.NewDependency(cfg, db, redis, logger) +} + +func newTestDependencyWithConfig(cfg *config.Config, db *gorm.DB, redis *redis.Client) *dependency.Dependency { + if cfg == nil { + cfg = testutil.NewTestConfig() + } + logger := testutil.NewTestLogger() + if redis != nil { + cfg.IsRedisEnabled = true + if cfg.RedisURL == "" { + cfg.RedisURL = "redis://test" + } + } + return dependency.NewDependency(cfg, db, redis, logger) +} + +func setupTestRedis(t *testing.T, cfg *config.Config) (*miniredis.Miniredis, *redis.Client, func()) { t.Helper() mr := miniredis.RunT(t) @@ -80,16 +86,14 @@ func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *redis.Client, func()) Addr: mr.Addr(), }) - prevCfg := config.Cfg - cfgCopy := *prevCfg - cfgCopy.RedisURL = "redis://" + mr.Addr() - cfgCopy.IsRedisEnabled = true - config.Cfg = &cfgCopy + if cfg != nil { + cfg.RedisURL = "redis://" + mr.Addr() + cfg.IsRedisEnabled = true + } cleanup := func() { _ = client.Close() mr.Close() - config.Cfg = prevCfg } return mr, client, cleanup diff --git a/backend/internal/service/twofa_service_test.go b/backend/internal/service/twofa_service_test.go index 3c22ac7..ee42bd1 100644 --- a/backend/internal/service/twofa_service_test.go +++ b/backend/internal/service/twofa_service_test.go @@ -16,7 +16,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, Password: dto.Password{Password: "p"}, @@ -33,7 +33,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("ConfirmSetup_Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, Password: dto.Password{Password: "p"}, @@ -61,7 +61,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_AlreadyEnabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u3"}, Email: "u3@e.com"}, Password: dto.Password{Password: "p"}, @@ -85,7 +85,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_OAuthUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) // Mock OAuth user oauthUser := dto.GoogleUserData{ ID: "oauth123", @@ -108,7 +108,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { t.Run("StartSetup_DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "u4"}, Email: "u4@e.com"}, Password: dto.Password{Password: "p"}, @@ -125,7 +125,7 @@ func TestTwoFASetupAndConfirm(t *testing.T) { func TestConfirmTwoFaSetup_Errors(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -165,7 +165,7 @@ func TestConfirmTwoFaSetup_Errors(t *testing.T) { }) t.Run("WrongTokenType", func(t *testing.T) { - token, _ := jwt.SignUserToken(u.ID) + token, _ := jwt.SignUserToken(svc.Dep, u.ID) req := &dto.TwoFAConfirmRequest{ SetupToken: token, TwoFACode: code, @@ -193,7 +193,7 @@ func TestConfirmTwoFaSetup_Errors(t *testing.T) { t.Run("NotInitiated", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) // User with no 2FA token u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ni"}, Email: "ni@e.com"}, @@ -201,7 +201,7 @@ func TestConfirmTwoFaSetup_Errors(t *testing.T) { }) // Create a valid setup token manually - setupToken, _ := jwt.SignTwoFASetupToken(u.ID, "secret") + setupToken, _ := jwt.SignTwoFASetupToken(svc.Dep, u.ID, "secret") code, _ := totp.GenerateCode("secret", time.Now()) req := &dto.TwoFAConfirmRequest{ @@ -221,7 +221,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch1"}, Email: "ch1@e.com"}, Password: dto.Password{Password: "p"}, @@ -253,7 +253,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("InvalidCode", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch2"}, Email: "ch2@e.com"}, Password: dto.Password{Password: "p"}, @@ -281,7 +281,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("NotEnabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "chne"}, Email: "chne@e.com"}, Password: dto.Password{Password: "p"}, @@ -289,7 +289,7 @@ func TestTwoFAChallenge(t *testing.T) { // Do NOT enable 2FA // Create session token manually - sessionToken, _ := jwt.SignTwoFAToken(u.ID) + sessionToken, _ := jwt.SignTwoFAToken(svc.Dep, u.ID) req := &dto.TwoFAChallengeRequest{ SessionToken: sessionToken, @@ -304,7 +304,7 @@ func TestTwoFAChallenge(t *testing.T) { t.Run("DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "ch3"}, Email: "ch3@e.com"}, Password: dto.Password{Password: "p"}, @@ -338,7 +338,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("Success", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis1"}, Email: "dis1@e.com"}, Password: dto.Password{Password: "p"}, @@ -361,7 +361,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("AlreadyDisabled", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis2"}, Email: "dis2@e.com"}, Password: dto.Password{Password: "p"}, @@ -378,7 +378,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("OAuthUser", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) // Mock OAuth user oauthUser := dto.GoogleUserData{ ID: "oauth456", @@ -401,7 +401,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("DBError", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "dis3"}, Email: "dis3@e.com"}, Password: dto.Password{Password: "p"}, @@ -421,7 +421,7 @@ func TestDisableTwoFA(t *testing.T) { t.Run("InvalidPassword", func(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "disinv"}, Email: "disinv@e.com"}, Password: dto.Password{Password: "correct"}, diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 97524d4..d9c4480 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -23,7 +23,7 @@ const MaxAvatarSize = 1 * 1024 * 1024 // 1 MB const BaseGoogleOAuthURL = "https://accounts.google.com/o/oauth2/v2/auth" type UserService struct { - Dep *dependency.Dependency + Dep *dependency.Dependency } func (s *UserService) CreateUser(ctx context.Context, request *dto.CreateUserRequest) (*dto.UserWithoutTokenResponse, error) { diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 540196a..d3ee9d9 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -10,80 +10,109 @@ import ( "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) +func requireAuthStatus(t *testing.T, err error, status int) { + t.Helper() + authErr, ok := err.(*middleware.AuthError) + if !ok || authErr.Status != status { + t.Fatalf("expected %d error, got %v", status, err) + } +} + func TestCreateUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() - t.Run("Success", func(t *testing.T) { - req := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "testuser"}, - Email: "test@example.com", + cases := []struct { + name string + req *dto.CreateUserRequest + setup func() + wantErrStatus int + }{ + { + name: "Success", + req: &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "testuser"}, + Email: "test@example.com", + }, + Password: dto.Password{Password: "password123"}, }, - Password: dto.Password{Password: "password123"}, - } - - resp, err := svc.CreateUser(ctx, req) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if resp.Username != req.Username { - t.Errorf("expected username %s, got %s", req.Username, resp.Username) - } - if resp.Email != req.Email { - t.Errorf("expected email %s, got %s", req.Email, resp.Email) - } - if resp.ID == 0 { - t.Error("expected valid ID") - } - }) - - t.Run("DuplicateUsername", func(t *testing.T) { - req := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "testuser"}, - Email: "other@example.com", + }, + { + name: "DuplicateUsername", + req: &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "testuser"}, + Email: "other@example.com", + }, + Password: dto.Password{Password: "password123"}, }, - Password: dto.Password{Password: "password123"}, - } - - _, err := svc.CreateUser(ctx, req) - if err == nil { - t.Fatal("expected error for duplicate username") - } - - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 409 { - t.Errorf("expected 409 error, got %v", err) - } - }) - - t.Run("DuplicateEmail", func(t *testing.T) { - req := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "otheruser"}, - Email: "test@example.com", + setup: func() { + _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "testuser"}, + Email: "test@example.com", + }, + Password: dto.Password{Password: "password123"}, + }) }, - Password: dto.Password{Password: "password123"}, - } - - _, err := svc.CreateUser(ctx, req) - if err == nil { - t.Fatal("expected error for duplicate email") - } + wantErrStatus: 409, + }, + { + name: "DuplicateEmail", + req: &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "otheruser"}, + Email: "test@example.com", + }, + Password: dto.Password{Password: "password123"}, + }, + setup: func() { + _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "seeduser"}, + Email: "test@example.com", + }, + Password: dto.Password{Password: "password123"}, + }) + }, + wantErrStatus: 409, + }, + } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 409 { - t.Errorf("expected 409 error, got %v", err) - } - }) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + resp, err := svc.CreateUser(ctx, tc.req) + if tc.wantErrStatus == 0 { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resp.Username != tc.req.Username { + t.Errorf("expected username %s, got %s", tc.req.Username, resp.Username) + } + if resp.Email != tc.req.Email { + t.Errorf("expected email %s, got %s", tc.req.Email, resp.Email) + } + if resp.ID == 0 { + t.Error("expected valid ID") + } + return + } + if err == nil { + t.Fatalf("expected error for %s", tc.name) + } + requireAuthStatus(t, err, tc.wantErrStatus) + }) + } } func TestLoginUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() // Setup user @@ -232,7 +261,7 @@ func TestLoginUser(t *testing.T) { func TestGetUserByID(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -243,31 +272,38 @@ func TestGetUserByID(t *testing.T) { Password: dto.Password{Password: "pass"}, }) - t.Run("Success", func(t *testing.T) { - got, err := svc.GetUserByID(ctx, u.ID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.ID != u.ID { - t.Errorf("want ID %d, got %d", u.ID, got.ID) - } - }) + cases := []struct { + name string + userID uint + wantErrStatus int + }{ + {"Success", u.ID, 0}, + {"NotFound", 9999, 404}, + } - t.Run("NotFound", func(t *testing.T) { - _, err := svc.GetUserByID(ctx, 9999) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 404 { - t.Errorf("expected 404 error, got %v", err) - } - }) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := svc.GetUserByID(ctx, tc.userID) + if tc.wantErrStatus == 0 { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.ID != u.ID { + t.Errorf("want ID %d, got %d", u.ID, got.ID) + } + return + } + if err == nil { + t.Fatal("expected error") + } + requireAuthStatus(t, err, tc.wantErrStatus) + }) + } } func TestUpdateUserPassword(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -292,59 +328,62 @@ func TestUpdateUserPassword(t *testing.T) { t.Error("expected new token") } - // Verify new password works loginReq := &dto.LoginUserRequest{ Identifier: dto.Identifier{Identifier: "passupdate"}, Password: dto.Password{Password: "newpass"}, } - _, err = svc.LoginUser(ctx, loginReq) - if err != nil { + if _, err := svc.LoginUser(ctx, loginReq); err != nil { t.Error("failed to login with new password") } }) - t.Run("InvalidOldPassword", func(t *testing.T) { - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "wrongold"}, - NewPassword: dto.NewPassword{NewPassword: "newpass2"}, - } - _, err := svc.UpdateUserPassword(ctx, u.ID, req) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 401 { - t.Errorf("expected 401 error, got %v", err) - } - }) - - t.Run("OAuthUser", func(t *testing.T) { - oauthUser := dto.GoogleUserData{ - ID: "passoauth", - Email: "passoauth@e.com", - } - user, _ := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) - - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "any"}, - NewPassword: dto.NewPassword{NewPassword: "new"}, - } + errorCases := []struct { + name string + setup func() uint + req *dto.UpdateUserPasswordRequest + wantErrStatus int + }{ + { + name: "InvalidOldPassword", + setup: func() uint { + return u.ID + }, + req: &dto.UpdateUserPasswordRequest{ + OldPassword: dto.OldPassword{OldPassword: "wrongold"}, + NewPassword: dto.NewPassword{NewPassword: "newpass2"}, + }, + wantErrStatus: 401, + }, + { + name: "OAuthUser", + setup: func() uint { + oauthUser := dto.GoogleUserData{ID: "passoauth", Email: "passoauth@e.com"} + user, _ := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) + return user.ID + }, + req: &dto.UpdateUserPasswordRequest{ + OldPassword: dto.OldPassword{OldPassword: "any"}, + NewPassword: dto.NewPassword{NewPassword: "new"}, + }, + wantErrStatus: 400, + }, + } - _, err := svc.UpdateUserPassword(ctx, user.ID, req) - if err == nil { - t.Fatal("expected error for oauth user") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 400 { - t.Errorf("expected 400 error, got %v", err) - } - }) + for _, tc := range errorCases { + t.Run(tc.name, func(t *testing.T) { + userID := tc.setup() + _, err := svc.UpdateUserPassword(ctx, userID, tc.req) + if err == nil { + t.Fatal("expected error") + } + requireAuthStatus(t, err, tc.wantErrStatus) + }) + } t.Run("InvalidHash", func(t *testing.T) { - // Manually create user with invalid hash _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ User: dto.User{UserName: dto.UserName{Username: "badhash2"}, Email: "badhash2@e.com"}, - Password: dto.Password{Password: "p"}, + Password: dto.Password{Password: "password123"}, }) badHash := "invalid_hash" var user model.User @@ -352,7 +391,7 @@ func TestUpdateUserPassword(t *testing.T) { db.Model(&user).Update("password_hash", badHash) req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "p"}, + OldPassword: dto.OldPassword{OldPassword: "password123"}, NewPassword: dto.NewPassword{NewPassword: "new"}, } @@ -360,7 +399,6 @@ func TestUpdateUserPassword(t *testing.T) { if err == nil { t.Fatal("expected error") } - // Should return raw error if _, ok := err.(*middleware.AuthError); ok { t.Error("expected raw error") } @@ -369,7 +407,7 @@ func TestUpdateUserPassword(t *testing.T) { func TestUpdateUserProfile(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -380,56 +418,69 @@ func TestUpdateUserProfile(t *testing.T) { Password: dto.Password{Password: "pass"}, }) - t.Run("Success", func(t *testing.T) { - newAvatar := "new_avatar.png" - req := &dto.UpdateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "newname"}, - Email: "new@example.com", - Avatar: &newAvatar, + cases := []struct { + name string + setup func() + req *dto.UpdateUserRequest + wantErrStatus int + }{ + { + name: "Success", + req: &dto.UpdateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "newname"}, + Email: "new@example.com", + Avatar: func() *string { v := "new_avatar.png"; return &v }(), + }, }, - } - - got, err := svc.UpdateUserProfile(ctx, u.ID, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.Username != "newname" { - t.Errorf("want username newname, got %s", got.Username) - } - if got.Email != "new@example.com" { - t.Errorf("want email new@example.com, got %s", got.Email) - } - }) - - t.Run("Duplicate", func(t *testing.T) { - // Create another user - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "other"}, Email: "other@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - req := &dto.UpdateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "other"}, // Duplicate - Email: "new@example.com", + }, + { + name: "Duplicate", + setup: func() { + _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "other"}, Email: "other@e.com"}, + Password: dto.Password{Password: "password123"}, + }) }, - } + req: &dto.UpdateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "other"}, + Email: "new@example.com", + }, + }, + wantErrStatus: 409, + }, + } - _, err := svc.UpdateUserProfile(ctx, u.ID, req) - if err == nil { - t.Fatal("expected error for duplicate") - } - authErr, ok := err.(*middleware.AuthError) - if !ok || authErr.Status != 409 { - t.Errorf("expected 409 error, got %v", err) - } - }) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup() + } + got, err := svc.UpdateUserProfile(ctx, u.ID, tc.req) + if tc.wantErrStatus == 0 { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Username != tc.req.Username { + t.Errorf("want username %s, got %s", tc.req.Username, got.Username) + } + if got.Email != tc.req.Email { + t.Errorf("want email %s, got %s", tc.req.Email, got.Email) + } + return + } + if err == nil { + t.Fatal("expected error for duplicate") + } + requireAuthStatus(t, err, tc.wantErrStatus) + }) + } } func TestDeleteUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ @@ -455,7 +506,7 @@ func TestDeleteUser(t *testing.T) { func TestValidateUserToken(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() createReq := &dto.CreateUserRequest{ @@ -512,7 +563,7 @@ func TestValidateUserToken(t *testing.T) { func TestLogoutUser(t *testing.T) { db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + svc := NewUserService(newTestDependency(db, nil)) ctx := context.Background() createReq := &dto.CreateUserRequest{ @@ -545,126 +596,90 @@ func TestLogoutUser(t *testing.T) { func TestDBErrors(t *testing.T) { ctx := context.Background() - t.Run("CreateUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - req := &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "db1"}, Email: "db1@e.com"}, - Password: dto.Password{Password: "p"}, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.CreateUser(ctx, req) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("LoginUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "db1"}, - Password: dto.Password{Password: "p"}, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.LoginUser(ctx, req) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("GetUserByID", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.GetUserByID(ctx, 1) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("UpdateUserPassword", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "p"}, - NewPassword: dto.NewPassword{NewPassword: "p2"}, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.UpdateUserPassword(ctx, 1, req) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("UpdateUserProfile", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - req := &dto.UpdateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "n"}, Email: "n@e.com"}, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.UpdateUserProfile(ctx, 1, req) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("DeleteUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - err := svc.DeleteUser(ctx, 1) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("ValidateUserToken", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - err := svc.ValidateUserToken(ctx, "token", 1) - if err == nil { - t.Error("expected db error") - } - }) - - t.Run("LogoutUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := NewUserService(db, nil) + cases := []struct { + name string + run func(svc *UserService) error + }{ + { + name: "CreateUser", + run: func(svc *UserService) error { + req := &dto.CreateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "db1"}, Email: "db1@e.com"}, + Password: dto.Password{Password: "password123"}, + } + _, err := svc.CreateUser(ctx, req) + return err + }, + }, + { + name: "LoginUser", + run: func(svc *UserService) error { + req := &dto.LoginUserRequest{ + Identifier: dto.Identifier{Identifier: "db1"}, + Password: dto.Password{Password: "password123"}, + } + _, err := svc.LoginUser(ctx, req) + return err + }, + }, + { + name: "GetUserByID", + run: func(svc *UserService) error { + _, err := svc.GetUserByID(ctx, 1) + return err + }, + }, + { + name: "UpdateUserPassword", + run: func(svc *UserService) error { + req := &dto.UpdateUserPasswordRequest{ + OldPassword: dto.OldPassword{OldPassword: "password123"}, + NewPassword: dto.NewPassword{NewPassword: "password456"}, + } + _, err := svc.UpdateUserPassword(ctx, 1, req) + return err + }, + }, + { + name: "UpdateUserProfile", + run: func(svc *UserService) error { + req := &dto.UpdateUserRequest{ + User: dto.User{UserName: dto.UserName{Username: "n"}, Email: "n@e.com"}, + } + _, err := svc.UpdateUserProfile(ctx, 1, req) + return err + }, + }, + { + name: "DeleteUser", + run: func(svc *UserService) error { + return svc.DeleteUser(ctx, 1) + }, + }, + { + name: "ValidateUserToken", + run: func(svc *UserService) error { + return svc.ValidateUserToken(ctx, "token", 1) + }, + }, + { + name: "LogoutUser", + run: func(svc *UserService) error { + return svc.LogoutUser(ctx, 1) + }, + }, + } - sqlDB, _ := db.DB() - _ = sqlDB.Close() + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + db := setupTestDB(t.Name()) + svc := NewUserService(newTestDependency(db, nil)) + sqlDB, _ := db.DB() + _ = sqlDB.Close() - err := svc.LogoutUser(ctx, 1) - if err == nil { - t.Error("expected db error") - } - }) + if err := tc.run(svc); err == nil { + t.Error("expected db error") + } + }) + } } diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go new file mode 100644 index 0000000..b00d263 --- /dev/null +++ b/backend/internal/testutil/testutil.go @@ -0,0 +1,50 @@ +package testutil + +import ( + "io" + "log/slog" + + "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +func NewTestConfig() *config.Config { + return &config.Config{ + JwtSecret: "test-secret", + UserTokenExpiry: 3600, + UserTokenAbsoluteExpiry: 2592000, + OauthStateTokenExpiry: 600, + GoogleClientId: "test-client-id", + GoogleClientSecret: "test-client-secret", + GoogleRedirectUri: "http://localhost:8080/callback", + FrontendUrl: "http://localhost:3000", + TwoFaUrlPrefix: "otpauth://totp/Transcendence?secret=", + TwoFaTokenExpiry: 600, + RedisURL: "", + IsRedisEnabled: false, + } +} + +func NewTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: slog.LevelError, + })) +} + +func NewTestDependency(cfg *config.Config, db *gorm.DB, redis *redis.Client, logger *slog.Logger) *dependency.Dependency { + if cfg == nil { + cfg = NewTestConfig() + } + if logger == nil { + logger = NewTestLogger() + } + if redis != nil { + cfg.IsRedisEnabled = true + if cfg.RedisURL == "" { + cfg.RedisURL = "redis://test" + } + } + return dependency.NewDependency(cfg, db, redis, logger) +} diff --git a/backend/internal/util/jwt/token_test.go b/backend/internal/util/jwt/token_test.go index bbfea3d..299efcd 100644 --- a/backend/internal/util/jwt/token_test.go +++ b/backend/internal/util/jwt/token_test.go @@ -7,101 +7,116 @@ import ( libjwt "github.com/golang-jwt/jwt/v5" - "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + "github.com/paularynty/transcendence/auth-service-go/internal/dto" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) -func setupTokenConfig(t *testing.T) func() { +func setupTokenDep(t *testing.T) *dependency.Dependency { t.Helper() - prev := config.Cfg - config.Cfg = &config.Config{ - JwtSecret: "test-secret-key", - UserTokenExpiry: 3600, - OauthStateTokenExpiry: 120, - TwoFaTokenExpiry: 300, - } - - return func() { - config.Cfg = prev - } -} - -func TestSignAndValidateUserToken(t *testing.T) { - cleanup := setupTokenConfig(t) - defer cleanup() - - token, err := jwt.SignUserToken(42) - if err != nil { - t.Fatalf("SignUserToken returned error: %v", err) - } - - claims, err := jwt.ValidateUserTokenGeneric(token) - if err != nil { - t.Fatalf("ValidateUserTokenGeneric returned error: %v", err) - } - - if claims.UserID != 42 { - t.Fatalf("expected user id 42, got %d", claims.UserID) - } - - if claims.Type != jwt.UserTokenType { - t.Fatalf("expected claim type %q, got %q", jwt.UserTokenType, claims.Type) - } - - if claims.ExpiresAt == nil || claims.ExpiresAt.Before(time.Now()) { - t.Fatalf("expected future expiration, got %v", claims.ExpiresAt) - } -} - -func TestValidateUserTokenRejectsWrongType(t *testing.T) { - cleanup := setupTokenConfig(t) - defer cleanup() - - token, err := jwt.SignTwoFAToken(10) - if err != nil { - t.Fatalf("SignTwoFAToken returned error: %v", err) - } - - _, err = jwt.ValidateUserTokenGeneric(token) - if !errors.Is(err, libjwt.ErrTokenInvalidClaims) { - t.Fatalf("expected ErrTokenInvalidClaims, got %v", err) - } -} - -func TestValidateOauthStateToken(t *testing.T) { - cleanup := setupTokenConfig(t) - defer cleanup() - - token, err := jwt.SignOauthStateToken() - if err != nil { - t.Fatalf("SignOauthStateToken returned error: %v", err) - } - - claims, err := jwt.ValidateOauthStateToken(token) - if err != nil { - t.Fatalf("ValidateOauthStateToken returned error: %v", err) - } - - if claims.Type != jwt.GoogleOAuthStateType { - t.Fatalf("expected oauth state type %q, got %q", jwt.GoogleOAuthStateType, claims.Type) - } + cfg := testutil.NewTestConfig() + cfg.JwtSecret = "test-secret-key" + cfg.UserTokenExpiry = 3600 + cfg.OauthStateTokenExpiry = 120 + cfg.TwoFaTokenExpiry = 300 + return testutil.NewTestDependency(cfg, nil, nil, nil) } -func TestValidateTwoFASetupToken(t *testing.T) { - cleanup := setupTokenConfig(t) - defer cleanup() - - token, err := jwt.SignTwoFASetupToken(7, "secret") - if err != nil { - t.Fatalf("SignTwoFASetupToken returned error: %v", err) - } - - claims, err := jwt.ValidateTwoFASetupToken(token) - if err != nil { - t.Fatalf("ValidateTwoFASetupToken returned error: %v", err) +func TestTokenRoundTrip(t *testing.T) { + dep := setupTokenDep(t) + + cases := []struct { + name string + sign func() (string, error) + validate func(string) (any, error) + assert func(t *testing.T, claims any) + expectedError error + }{ + { + name: "UserToken", + sign: func() (string, error) { + return jwt.SignUserToken(dep, 42) + }, + validate: func(token string) (any, error) { + return jwt.ValidateUserTokenGeneric(dep, token) + }, + assert: func(t *testing.T, claims any) { + parsed := claims.(*dto.UserJwtPayload) + if parsed.UserID != 42 { + t.Fatalf("expected user id 42, got %d", parsed.UserID) + } + if parsed.Type != jwt.UserTokenType { + t.Fatalf("expected claim type %q, got %q", jwt.UserTokenType, parsed.Type) + } + if parsed.ExpiresAt == nil || parsed.ExpiresAt.Before(time.Now()) { + t.Fatalf("expected future expiration, got %v", parsed.ExpiresAt) + } + }, + }, + { + name: "OauthStateToken", + sign: func() (string, error) { + return jwt.SignOauthStateToken(dep) + }, + validate: func(token string) (any, error) { + return jwt.ValidateOauthStateToken(dep, token) + }, + assert: func(t *testing.T, claims any) { + parsed := claims.(*dto.OauthStateJwtPayload) + if parsed.Type != jwt.GoogleOAuthStateType { + t.Fatalf("expected oauth state type %q, got %q", jwt.GoogleOAuthStateType, parsed.Type) + } + }, + }, + { + name: "TwoFASetupToken", + sign: func() (string, error) { + return jwt.SignTwoFASetupToken(dep, 7, "secret") + }, + validate: func(token string) (any, error) { + return jwt.ValidateTwoFASetupToken(dep, token) + }, + assert: func(t *testing.T, claims any) { + parsed := claims.(*dto.TwoFaSetupJwtPayload) + if parsed.Secret != "secret" { + t.Fatalf("expected secret to be propagated, got %q", parsed.Secret) + } + }, + }, + { + name: "UserTokenRejectsWrongType", + sign: func() (string, error) { + return jwt.SignTwoFAToken(dep, 10) + }, + validate: func(token string) (any, error) { + return jwt.ValidateUserTokenGeneric(dep, token) + }, + expectedError: libjwt.ErrTokenInvalidClaims, + }, } - if claims.Secret != "secret" { - t.Fatalf("expected secret to be propagated, got %q", claims.Secret) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + token, err := tc.sign() + if err != nil { + t.Fatalf("sign returned error: %v", err) + } + + claims, err := tc.validate(token) + if tc.expectedError != nil { + if !errors.Is(err, tc.expectedError) { + t.Fatalf("expected %v, got %v", tc.expectedError, err) + } + return + } + if err != nil { + t.Fatalf("validate returned error: %v", err) + } + + if tc.assert != nil { + tc.assert(t, claims) + } + }) } } From ef5faa68dd08a0388de90b98526d1e91f251b29c Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Fri, 30 Jan 2026 00:25:28 +0200 Subject: [PATCH 5/5] refactor/backend: replace type assertion by errors.As --- backend/internal/middleware/error_handler.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/backend/internal/middleware/error_handler.go b/backend/internal/middleware/error_handler.go index 81e085a..ea3c992 100644 --- a/backend/internal/middleware/error_handler.go +++ b/backend/internal/middleware/error_handler.go @@ -1,6 +1,8 @@ package middleware import ( + "errors" + "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" ) @@ -33,18 +35,22 @@ func ErrorHandler() gin.HandlerFunc { err := c.Errors.Last().Err + var authErr *AuthError + // Handle AuthError specifically - if authErr, ok := err.(*AuthError); ok { + if errors.As(err, &authErr) { c.AbortWithStatusJSON(authErr.Status, gin.H{ "error": authErr.Message, }) return } + var validationErr validator.ValidationErrors + // Handle validation errors - if ve, ok := err.(validator.ValidationErrors); ok { - messages := make([]string, 0, len(ve)) - for _, fe := range ve { + if errors.As(err, &validationErr) { + messages := make([]string, 0, len(validationErr)) + for _, fe := range validationErr { messages = append(messages, fe.Error()) } c.AbortWithStatusJSON(400, gin.H{