From a374c9c39eaa948ca506aea35c0c38c0b1972f89 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Sun, 1 Feb 2026 22:08:33 +0200 Subject: [PATCH 01/15] refactor/backend: new config test --- backend/internal/config/config_test.go | 234 +++++++++++++++++++------ 1 file changed, 183 insertions(+), 51 deletions(-) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 51cabcc..b94cc77 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,84 +1,216 @@ package config import ( + "fmt" "testing" ) -func assertError(t *testing.T, err error, name string) { +const testKey = "TEST_KEY" +const notSet = "notSet" + +func setEnv(t *testing.T, v string) { t.Helper() - if err == nil { - t.Fatalf("expected error for %s", name) + + if v == notSet { + return } + + t.Setenv(testKey, v) } -func assertNoError(t *testing.T, err error, name string) { - t.Helper() - if err != nil { - t.Fatalf("unexpected error for %s: %v", name, err) +func TestGetEnvStrOrDefault(t *testing.T) { + const validValue = "v1" + const defaultValue = "v" + const emptyValue = "" + + testCases := []struct { + name string + envValue string + expected string + }{ + { + name: "valid env string", + envValue: validValue, + expected: validValue, + }, + { + name: "empty env string", + envValue: emptyValue, + expected: defaultValue, + }, + { + name: "env not set", + envValue: notSet, + expected: defaultValue, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + setEnv(t, tc.envValue) + + if got := getEnvStrOrDefault(testKey, defaultValue); got != tc.expected { + t.Fatalf("expect: %q, got %q", tc.expected, got) + } + }) } } -func TestGetEnvStrOrDefault(t *testing.T) { - t.Setenv("TEST_STR", "") - if got := getEnvStrOrDefault("TEST_STR", "fallback"); got != "fallback" { - t.Fatalf("expected default value, got %q", got) +func TestGetEnvIntOrDefault(t *testing.T) { + const validValue = "10" + const validExpected = 10 + const defaultValue = 22 + const invalidValue = "a" + + testCases := []struct { + name string + envValue string + expected int + }{ + { + name: "valid env (int)", + envValue: validValue, + expected: validExpected, + }, + { + name: "env not set (int)", + envValue: notSet, + expected: defaultValue, + }, + { + name: "invalid env (int)", + envValue: invalidValue, + expected: defaultValue, + }, } - t.Setenv("TEST_STR", "value") - if got := getEnvStrOrDefault("TEST_STR", "fallback"); got != "value" { - t.Fatalf("expected env value, got %q", got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setEnv(t, tc.envValue) + + if got := getEnvIntOrDefault(testKey, defaultValue); got != tc.expected { + t.Fatalf("expected: %d, got: %d", tc.expected, got) + } + }) } } func TestGetEnvStrOrError(t *testing.T) { - t.Setenv("TEST_PANIC", "") - _, err := getEnvStrOrError("TEST_PANIC") - assertError(t, err, "empty env") - - t.Setenv("TEST_PANIC", "value") - got, err := getEnvStrOrError("TEST_PANIC") - assertNoError(t, err, "set env") - if got != "value" { - t.Fatalf("expected env value, got %q", got) + const validValue = "v1" + const emptyValue = "" + const errorValue = "error" + + testCases := []struct { + name string + envValue string + expected string + expectErr bool + }{ + { + name: "valid env string", + envValue: validValue, + expected: validValue, + expectErr: false, + }, + { + name: "empty env string", + envValue: emptyValue, + expected: errorValue, + expectErr: true, + }, + { + name: "env not set", + envValue: notSet, + expected: errorValue, + expectErr: true, + }, } -} -func TestGetEnvIntOrDefault(t *testing.T) { - t.Setenv("TEST_INT", "") - if got := getEnvIntOrDefault("TEST_INT", 7); got != 7 { - t.Fatalf("expected default value, got %d", got) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setEnv(t, tc.envValue) + + got, err := getEnvStrOrError(testKey) + + if tc.expectErr && err == nil { + t.Fatalf("expected error, got %q", got) + } + + if !tc.expectErr && err != nil { + t.Fatalf("expected %q, got error %v", tc.expected, err) + } + + if !tc.expectErr && got != tc.expected { + t.Fatalf("expected %q, got %q", tc.expected, got) + } + }) } +} + +var mandatoryItems = []string{ + "JWT_SECRET", + "GOOGLE_CLIENT_ID", + "GOOGLE_CLIENT_SECRET", +} - t.Setenv("TEST_INT", "42") - if got := getEnvIntOrDefault("TEST_INT", 7); got != 42 { - t.Fatalf("expected env value, got %d", got) +func setEnvForMandatoryItem(t *testing.T, keys []string) { + t.Helper() + + for _, key := range mandatoryItems { + t.Setenv(key, "") } - t.Setenv("TEST_INT", "not-an-int") - if got := getEnvIntOrDefault("TEST_INT", 7); got != 7 { - t.Fatalf("expected default value for invalid int, got %d", got) + for _, key := range keys { + t.Setenv(key, "test_value") } } -func TestLoadConfigFromEnv_ErrsOnMissingRequired(t *testing.T) { - t.Setenv("JWT_SECRET", "jwt") - t.Setenv("GOOGLE_CLIENT_ID", "client") - t.Setenv("GOOGLE_CLIENT_SECRET", "secret") +func TestLoadConfigFromEnv_MissingMandatory(t *testing.T) { + type testCase struct { + name string + expectErr bool + keys []string + } + + testCases := []testCase{ + { + name: "normal case", + expectErr: false, + keys: mandatoryItems, + }, + } + + for i, item := range mandatoryItems { + keys := make([]string, 0, len(mandatoryItems)-1) + keys = append(keys, mandatoryItems[:i]...) + keys = append(keys, mandatoryItems[i+1:]...) - _, err := LoadConfigFromEnv() - assertNoError(t, err, "all required set") + tc := testCase{ + name: fmt.Sprintf("missing %s", item), + expectErr: true, + keys: keys, + } + testCases = append(testCases, tc) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { - t.Setenv("JWT_SECRET", "") - _, err = LoadConfigFromEnv() - assertError(t, err, "JWT_SECRET unset") + setEnvForMandatoryItem(t, tc.keys) + cfg, err := LoadConfigFromEnv() - t.Setenv("JWT_SECRET", "jwt") - t.Setenv("GOOGLE_CLIENT_ID", "") - _, err = LoadConfigFromEnv() - assertError(t, err, "GOOGLE_CLIENT_ID unset") + if tc.expectErr && err == nil { + t.Fatalf("expected error, but got cfg: %v.", cfg) + } - t.Setenv("GOOGLE_CLIENT_ID", "client") - t.Setenv("GOOGLE_CLIENT_SECRET", "") - _, err = LoadConfigFromEnv() - assertError(t, err, "GOOGLE_CLIENT_SECRET unset") + if !tc.expectErr && cfg == nil { + t.Fatalf("expected cfg, but got nil") + } + + if !tc.expectErr && err != nil { + t.Fatalf("expected cfg, but got err: %v", err) + } + }) + } } From ec4f2ca32cf75ebd9510778ca672434414f23cba Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Sun, 1 Feb 2026 23:14:15 +0200 Subject: [PATCH 02/15] refactor/backend: new schema tests --- backend/internal/config/config_test.go | 2 +- backend/internal/dto/schemas.go | 7 +- backend/internal/dto/schemas_test.go | 435 +++++++++++++------------ 3 files changed, 231 insertions(+), 213 deletions(-) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index b94cc77..da024ce 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -156,7 +156,7 @@ var mandatoryItems = []string{ func setEnvForMandatoryItem(t *testing.T, keys []string) { t.Helper() - + for _, key := range mandatoryItems { t.Setenv(key, "") } diff --git a/backend/internal/dto/schemas.go b/backend/internal/dto/schemas.go index adb7f90..b2e59c8 100644 --- a/backend/internal/dto/schemas.go +++ b/backend/internal/dto/schemas.go @@ -29,6 +29,7 @@ func InitValidator() { _ = Validate.RegisterValidation("username", validateUsername) _ = Validate.RegisterValidation("password", validatePassword) _ = Validate.RegisterValidation("identifier", validateIdentifier) + Validate.RegisterAlias("passwordField", "required,trim,min=6,max=20,password") registerUsernameTranslation(Validate, Trans) registerPasswordTranslation(Validate, Trans) registerIdentifierTranslation(Validate, Trans) @@ -80,15 +81,15 @@ func registerUsernameTranslation(v *validator.Validate, trans ut.Translator) { // Password type Password struct { - Password string `json:"password" validate:"required,trim,min=6,max=20,password"` + Password string `json:"password" validate:"passwordField"` } type OldPassword struct { - OldPassword string `json:"oldPassword" validate:"required,trim,password,min=6,max=20"` + OldPassword string `json:"oldPassword" validate:"passwordField"` } type NewPassword struct { - NewPassword string `json:"newPassword" validate:"required,trim,password,min=6,max=20"` + NewPassword string `json:"newPassword" validate:"passwordField"` } // Contains only letters, numbers, ".", "_" or "-" diff --git a/backend/internal/dto/schemas_test.go b/backend/internal/dto/schemas_test.go index fa3d3ca..60b9aee 100644 --- a/backend/internal/dto/schemas_test.go +++ b/backend/internal/dto/schemas_test.go @@ -1,283 +1,300 @@ package dto_test import ( - "encoding/json" + "errors" + "fmt" "strings" "testing" + "github.com/go-playground/validator/v10" "github.com/paularynty/transcendence/auth-service-go/internal/dto" ) -func init() { +func TestUsername_HappyPath(t *testing.T) { dto.InitValidator() -} - -func TestUserAvatarMustBeURL(t *testing.T) { - avatar := "avatar.png" - payload := dto.User{ - UserName: dto.UserName{Username: "valid_user"}, - Email: "user@example.com", - Avatar: &avatar, - } - - if err := dto.Validate.Struct(&payload); err == nil { - t.Fatalf("expected non-URL avatar to be rejected by url validator") - } - - validAvatar := "https://example.com/avatar.png" - payload.Avatar = &validAvatar - if err := dto.Validate.Struct(&payload); err != nil { - t.Fatalf("expected URL avatar to pass validation, got error: %v", err) - } -} - -func TestTwoFAChallengeRequiresNumericOTP(t *testing.T) { - invalid := dto.TwoFAChallengeRequest{TwoFACode: "AB12CD", SessionToken: "session-token"} - if err := dto.Validate.Struct(&invalid); err == nil { - t.Fatalf("expected alphanumeric code to fail numeric validator") + testCases := []struct { + value string + expectedValue string + }{ + {value: "aaa", expectedValue: "aaa"}, + {value: " aaa ", expectedValue: "aaa"}, + {value: "aA0_-", expectedValue: "aA0_-"}, } - valid := dto.TwoFAChallengeRequest{TwoFACode: "123456", SessionToken: "session-token"} - if err := dto.Validate.Struct(&valid); err != nil { - t.Fatalf("expected numeric code to pass validation, got error: %v", err) - } -} + for _, tc := range testCases { + t.Run(fmt.Sprintf("schema username happy path test: %q", tc.value), func(t *testing.T) { + req := &dto.UserName{ + Username: tc.value, + } -func TestTwoFAPendingUserResponseRequiresTaggedFields(t *testing.T) { - payload := dto.TwoFAPendingUserResponse{ - Message: "ANY_VALUE", - SessionToken: "session-token", - } + err := dto.Validate.Struct(req) + if err != nil { + t.Fatalf("expected %q, got err: %v", tc.expectedValue, err) + } - if err := dto.Validate.Struct(&payload); err != nil { - t.Fatalf("expected arbitrary message/twoFaUrl to be accepted, got error: %v", err) + if req.Username != tc.expectedValue { + t.Fatalf("expected %q, got %q", tc.expectedValue, req.Username) + } + }) } } -func TestUserJWTTypeAllowsAnyString(t *testing.T) { - payload := dto.UserJwtPayload{ - UserID: 1, - Type: "OTHER", - } +func TestUsername_Errors(t *testing.T) { + dto.InitValidator() - if err := dto.Validate.Struct(&payload); err != nil { - t.Fatalf("expected arbitrary type value to be accepted, got error: %v", err) - } -} + testCases := []string{ + "", // empty + "aa", // too short + " aa", // too short after trimming + " aa ", // too short after trimming + "aa ", // too short after trimming + "a a", // invalid char + "a%a", // invalid char + strings.Repeat("a", 51), // too long + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("schema username test: %q", tc), func(t *testing.T) { + req := &dto.UserName{ + Username: tc, + } -func TestUsersResponseMarshalsAsObjectWithSlice(t *testing.T) { - payload := dto.UsersResponse{Users: []dto.SimpleUser{}} + err := dto.Validate.Struct(req) + if err == nil { + t.Fatalf("expected error, got : %q", req.Username) + } - bytes, err := json.Marshal(payload) - if err != nil { - t.Fatalf("failed to marshal users response: %v", err) - } + var ve validator.ValidationErrors + if !errors.As(err, &ve) { + t.Fatalf("expected validation error, got %v", err) + } - expected := "{\"users\":[]}" - if string(bytes) != expected { - t.Fatalf("expected users response to marshal to %s, got %s", expected, string(bytes)) + for _, fe := range ve { + if fe.Field() != "Username" { + t.Fatalf("expected validation error on Username, got %v", err) + } + } + }) } } -func TestTrimValidationStripsWhitespace(t *testing.T) { - type payload struct { - Value string `validate:"required,trim,min=6"` - } - - data := &payload{Value: " foobar "} - if err := dto.Validate.Struct(data); err != nil { - t.Fatalf("expected trimmed value to pass validation, got error: %v", err) - } - - if data.Value != "foobar" { - t.Fatalf("expected trim validator to remove outer spaces, got %q", data.Value) - } - - tooShort := &payload{Value: " abcde "} - if err := dto.Validate.Struct(tooShort); err == nil { - t.Fatalf("expected trimmed value shorter than min to fail validation") - } +type passwordReqFactory func(string) (any, func() string) - emptyAfterTrim := &payload{Value: " "} - if err := dto.Validate.Struct(emptyAfterTrim); err == nil { - t.Fatalf("expected whitespace-only value to fail validation after trim") - } -} +func runPasswordTests(t *testing.T, label string, fieldName string, build passwordReqFactory) { + t.Helper() + dto.InitValidator() -func TestUsernameValidatorRules(t *testing.T) { - cases := []struct { - name string - input string - wantErr bool + validCases := []struct { + value string + expected string }{ - {"Valid", "valid_user", false}, - {"ValidTrimmed", " valid-user ", false}, - {"ValidTrimmedRight", "valid-user ", false}, - {"ValidTrimmedLeft", " valid-user", false}, - {"EmptyAfterTrim", " ", true}, - {"TooShort", "ab", true}, - {"TooShortAfterTrim", " ab ", true}, - {"ContainsSpace", "user name", true}, - {"IllegalChars", "user@name", true}, + {value: "pass123", expected: "pass123"}, + {value: " pass123 ", expected: "pass123"}, + {value: "aA0,.#$%@^;|_!*&?", expected: "aA0,.#$%@^;|_!*&?"}, + {value: strings.Repeat("a", 20), expected: strings.Repeat("a", 20)}, } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - payload := &dto.UserName{Username: tc.input} - err := dto.Validate.Struct(payload) - if tc.wantErr { - if err == nil { - t.Fatalf("expected username %q to be invalid", tc.input) - } - return - } + for _, tc := range validCases { + t.Run(fmt.Sprintf("%s valid %q", label, tc.value), func(t *testing.T) { + req, getValue := build(tc.value) + err := dto.Validate.Struct(req) if err != nil { - t.Fatalf("expected username %q to be valid, got error: %v", tc.input, err) + t.Fatalf("expected %q, got err: %v", tc.expected, err) } - if payload.Username != strings.TrimSpace(tc.input) { - t.Fatalf("expected username to be trimmed to %q, got %q", strings.TrimSpace(tc.input), payload.Username) + if got := getValue(); got != tc.expected { + t.Fatalf("expected %q, got %q", tc.expected, got) } }) } -} -func TestPasswordValidatorRules(t *testing.T) { - cases := []struct { - name string - input string - wantErr bool - }{ - {"ValidBasic", "Abc123", false}, - {"ValidSymbols", "pass,#$%", false}, - {"ValidTrimmedRight", "Abc123 ", false}, - {"ValidTrimmedLeft", " Abc123", false}, - {"EmptyAfterTrim", " ", true}, - {"TooShort", "ab", true}, - {"TooShortAfterTrim", " ab ", true}, - {"ContainsSpace", "pass word", true}, - {"DisallowedChar", "bad~pass", true}, + invalidCases := []string{ + "", // empty + "aaaaa", // too short + " aaaaa ", // too short after trimming + "aaa aa", // invalid char + "aa{}aa", // invalid char + strings.Repeat("a", 21), // too long } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - payload := &dto.Password{Password: tc.input} - err := dto.Validate.Struct(payload) - if tc.wantErr { - if err == nil { - t.Fatalf("expected password %q to be invalid", tc.input) - } - return + for _, tc := range invalidCases { + t.Run(fmt.Sprintf("%s invalid %q", label, tc), func(t *testing.T) { + req, _ := build(tc) + + err := dto.Validate.Struct(req) + if err == nil { + t.Fatalf("expected error, got nil") } - if err != nil { - t.Fatalf("expected password %q to be valid, got error: %v", tc.input, err) + var ve validator.ValidationErrors + if !errors.As(err, &ve) { + t.Fatalf("expected validation error, got %v", err) } - if payload.Password != strings.TrimSpace(tc.input) { - t.Fatalf("expected password to be trimmed to %q, got %q", strings.TrimSpace(tc.input), payload.Password) + for _, fe := range ve { + if fe.Field() != fieldName { + t.Fatalf("expected validation error on %s, got %v", fieldName, err) + } } }) } } -func TestIdentifierValidatorAcceptsUsernameOrEmail(t *testing.T) { - cases := []struct { - name string - input string - wantErr bool +func TestPasswordSchemas(t *testing.T) { + runPasswordTests(t, "Password", "Password", func(value string) (any, func() string) { + req := &dto.Password{Password: value} + return req, func() string { return req.Password } + }) + + runPasswordTests(t, "OldPassword", "OldPassword", func(value string) (any, func() string) { + req := &dto.OldPassword{OldPassword: value} + return req, func() string { return req.OldPassword } + }) + + runPasswordTests(t, "NewPassword", "NewPassword", func(value string) (any, func() string) { + req := &dto.NewPassword{NewPassword: value} + return req, func() string { return req.NewPassword } + }) +} + +func TestIdentifier_HappyPath(t *testing.T) { + dto.InitValidator() + + testCases := []struct { + value string + expectedValue string }{ - {"Username", "valid_user", false}, - {"Email", "user@example.com", false}, - {"TrimmedEmail", " user@example.com ", false}, - {"TrimmedEmailRight", "user@example.com ", false}, - {"TrimmedEmailLeft", " user@example.com", false}, - {"EmptyAfterTrim", " ", true}, - {"Invalid", "???", true}, - {"TooShort", "ab", true}, + {value: "user_01", expectedValue: "user_01"}, + {value: " user@example.com ", expectedValue: "user@example.com"}, } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - payload := &dto.Identifier{Identifier: tc.input} - err := dto.Validate.Struct(payload) - if tc.wantErr { - if err == nil { - t.Fatalf("expected identifier %q to be invalid", tc.input) - } - return - } + for _, tc := range testCases { + t.Run(fmt.Sprintf("schema identifier happy path test: %q", tc.value), func(t *testing.T) { + req := &dto.Identifier{Identifier: tc.value} + err := dto.Validate.Struct(req) if err != nil { - t.Fatalf("expected identifier %q to be valid, got error: %v", tc.input, err) + t.Fatalf("expected %q, got err: %v", tc.expectedValue, err) } - if payload.Identifier != strings.TrimSpace(tc.input) { - t.Fatalf("expected identifier to be trimmed to %q, got %q", strings.TrimSpace(tc.input), payload.Identifier) + if req.Identifier != tc.expectedValue { + t.Fatalf("expected %q, got %q", tc.expectedValue, req.Identifier) } }) } } -func TestCreateUserRequestValidation(t *testing.T) { - avatar := "https://example.com/avatar.png" - valid := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "valid_user"}, - Email: "user@example.com", - Avatar: &avatar, - }, - Password: dto.Password{Password: "Valid123"}, - } - - if err := dto.Validate.Struct(valid); err != nil { - t.Fatalf("expected create user request to be valid, got error: %v", err) - } - - invalid := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "valid_user"}, - Email: "user@example.com", - Avatar: &avatar, - }, - Password: dto.Password{Password: "no~"}, - } +func TestIdentifier_Errors(t *testing.T) { + dto.InitValidator() - if err := dto.Validate.Struct(invalid); err == nil { - t.Fatalf("expected create user request with disallowed password to fail validation") + testCases := []string{ + "", // empty + "ab", // too short for username + "a a", // invalid char + "bad@", // invalid email + "@bad.com", // invalid email } -} -func TestLoginUserRequestValidation(t *testing.T) { - valid := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "valid_user"}, - Password: dto.Password{Password: "Valid123"}, - } + for _, tc := range testCases { + t.Run(fmt.Sprintf("schema identifier error test: %q", tc), func(t *testing.T) { + req := &dto.Identifier{Identifier: tc} - if err := dto.Validate.Struct(valid); err != nil { - t.Fatalf("expected login request to be valid, got error: %v", err) - } + err := dto.Validate.Struct(req) + if err == nil { + t.Fatalf("expected error, got nil") + } - invalid := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "??"}, - Password: dto.Password{Password: "Valid123"}, - } + var ve validator.ValidationErrors + if !errors.As(err, &ve) { + t.Fatalf("expected validation error, got %v", err) + } - if err := dto.Validate.Struct(invalid); err == nil { - t.Fatalf("expected login request with invalid identifier to fail validation") + for _, fe := range ve { + if fe.Field() != "Identifier" { + t.Fatalf("expected validation error on Identifier, got %v", err) + } + } + }) } } -func TestTwoFAConfirmRequestValidation(t *testing.T) { - valid := &dto.TwoFAConfirmRequest{TwoFACode: "123456", SetupToken: "token"} - if err := dto.Validate.Struct(valid); err != nil { - t.Fatalf("expected valid 2FA confirm request to pass, got error: %v", err) +func TestRequestSchemas_HappyPath(t *testing.T) { + dto.InitValidator() + + testCases := []struct { + name string + req any + }{ + { + name: "CreateUserRequest", + req: &dto.CreateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "user1"}, + Email: "user1@example.com", + }, + Password: dto.Password{Password: "pass123"}, + }, + }, + { + name: "UpdateUserPasswordRequest", + req: &dto.UpdateUserPasswordRequest{ + OldPassword: dto.OldPassword{OldPassword: "oldpass"}, + NewPassword: dto.NewPassword{NewPassword: "newpass"}, + }, + }, + { + name: "LoginUserRequest", + req: &dto.LoginUserRequest{ + Identifier: dto.Identifier{Identifier: "user1"}, + Password: dto.Password{Password: "pass123"}, + }, + }, + { + name: "UpdateUserRequest", + req: &dto.UpdateUserRequest{ + User: dto.User{ + UserName: dto.UserName{Username: "user1"}, + Email: "user1@example.com", + }, + }, + }, + { + name: "UsernameRequest", + req: &dto.UsernameRequest{UserName: dto.UserName{Username: "user1"}}, + }, + { + name: "SetTwoFARequest", + req: &dto.SetTwoFARequest{TwoFA: true}, + }, + { + name: "DisableTwoFARequest", + req: &dto.DisableTwoFARequest{Password: dto.Password{Password: "pass123"}}, + }, + { + name: "TwoFAConfirmRequest", + req: &dto.TwoFAConfirmRequest{TwoFACode: "123456", SetupToken: "setup"}, + }, + { + name: "TwoFAChallengeRequest", + req: &dto.TwoFAChallengeRequest{TwoFACode: "123456", SessionToken: "session"}, + }, + { + name: "AddNewFriendRequest", + req: &dto.AddNewFriendRequest{UserID: 1}, + }, + { + name: "GoogleOauthCallback", + req: &dto.GoogleOauthCallback{Code: "code", State: "state"}, + }, } - invalid := &dto.TwoFAConfirmRequest{TwoFACode: "ABC123", SetupToken: "token"} - if err := dto.Validate.Struct(invalid); err == nil { - t.Fatalf("expected non-numeric 2FA code to fail validation") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := dto.Validate.Struct(tc.req); err != nil { + t.Fatalf("expected valid %s, got err: %v", tc.name, err) + } + }) } } From 63a3cf8e89c06fb040dc8c5ee15ecd500ad7825e Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Sun, 1 Feb 2026 23:45:53 +0200 Subject: [PATCH 03/15] feat/backend: add NewMiddlewareTestRouter for middleware testing --- backend/internal/testutil/testutil.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go index b00d263..2852ee0 100644 --- a/backend/internal/testutil/testutil.go +++ b/backend/internal/testutil/testutil.go @@ -4,6 +4,7 @@ import ( "io" "log/slog" + "github.com/gin-gonic/gin" "github.com/paularynty/transcendence/auth-service-go/internal/config" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/redis/go-redis/v9" @@ -48,3 +49,23 @@ func NewTestDependency(cfg *config.Config, db *gorm.DB, redis *redis.Client, log } return dependency.NewDependency(cfg, db, redis, logger) } + +func NewMiddlewareTestRouter(middleware1 gin.HandlerFunc, middleware2 gin.HandlerFunc) *gin.Engine { + r := gin.New() + + if middleware1 != nil { + r.Use(middleware1) + } + + if middleware2 != nil { + r.Use(middleware2) + } + + r.POST("/middleware-test", func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "ok", + }) + }) + + return r +} From 8d75bad8b7351ad8aa1e9c7b67779975b13f343f Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:05:55 +0200 Subject: [PATCH 04/15] refactor/backend: slightly re-structure userService, and add AuthServiceInterface for testing --- backend/internal/middleware/auth.go | 14 +++++-- backend/internal/service/helper.go | 16 ------- backend/internal/service/user_service.go | 20 +++++++++ backend/internal/testutil/testutil.go | 53 ++++++++++++++++-------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 09a47ef..ad3099b 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -1,19 +1,25 @@ package middleware import ( + "context" "errors" "strings" "github.com/gin-gonic/gin" authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - "github.com/paularynty/transcendence/auth-service-go/internal/service" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) const PrefixBearer = "Bearer " -func Auth(userService *service.UserService) gin.HandlerFunc { +type AuthService interface { + ValidateUserToken(ctx context.Context, tokenString string, userID uint) error + GetDependency() *dependency.Dependency +} + +func Auth(authService AuthService) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") @@ -24,13 +30,13 @@ func Auth(userService *service.UserService) gin.HandlerFunc { tokenString := authHeader[len(PrefixBearer):] - userJwtPayload, err := jwt.ValidateUserTokenGeneric(userService.Dep, tokenString) + userJwtPayload, err := jwt.ValidateUserTokenGeneric(authService.GetDependency(), tokenString) if err != nil { _ = c.AbortWithError(401, authError.NewAuthError(401, "Invalid or expired token")) return } - err = userService.ValidateUserToken(c.Request.Context(), tokenString, userJwtPayload.UserID) + err = authService.ValidateUserToken(c.Request.Context(), tokenString, userJwtPayload.UserID) var authError *authError.AuthError if err != nil { diff --git a/backend/internal/service/helper.go b/backend/internal/service/helper.go index b6b7868..2b608e1 100644 --- a/backend/internal/service/helper.go +++ b/backend/internal/service/helper.go @@ -8,7 +8,6 @@ import ( "time" model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" "github.com/redis/go-redis/v9" @@ -18,21 +17,6 @@ import ( const HeartBeatPrefix = "heartbeat:" -func NewUserService(dep *dependency.Dependency) (*UserService, error) { - - if dep.DB == nil { - return nil, fmt.Errorf("UserService: db is nil") - } - - if dep.Cfg.IsRedisEnabled && dep.Redis == nil { - return nil, fmt.Errorf("UserService: redis is enabled but redis client is nil") - } - - return &UserService{ - Dep: dep, - }, nil -} - func isTwoFAEnabled(twoFAToken *string) bool { return twoFAToken != nil && *twoFAToken != "" && !strings.HasPrefix(*twoFAToken, TwoFAPrePrefix) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a14761f..6b15d6b 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "fmt" "strings" "time" @@ -26,6 +27,25 @@ type UserService struct { Dep *dependency.Dependency } +func NewUserService(dep *dependency.Dependency) (*UserService, error) { + + if dep.DB == nil { + return nil, fmt.Errorf("UserService: db is nil") + } + + if dep.Cfg.IsRedisEnabled && dep.Redis == nil { + return nil, fmt.Errorf("UserService: redis is enabled but redis client is nil") + } + + return &UserService{ + Dep: dep, + }, nil +} + +func (s *UserService) GetDependency() *dependency.Dependency { + return s.Dep +} + func (s *UserService) CreateUser(ctx context.Context, request *dto.CreateUserRequest) (*dto.UserWithoutTokenResponse, error) { passwordBytes, err := bcrypt.GenerateFromPassword([]byte(request.Password.Password), BcryptSaltRounds) diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go index 2852ee0..4c37276 100644 --- a/backend/internal/testutil/testutil.go +++ b/backend/internal/testutil/testutil.go @@ -11,36 +11,44 @@ import ( "gorm.io/gorm" ) -func NewTestConfig() *config.Config { - return &config.Config{ - JwtSecret: "test-secret", - UserTokenExpiry: 3600, - UserTokenAbsoluteExpiry: 2592000, - OauthStateTokenExpiry: 600, - GoogleClientId: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectUri: "http://localhost:8080/callback", - FrontendUrl: "http://localhost:3000", - TwoFaUrlPrefix: "otpauth://totp/Transcendence?secret=", - TwoFaTokenExpiry: 600, - RedisURL: "", - IsRedisEnabled: false, - } -} - func NewTestLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ Level: slog.LevelError, })) } +func NewTestConfig() *config.Config { + return &config.Config{ + GinMode: "test", + DbAddress: "inmemory://test", + JwtSecret: "test-jwt-secret", + UserTokenExpiry: 5, + OauthStateTokenExpiry: 5, + GoogleClientId: "test-google-client-id", + GoogleClientSecret: "test-google-client-secret", + GoogleRedirectUri: "test-google-redirect-uri", + FrontendUrl: "http://localhost:5173", + TwoFaUrlPrefix: "otpauth://totp/Transcendence?secret=", + TwoFaTokenExpiry: 5, + RedisURL: "", + IsRedisEnabled: false, + UserTokenAbsoluteExpiry: 2592000, + Port: 3003, + RateLimiterDurationInSec: 5, + RateLimiterRequestLimit: 10, + RateLimiterCleanupIntervalInSec: 10, + } +} + func NewTestDependency(cfg *config.Config, db *gorm.DB, redis *redis.Client, logger *slog.Logger) *dependency.Dependency { if cfg == nil { cfg = NewTestConfig() } + if logger == nil { logger = NewTestLogger() } + if redis != nil { cfg.IsRedisEnabled = true if cfg.RedisURL == "" { @@ -69,3 +77,14 @@ func NewMiddlewareTestRouter(middleware1 gin.HandlerFunc, middleware2 gin.Handle return r } + +func NewIntegrationTestRouter(dep *dependency.Dependency, handlers ...gin.HandlerFunc) *gin.Engine { + r := gin.New() + + for _, handler := range handlers { + r.Use(handler) + } + + return r +} + From 06485dce191afe85e4bd8a571b6cf00ddabbc9a1 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:57:23 +0200 Subject: [PATCH 05/15] refactor/backend: new auth test --- backend/internal/middleware/auth.go | 2 + backend/internal/middleware/auth_test.go | 161 +++++++---------------- backend/internal/testutil/testutil.go | 3 +- 3 files changed, 49 insertions(+), 117 deletions(-) diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index ad3099b..8f92705 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -30,12 +30,14 @@ func Auth(authService AuthService) gin.HandlerFunc { tokenString := authHeader[len(PrefixBearer):] + // JWT validation userJwtPayload, err := jwt.ValidateUserTokenGeneric(authService.GetDependency(), tokenString) if err != nil { _ = c.AbortWithError(401, authError.NewAuthError(401, "Invalid or expired token")) return } + // Online validation err = authService.ValidateUserToken(c.Request.Context(), tokenString, userJwtPayload.UserID) var authError *authError.AuthError diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go index 417cc5a..c00b9e7 100644 --- a/backend/internal/middleware/auth_test.go +++ b/backend/internal/middleware/auth_test.go @@ -1,146 +1,77 @@ package middleware_test import ( - "encoding/json" + "context" + "fmt" "net/http" "net/http/httptest" "testing" - "github.com/gin-gonic/gin" - - model "github.com/paularynty/transcendence/auth-service-go/internal/db" + authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" "github.com/paularynty/transcendence/auth-service-go/internal/dependency" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" - "github.com/paularynty/transcendence/auth-service-go/internal/service" "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" - "gorm.io/driver/sqlite" - "gorm.io/gorm" ) -func setupAuthDeps(t *testing.T) (*dependency.Dependency, *service.UserService) { - t.Helper() - cfg := testutil.NewTestConfig() - cfg.JwtSecret = "test-secret-key" - cfg.UserTokenExpiry = 3600 - cfg.OauthStateTokenExpiry = 120 - cfg.TwoFaTokenExpiry = 300 - dbName := "file:" + t.Name() + "?mode=memory&cache=shared" - db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) - if err != nil { - t.Fatalf("failed to connect to db: %v", err) - } - if err := db.AutoMigrate(&model.User{}, &model.Token{}); err != nil { - t.Fatalf("failed to migrate db: %v", err) - } - dep := testutil.NewTestDependency(cfg, db, nil, nil) - userService, err := service.NewUserService(dep) - if err != nil { - t.Fatalf("failed to create user service: %v", err) - } - return dep, userService -} - -func TestAuthMiddlewareRejectsMissingToken(t *testing.T) { - gin.SetMode(gin.TestMode) - _, userService := setupAuthDeps(t) +var testDep = testutil.NewTestDependency(nil, nil, nil, nil) - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(userService)) - r.GET("/protected", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - - req := httptest.NewRequest(http.MethodGet, "/protected", nil) - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) +type testAuthService struct { + returnCode int +} - if resp.Code != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d", resp.Code) - } +func (ts *testAuthService) GetDependency() *dependency.Dependency { + return testDep +} - var body map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) +func (ts *testAuthService) ValidateUserToken(ctx context.Context, tokenString string, userID uint) error { + switch ts.returnCode { + case 200: + return nil + case 401: + return authError.NewAuthError(401, "invalid token") + default: + return fmt.Errorf("unexpected error") } +} - if body["error"] != "Invalid or expired token" { - t.Fatalf("unexpected error message: %v", body) - } +func newTestAuthService(returnCode int) middleware.AuthService { + return &testAuthService{returnCode: returnCode} } -func TestAuthMiddlewareAllowsValidToken(t *testing.T) { - gin.SetMode(gin.TestMode) - dep, userService := setupAuthDeps(t) +const notSet = "notSet" - token, err := jwt.SignUserToken(dep, 99) +func TestAuth(t *testing.T) { + validToken, err := jwt.SignUserToken(testDep, 1) if err != nil { - t.Fatalf("failed to sign user token: %v", err) - } - if err := dep.DB.Create(&model.User{Model: gorm.Model{ID: 99}, Username: "u99", Email: "u99@example.com"}).Error; err != nil { - t.Fatalf("failed to create user: %v", err) - } - if err := dep.DB.Create(&model.Token{UserID: 99, Token: token}).Error; err != nil { - t.Fatalf("failed to create token: %v", err) - } - - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(userService)) - r.GET("/protected", func(c *gin.Context) { - userID, ok := c.Get("userID") - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "missing userID"}) - return - } - c.JSON(http.StatusOK, gin.H{"userId": userID}) - }) - - req := httptest.NewRequest(http.MethodGet, "/protected", nil) - req.Header.Set("Authorization", middleware.PrefixBearer+token) - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) + t.Fatalf("failed to sign test token, err: %v", err) } - var body map[string]any - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if val, ok := body["userId"].(float64); !ok || val != 99 { - t.Fatalf("expected userId 99, got %v", body["userId"]) - } -} - -func TestAuthMiddlewareRejectsInvalidToken(t *testing.T) { - gin.SetMode(gin.TestMode) - dep, userService := setupAuthDeps(t) - - token, err := jwt.SignTwoFAToken(dep, 10) - if err != nil { - t.Fatalf("failed to sign 2fa token: %v", err) + testCases := []struct { + name string + token string + expectedStatus int + }{ + {name: "valid token", token: validToken, expectedStatus: 200}, + {name: "invalid token", token: "aaa", expectedStatus: 401}, + {name: "token not set", token: notSet, expectedStatus: 401}, } - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.Auth(userService)) - r.GET("/protected", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := testutil.NewMiddlewareTestRouter(middleware.Auth(newTestAuthService(tc.expectedStatus)), nil) + req, _ := http.NewRequest("POST", "/middleware-test", nil) - req := httptest.NewRequest(http.MethodGet, "/protected", nil) - req.Header.Set("Authorization", middleware.PrefixBearer+token) - resp := httptest.NewRecorder() + if tc.token != notSet { + req.Header.Add("Authorization", fmt.Sprintf("%s%s", middleware.PrefixBearer, tc.token)) + } - r.ServeHTTP(resp, req) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) - if resp.Code != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d", resp.Code) + if w.Code != tc.expectedStatus { + t.Fatalf("expected: %d, got: %d", tc.expectedStatus, w.Code) + } + }) } } diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go index 4c37276..d75eaf5 100644 --- a/backend/internal/testutil/testutil.go +++ b/backend/internal/testutil/testutil.go @@ -74,7 +74,7 @@ func NewMiddlewareTestRouter(middleware1 gin.HandlerFunc, middleware2 gin.Handle "message": "ok", }) }) - + return r } @@ -87,4 +87,3 @@ func NewIntegrationTestRouter(dep *dependency.Dependency, handlers ...gin.Handle return r } - From 2575a02077c84926bf7f3b2097ddc2f7442c2270 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 22:13:07 +0200 Subject: [PATCH 06/15] refactor/backend: enhance auth test to validate userID and token in response --- backend/internal/middleware/auth_test.go | 23 ++++++++++++++++++++++- backend/internal/testutil/testutil.go | 6 +++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go index c00b9e7..cacb235 100644 --- a/backend/internal/middleware/auth_test.go +++ b/backend/internal/middleware/auth_test.go @@ -2,6 +2,7 @@ package middleware_test import ( "context" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -40,9 +41,15 @@ func newTestAuthService(returnCode int) middleware.AuthService { } const notSet = "notSet" +const userID = 11 + +var resp struct { + UserID uint `json:"userID"` + Token string `json:"token"` +} func TestAuth(t *testing.T) { - validToken, err := jwt.SignUserToken(testDep, 1) + validToken, err := jwt.SignUserToken(testDep, userID) if err != nil { t.Fatalf("failed to sign test token, err: %v", err) } @@ -72,6 +79,20 @@ func TestAuth(t *testing.T) { if w.Code != tc.expectedStatus { t.Fatalf("expected: %d, got: %d", tc.expectedStatus, w.Code) } + + if w.Code == 200 { + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp.Token != validToken { + t.Fatalf("token does not set to request") + } + + if resp.UserID != userID { + t.Fatalf("userID does not set to request, expected: %d, got: %d", userID, resp.UserID) + } + } }) } } diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go index d75eaf5..3bde0f1 100644 --- a/backend/internal/testutil/testutil.go +++ b/backend/internal/testutil/testutil.go @@ -70,8 +70,12 @@ func NewMiddlewareTestRouter(middleware1 gin.HandlerFunc, middleware2 gin.Handle } r.POST("/middleware-test", func(c *gin.Context) { + userID := c.GetUint("userID") + token := c.GetString("token") + c.JSON(200, gin.H{ - "message": "ok", + "userID": userID, + "token": token, }) }) From 19f4b164e57a8a437d70ea3d03ffd5ea443e041d Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 22:52:36 +0200 Subject: [PATCH 07/15] refactor/backend: simplify error handler tests and consolidate error generation logic --- .../internal/middleware/error_handler_test.go | 153 ++++++------------ 1 file changed, 47 insertions(+), 106 deletions(-) diff --git a/backend/internal/middleware/error_handler_test.go b/backend/internal/middleware/error_handler_test.go index e55585c..ada2c36 100644 --- a/backend/internal/middleware/error_handler_test.go +++ b/backend/internal/middleware/error_handler_test.go @@ -1,130 +1,71 @@ package middleware_test import ( - "bytes" - "encoding/json" - "errors" + "fmt" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" + "github.com/go-playground/validator/v10" authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" ) -func TestErrorHandlerReturnsAuthErrorPayload(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.GET("/auth", func(c *gin.Context) { - _ = c.AbortWithError(http.StatusUnauthorized, authError.NewAuthError(http.StatusUnauthorized, "Invalid or expired token")) - }) - - req := httptest.NewRequest(http.MethodGet, "/auth", nil) - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d", resp.Code) - } - - var body map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if body["error"] != "Invalid or expired token" { - t.Fatalf("unexpected error payload: %v", body) - } +type testStruct struct { + UserId int `json:"userID" validate:"required"` } -func TestErrorHandlerDifferentErrors(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.GET("/unknown", func(c *gin.Context) { - _ = c.AbortWithError(http.StatusTeapot, errors.New("boom")) - }) - - req := httptest.NewRequest(http.MethodGet, "/unknown", nil) - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusTeapot { - t.Fatalf("expected status 418, got %d body=%s", resp.Code, resp.Body.String()) - } - - var body map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if body["error"] != "Internal Server Error" { - t.Fatalf("unexpected error payload: %v", body) +func errorGenerator(code int) gin.HandlerFunc { + return func(c *gin.Context) { + switch code { + case 200: + c.Next() + case 401: + _ = c.AbortWithError(401, authError.NewAuthError(401, "invalid user")) + case 400: + v := validator.New() + s := testStruct{} + + err := v.Struct(s) + if err == nil { + panic("expected validation error, got nil") + } + + _ = c.AbortWithError(400, err) + case 500: + _ = c.AbortWithError(500, fmt.Errorf("unknown error")) + default: + panic("panic test") + } } } -func TestValidationMiddlewareReturnsValidationErrors(t *testing.T) { - gin.SetMode(gin.TestMode) - dto.InitValidator() - - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.ValidateBody[dto.UserName]()) - r.POST("/validate", func(c *gin.Context) { - // Should not reach when validation fails - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - - req := httptest.NewRequest(http.MethodPost, "/validate", bytes.NewBufferString(`{"username":" a( "}`)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected status 400, got %d body=%s", resp.Code, resp.Body.String()) - } - - var body map[string]any - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - errorsField, ok := body["error"].([]any) - if !ok || len(errorsField) != 1 { - t.Fatalf("expected validation errors array, got %v", body) +func TestErrorHandler(t *testing.T) { + testCases := []struct { + name string + code int + expectedCode int + }{ + {name: "normal case", code: 200, expectedCode: 200}, + {name: "400 error", code: 400, expectedCode: 400}, + {name: "401 error", code: 401, expectedCode: 401}, + {name: "500 error", code: 500, expectedCode: 500}, } -} - -func TestPanicHandlerReturnsJSON500(t *testing.T) { - gin.SetMode(gin.TestMode) - r := gin.New() - r.Use(middleware.PanicHandler()) - r.GET("/panic", func(c *gin.Context) { - panic("boom") - }) - - req := httptest.NewRequest(http.MethodGet, "/panic", nil) - resp := httptest.NewRecorder() - r.ServeHTTP(resp, req) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := testutil.NewMiddlewareTestRouter(errorGenerator(tc.code), middleware.ErrorHandler()) + w := httptest.NewRecorder() - if resp.Code != http.StatusInternalServerError { - t.Fatalf("expected status 500, got %d body=%s", resp.Code, resp.Body.String()) - } - - var body map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } + req, _ := http.NewRequest("POST", "/middleware-test", nil) + r.ServeHTTP(w, req) - if body["error"] != "Internal Server Error" { - t.Fatalf("unexpected error payload: %v", body) + if w.Code != tc.expectedCode { + t.Fatalf("expected: %d, got: %d", tc.expectedCode, w.Code) + } + }) } } From 1d276edc55d8f52de05532e22c456f39749ac799 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:01:54 +0200 Subject: [PATCH 08/15] refactor/backend: reorganize rate limiter tests for clarity and consistency --- .../middleware/rate_limiter_internal_test.go | 49 ----- .../internal/middleware/rate_limiter_test.go | 179 ++++++++---------- 2 files changed, 81 insertions(+), 147 deletions(-) delete mode 100644 backend/internal/middleware/rate_limiter_internal_test.go diff --git a/backend/internal/middleware/rate_limiter_internal_test.go b/backend/internal/middleware/rate_limiter_internal_test.go deleted file mode 100644 index 24046f6..0000000 --- a/backend/internal/middleware/rate_limiter_internal_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package middleware - -import ( - "testing" - "time" -) - -func TestAllowRequestCleansExpiredEntriesAtInterval(t *testing.T) { - rl := NewRateLimiter(10*time.Millisecond, 1, 5*time.Millisecond) - - now := time.Now() - rl.requestCounts["old"] = 2 - rl.requestExpiry["old"] = now.Add(-time.Second) - rl.lastCleanup = now.Add(-rl.cleanupInterval - time.Second) - - _ = rl.AllowRequest("new-client") - - if _, exists := rl.requestCounts["old"]; exists { - t.Fatalf("expected expired request count to be removed during cleanup") - } - if _, exists := rl.requestExpiry["old"]; exists { - t.Fatalf("expected expired request expiry to be removed during cleanup") - } -} - -func TestUnsafeClearExpiredEntriesRemovesOnlyExpired(t *testing.T) { - rl := NewRateLimiter(10*time.Millisecond, 1, time.Minute) - - now := time.Now() - rl.requestCounts["expired"] = 1 - rl.requestExpiry["expired"] = now.Add(-time.Second) - rl.requestCounts["active"] = 1 - rl.requestExpiry["active"] = now.Add(time.Second) - - unSafeClearExpiredEntries(now, rl) - - if _, exists := rl.requestCounts["expired"]; exists { - t.Fatalf("expected expired request count to be removed") - } - if _, exists := rl.requestExpiry["expired"]; exists { - t.Fatalf("expected expired request expiry to be removed") - } - if _, exists := rl.requestCounts["active"]; !exists { - t.Fatalf("expected active request count to remain") - } - if _, exists := rl.requestExpiry["active"]; !exists { - t.Fatalf("expected active request expiry to remain") - } -} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 3b544a9..17a0306 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -10,142 +10,125 @@ import ( "github.com/gin-gonic/gin" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" ) -func TestAllowRequestResetsAfterWindow(t *testing.T) { - gin.SetMode(gin.TestMode) +const rateLimiterPath = "/middleware-test" - rl := middleware.NewRateLimiter(30*time.Millisecond, 2, time.Minute) - clientID := "client-1" +func newRateLimiterRouter(rl *middleware.RateLimiter) *gin.Engine { + r := testutil.NewIntegrationTestRouter(nil, middleware.ErrorHandler(), rl.RateLimit()) - if !rl.AllowRequest(clientID) { - t.Fatalf("expected first request to pass") - } - if !rl.AllowRequest(clientID) { - t.Fatalf("expected second request to pass within window") - } - if rl.AllowRequest(clientID) { - t.Fatalf("expected third request to be blocked within window") - } + r.POST(rateLimiterPath, func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + r.OPTIONS(rateLimiterPath, func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) - time.Sleep(40 * time.Millisecond) + return r +} - if !rl.AllowRequest(clientID) { - t.Fatalf("expected requests to be allowed after window resets") - } +func doRequest(r http.Handler, method string, ip string) *httptest.ResponseRecorder { + req, _ := http.NewRequest(method, rateLimiterPath, nil) + req.RemoteAddr = ip + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w } -func TestRateLimitMiddlewareBlocksAfterLimit(t *testing.T) { - gin.SetMode(gin.TestMode) +func assertErrorMessage(t *testing.T, w *httptest.ResponseRecorder, expected string) { + t.Helper() - rl := middleware.NewRateLimiter(50*time.Millisecond, 1, time.Minute) + var body map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if body["error"] != expected { + t.Fatalf("unexpected error payload: %v", body) + } +} - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(rl.RateLimit()) - r.GET("/limited", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) +func TestRateLimiterBlocksAfterLimit(t *testing.T) { + rl := middleware.NewRateLimiter(100*time.Millisecond, 2, time.Minute) + r := newRateLimiterRouter(rl) - req1 := httptest.NewRequest(http.MethodGet, "/limited", nil) - req1.RemoteAddr = "198.51.100.10:1234" - resp1 := httptest.NewRecorder() - r.ServeHTTP(resp1, req1) + resp1 := doRequest(r, http.MethodPost, "203.0.113.1:1000") if resp1.Code != http.StatusOK { - t.Fatalf("expected first request status 200, got %d", resp1.Code) + t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) } - req2 := httptest.NewRequest(http.MethodGet, "/limited", nil) - req2.RemoteAddr = "198.51.100.10:5678" - resp2 := httptest.NewRecorder() - r.ServeHTTP(resp2, req2) + resp2 := doRequest(r, http.MethodPost, "203.0.113.1:1000") + if resp2.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp2.Code) + } - if resp2.Code != http.StatusTooManyRequests { - t.Fatalf("expected status 429, got %d body=%s", resp2.Code, resp2.Body.String()) + blocked := doRequest(r, http.MethodPost, "203.0.113.1:1000") + if blocked.Code != http.StatusTooManyRequests { + t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, blocked.Code) } + assertErrorMessage(t, blocked, "Too many requests") +} - var body map[string]string - if err := json.Unmarshal(resp2.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) +func TestRateLimiterResetsAfterWindow(t *testing.T) { + rl := middleware.NewRateLimiter(30*time.Millisecond, 1, time.Minute) + r := newRateLimiterRouter(rl) + + resp1 := doRequest(r, http.MethodPost, "198.51.100.3:9999") + if resp1.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) } - if body["error"] != "Too many requests" { - t.Fatalf("unexpected error payload: %v", body) + + resp2 := doRequest(r, http.MethodPost, "198.51.100.3:9999") + if resp2.Code != http.StatusTooManyRequests { + t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp2.Code) } -} -func TestRateLimitMiddlewareSkipsOptions(t *testing.T) { - gin.SetMode(gin.TestMode) + time.Sleep(60 * time.Millisecond) - rl := middleware.NewRateLimiter(50*time.Millisecond, 1, time.Minute) + resp3 := doRequest(r, http.MethodPost, "198.51.100.3:9999") + if resp3.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp3.Code) + } +} - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(rl.RateLimit()) - r.OPTIONS("/limited", func(c *gin.Context) { - c.Status(http.StatusNoContent) - }) - r.GET("/limited", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) +func TestRateLimiterOptionsBypass(t *testing.T) { + rl := middleware.NewRateLimiter(100*time.Millisecond, 1, time.Minute) + r := newRateLimiterRouter(rl) - reqOptions := httptest.NewRequest(http.MethodOptions, "/limited", nil) - reqOptions.RemoteAddr = "198.51.100.20:9999" for i := 0; i < 3; i++ { - resp := httptest.NewRecorder() - r.ServeHTTP(resp, reqOptions) + resp := doRequest(r, http.MethodOptions, "203.0.113.2:5555") if resp.Code != http.StatusNoContent { - t.Fatalf("expected OPTIONS to bypass limiter with 204, got %d", resp.Code) + t.Fatalf("expected status %d, got %d", http.StatusNoContent, resp.Code) } } - reqGet := httptest.NewRequest(http.MethodGet, "/limited", nil) - reqGet.RemoteAddr = "198.51.100.20:9999" - respGet1 := httptest.NewRecorder() - r.ServeHTTP(respGet1, reqGet) - if respGet1.Code != http.StatusOK { - t.Fatalf("expected first GET 200 after OPTIONS calls, got %d", respGet1.Code) + resp1 := doRequest(r, http.MethodPost, "203.0.113.2:5555") + if resp1.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) } - respGet2 := httptest.NewRecorder() - r.ServeHTTP(respGet2, reqGet) - if respGet2.Code != http.StatusTooManyRequests { - t.Fatalf("expected second GET to be rate limited with 429, got %d", respGet2.Code) + resp2 := doRequest(r, http.MethodPost, "203.0.113.2:5555") + if resp2.Code != http.StatusTooManyRequests { + t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp2.Code) } } -func TestRateLimitMiddlewareUsesClientSpecificCounters(t *testing.T) { - gin.SetMode(gin.TestMode) - +func TestRateLimiterClientIsolation(t *testing.T) { rl := middleware.NewRateLimiter(100*time.Millisecond, 1, time.Minute) + r := newRateLimiterRouter(rl) - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(rl.RateLimit()) - r.GET("/limited", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - - reqClientA := httptest.NewRequest(http.MethodGet, "/limited", nil) - reqClientA.RemoteAddr = "203.0.113.1:5000" - respA1 := httptest.NewRecorder() - r.ServeHTTP(respA1, reqClientA) - if respA1.Code != http.StatusOK { - t.Fatalf("expected client A first request 200, got %d", respA1.Code) + resp1 := doRequest(r, http.MethodPost, "203.0.113.10:5000") + if resp1.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) } - reqClientB := httptest.NewRequest(http.MethodGet, "/limited", nil) - reqClientB.RemoteAddr = "203.0.113.2:5000" - respB := httptest.NewRecorder() - r.ServeHTTP(respB, reqClientB) - if respB.Code != http.StatusOK { - t.Fatalf("expected client B request 200, got %d body=%s", respB.Code, respB.Body.String()) + resp2 := doRequest(r, http.MethodPost, "203.0.113.11:5000") + if resp2.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp2.Code) } - reqClientA2 := httptest.NewRequest(http.MethodGet, "/limited", nil) - reqClientA2.RemoteAddr = "203.0.113.1:6000" - respA2 := httptest.NewRecorder() - r.ServeHTTP(respA2, reqClientA2) - if respA2.Code != http.StatusTooManyRequests { - t.Fatalf("expected client A second request 429, got %d", respA2.Code) + resp3 := doRequest(r, http.MethodPost, "203.0.113.10:6000") + if resp3.Code != http.StatusTooManyRequests { + t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp3.Code) } } From 6545ae0bf192974aaa0ca7eb7e113b05c4a44a70 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:17:20 +0200 Subject: [PATCH 09/15] refactor/backend: refactor test for rate limiter, and validation --- .../internal/middleware/rate_limiter_test.go | 229 +++++++++--------- .../internal/middleware/validation_test.go | 138 +++-------- 2 files changed, 152 insertions(+), 215 deletions(-) diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 17a0306..be8b5c8 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -1,7 +1,6 @@ package middleware_test import ( - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -13,122 +12,134 @@ import ( "github.com/paularynty/transcendence/auth-service-go/internal/testutil" ) -const rateLimiterPath = "/middleware-test" - -func newRateLimiterRouter(rl *middleware.RateLimiter) *gin.Engine { - r := testutil.NewIntegrationTestRouter(nil, middleware.ErrorHandler(), rl.RateLimit()) - - r.POST(rateLimiterPath, func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - r.OPTIONS(rateLimiterPath, func(c *gin.Context) { - c.Status(http.StatusNoContent) - }) - - return r -} - -func doRequest(r http.Handler, method string, ip string) *httptest.ResponseRecorder { - req, _ := http.NewRequest(method, rateLimiterPath, nil) - req.RemoteAddr = ip - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - return w -} - -func assertErrorMessage(t *testing.T, w *httptest.ResponseRecorder, expected string) { - t.Helper() - - var body map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if body["error"] != expected { - t.Fatalf("unexpected error payload: %v", body) - } -} - -func TestRateLimiterBlocksAfterLimit(t *testing.T) { - rl := middleware.NewRateLimiter(100*time.Millisecond, 2, time.Minute) - r := newRateLimiterRouter(rl) - - resp1 := doRequest(r, http.MethodPost, "203.0.113.1:1000") - if resp1.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) +func TestRateLimiter(t *testing.T) { + testCases := []struct { + name string + duration time.Duration + limit int + sleep time.Duration + methods []string + remoteAddrs []string + expectedStatus []int + needOptions bool + }{ + { + name: "blocks after limit", + duration: 100 * time.Millisecond, + limit: 2, + methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, + remoteAddrs: []string{"203.0.113.1:1000", "203.0.113.1:1000", "203.0.113.1:1000"}, + expectedStatus: []int{200, 200, 429}, + }, + { + name: "resets after window", + duration: 30 * time.Millisecond, + limit: 1, + sleep: 60 * time.Millisecond, + methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, + remoteAddrs: []string{"198.51.100.3:9999", "198.51.100.3:9999", "198.51.100.3:9999"}, + expectedStatus: []int{200, 429, 200}, + }, + { + name: "options bypass", + duration: 100 * time.Millisecond, + limit: 1, + methods: []string{http.MethodOptions, http.MethodOptions, http.MethodOptions, http.MethodPost, http.MethodPost}, + remoteAddrs: []string{"203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555"}, + expectedStatus: []int{204, 204, 204, 200, 429}, + needOptions: true, + }, + { + name: "client isolation", + duration: 100 * time.Millisecond, + limit: 1, + methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, + remoteAddrs: []string{"203.0.113.10:5000", "203.0.113.11:5000", "203.0.113.10:6000"}, + expectedStatus: []int{200, 200, 429}, + }, } - resp2 := doRequest(r, http.MethodPost, "203.0.113.1:1000") - if resp2.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp2.Code) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rl := middleware.NewRateLimiter(tc.duration, tc.limit, time.Minute) + r := testutil.NewMiddlewareTestRouter(rl.RateLimit(), middleware.ErrorHandler()) + if tc.needOptions { + r.OPTIONS("/middleware-test", func(c *gin.Context) { + c.Status(204) + }) + } + + for i := 0; i < len(tc.methods); i++ { + req, _ := http.NewRequest(tc.methods[i], "/middleware-test", nil) + req.RemoteAddr = tc.remoteAddrs[i] + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus[i] { + t.Fatalf("expected: %d, got: %d", tc.expectedStatus[i], w.Code) + } + + if tc.sleep > 0 && i == 1 { + time.Sleep(tc.sleep) + } + } + }) } - - blocked := doRequest(r, http.MethodPost, "203.0.113.1:1000") - if blocked.Code != http.StatusTooManyRequests { - t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, blocked.Code) - } - assertErrorMessage(t, blocked, "Too many requests") } -func TestRateLimiterResetsAfterWindow(t *testing.T) { - rl := middleware.NewRateLimiter(30*time.Millisecond, 1, time.Minute) - r := newRateLimiterRouter(rl) - - resp1 := doRequest(r, http.MethodPost, "198.51.100.3:9999") - if resp1.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) - } - - resp2 := doRequest(r, http.MethodPost, "198.51.100.3:9999") - if resp2.Code != http.StatusTooManyRequests { - t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp2.Code) - } - - time.Sleep(60 * time.Millisecond) - - resp3 := doRequest(r, http.MethodPost, "198.51.100.3:9999") - if resp3.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp3.Code) - } -} - -func TestRateLimiterOptionsBypass(t *testing.T) { - rl := middleware.NewRateLimiter(100*time.Millisecond, 1, time.Minute) - r := newRateLimiterRouter(rl) - - for i := 0; i < 3; i++ { - resp := doRequest(r, http.MethodOptions, "203.0.113.2:5555") - if resp.Code != http.StatusNoContent { - t.Fatalf("expected status %d, got %d", http.StatusNoContent, resp.Code) - } - } - - resp1 := doRequest(r, http.MethodPost, "203.0.113.2:5555") - if resp1.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) - } - - resp2 := doRequest(r, http.MethodPost, "203.0.113.2:5555") - if resp2.Code != http.StatusTooManyRequests { - t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp2.Code) - } -} - -func TestRateLimiterClientIsolation(t *testing.T) { - rl := middleware.NewRateLimiter(100*time.Millisecond, 1, time.Minute) - r := newRateLimiterRouter(rl) - - resp1 := doRequest(r, http.MethodPost, "203.0.113.10:5000") - if resp1.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp1.Code) +func TestAllowRequest(t *testing.T) { + type step struct { + sleep time.Duration + client string + expect bool } - resp2 := doRequest(r, http.MethodPost, "203.0.113.11:5000") - if resp2.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, resp2.Code) + testCases := []struct { + name string + duration time.Duration + limit int + cleanup time.Duration + steps []step + }{ + { + name: "limit reached", + duration: 50 * time.Millisecond, + limit: 1, + cleanup: time.Minute, + steps: []step{ + {client: "client-a", expect: true}, + {client: "client-a", expect: false}, + }, + }, + { + name: "cleanup path", + duration: 50 * time.Millisecond, + limit: 1, + cleanup: 1 * time.Millisecond, + steps: []step{ + {client: "client-a", expect: true}, + {sleep: 2 * time.Millisecond}, + {client: "client-b", expect: true}, + }, + }, } - resp3 := doRequest(r, http.MethodPost, "203.0.113.10:6000") - if resp3.Code != http.StatusTooManyRequests { - t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp3.Code) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rl := middleware.NewRateLimiter(tc.duration, tc.limit, tc.cleanup) + + for _, s := range tc.steps { + if s.sleep > 0 { + time.Sleep(s.sleep) + continue + } + + allowed := rl.AllowRequest(s.client) + if allowed != s.expect { + t.Fatalf("expected: %v, got: %v", s.expect, allowed) + } + } + }) } } diff --git a/backend/internal/middleware/validation_test.go b/backend/internal/middleware/validation_test.go index 3d1b996..cfe4b7b 100644 --- a/backend/internal/middleware/validation_test.go +++ b/backend/internal/middleware/validation_test.go @@ -1,123 +1,49 @@ package middleware_test import ( - "bytes" - "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" - "github.com/gin-gonic/gin" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" ) -func TestValidateBodyPassesValidPayloadAndStoresInContext(t *testing.T) { - gin.SetMode(gin.TestMode) - dto.InitValidator() - - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.ValidateBody[dto.UserName]()) - r.POST("/ok", func(c *gin.Context) { - val, exists := c.Get("validatedBody") - if !exists { - c.JSON(http.StatusInternalServerError, gin.H{"error": "validatedBody missing"}) - return - } - - name, ok := val.(dto.UserName) - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "wrong type"}) - return - } - - c.JSON(http.StatusOK, gin.H{"username": name.Username}) - }) - - payload := dto.UserName{Username: "valid_user"} - body, _ := json.Marshal(payload) - req := httptest.NewRequest(http.MethodPost, "/ok", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d body=%s", resp.Code, resp.Body.String()) - } - - var respBody map[string]string - if err := json.Unmarshal(resp.Body.Bytes(), &respBody); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if respBody["username"] != "valid_user" { - t.Fatalf("expected username to propagate from validatedBody, got %v", respBody) - } -} - -func TestValidateBodyHandlesBindErrors(t *testing.T) { - gin.SetMode(gin.TestMode) - dto.InitValidator() - - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.ValidateBody[dto.UserName]()) - r.POST("/bad", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - - req := httptest.NewRequest(http.MethodPost, "/bad", bytes.NewBufferString("not-json")) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected status 400, got %d body=%s", resp.Code, resp.Body.String()) - } - - var respBody map[string]any - if err := json.Unmarshal(resp.Body.Bytes(), &respBody); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if respBody["error"] == nil { - t.Fatalf("expected error field in response, got %v", respBody) - } +type testPayload struct { + Name string `json:"name" validate:"required"` } -func TestValidateBodyReturnsValidationErrors(t *testing.T) { - gin.SetMode(gin.TestMode) +func TestValidateBody(t *testing.T) { dto.InitValidator() - r := gin.New() - r.Use(middleware.ErrorHandler()) - r.Use(middleware.ValidateBody[dto.UserName]()) - r.POST("/fail", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - - // too short username triggers validator rule - req := httptest.NewRequest(http.MethodPost, "/fail", bytes.NewBufferString(`{"username":"ab"}`)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - r.ServeHTTP(resp, req) - - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected status 400, got %d body=%s", resp.Code, resp.Body.String()) - } - - var respBody map[string]any - if err := json.Unmarshal(resp.Body.Bytes(), &respBody); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - errorsField, ok := respBody["error"].([]any) - if !ok || len(errorsField) == 0 { - t.Fatalf("expected validation errors array, got %v", respBody) + testCases := []struct { + name string + payload string + expectedStatus int + }{ + {name: "success", payload: `{"name":"ok"}`, expectedStatus: 200}, + {name: "validation error", payload: `{}`, expectedStatus: 400}, + {name: "invalid json", payload: `{`, expectedStatus: 400}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := testutil.NewMiddlewareTestRouter( + middleware.ValidateBody[testPayload](), + middleware.ErrorHandler(), + ) + reqBody := strings.NewReader(tc.payload) + req, _ := http.NewRequest(http.MethodPost, "/middleware-test", reqBody) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus { + t.Fatalf("expected: %d, got: %d", tc.expectedStatus, w.Code) + } + }) } } From 8b340daf428e8b44bdb27eab4ef8198245463ac8 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:30:56 +0200 Subject: [PATCH 10/15] refactor/backend: improve JWT token tests --- backend/internal/util/jwt/token_test.go | 247 +++++++++++++----------- 1 file changed, 138 insertions(+), 109 deletions(-) diff --git a/backend/internal/util/jwt/token_test.go b/backend/internal/util/jwt/token_test.go index 299efcd..29dc6ab 100644 --- a/backend/internal/util/jwt/token_test.go +++ b/backend/internal/util/jwt/token_test.go @@ -1,122 +1,151 @@ package jwt_test import ( - "errors" "testing" - "time" - libjwt "github.com/golang-jwt/jwt/v5" - - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" "github.com/paularynty/transcendence/auth-service-go/internal/testutil" "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" ) -func setupTokenDep(t *testing.T) *dependency.Dependency { - t.Helper() - cfg := testutil.NewTestConfig() - cfg.JwtSecret = "test-secret-key" - cfg.UserTokenExpiry = 3600 - cfg.OauthStateTokenExpiry = 120 - cfg.TwoFaTokenExpiry = 300 - return testutil.NewTestDependency(cfg, nil, nil, nil) +var testDep = testutil.NewTestDependency(nil, nil, nil, nil) + +func TestUserToken(t *testing.T) { + token, err := jwt.SignUserToken(testDep, 3) + if err != nil { + t.Fatalf("failed to generate token, got an error: %v", err) + } + + parsed, err := jwt.ValidateUserTokenGeneric(testDep, token) + if err != nil { + t.Fatalf("faled to parse user token, got an error: %v", err) + } + + if parsed.Type != jwt.UserTokenType { + t.Fatalf("expected token type: %s, got %s", jwt.UserTokenType, parsed.Type) + } + + if parsed.UserID != 3 { + t.Fatalf("expected userID: %d, got %d", 3, parsed.UserID) + } + + _, err = jwt.ValidateUserTokenGeneric(testDep, "aaa") + if err == nil { + t.Fatalf("expected error, got nil") + } + + invalidToken, err := jwt.SignOauthStateToken(testDep) + if err != nil { + t.Fatalf("failed to generate OauthStateToken") + } + + _, err = jwt.ValidateUserTokenGeneric(testDep, invalidToken) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestOauthStateToken(t *testing.T) { + token, err := jwt.SignOauthStateToken(testDep) + if err != nil { + t.Fatalf("failed to generate token, got an error: %v", err) + } + + parsed, err := jwt.ValidateOauthStateToken(testDep, token) + if err != nil { + t.Fatalf("faled to parse oauth state token, got an error: %v", err) + } + + if parsed.Type != jwt.GoogleOAuthStateType { + t.Fatalf("expected token type: %s, got %s", jwt.GoogleOAuthStateType, parsed.Type) + } + + _, err = jwt.ValidateOauthStateToken(testDep, "aaa") + if err == nil { + t.Fatalf("expected error, got nil") + } + + invalidToken, err := jwt.SignUserToken(testDep, 3) + if err != nil { + t.Fatalf("failed to generate user token") + } + + _, err = jwt.ValidateOauthStateToken(testDep, invalidToken) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestTwoFASetupToken(t *testing.T) { + token, err := jwt.SignTwoFASetupToken(testDep, 7, "test-secret") + if err != nil { + t.Fatalf("failed to generate token, got an error: %v", err) + } + + parsed, err := jwt.ValidateTwoFASetupToken(testDep, token) + if err != nil { + t.Fatalf("faled to parse two-fa setup token, got an error: %v", err) + } + + if parsed.Type != jwt.TwoFASetupType { + t.Fatalf("expected token type: %s, got %s", jwt.TwoFASetupType, parsed.Type) + } + + if parsed.UserID != 7 { + t.Fatalf("expected userID: %d, got %d", 7, parsed.UserID) + } + + if parsed.Secret != "test-secret" { + t.Fatalf("expected secret: %s, got %s", "test-secret", parsed.Secret) + } + + _, err = jwt.ValidateTwoFASetupToken(testDep, "aaa") + if err == nil { + t.Fatalf("expected error, got nil") + } + + invalidToken, err := jwt.SignTwoFAToken(testDep, 7) + if err != nil { + t.Fatalf("failed to generate two-fa token") + } + + _, err = jwt.ValidateTwoFASetupToken(testDep, invalidToken) + if err == nil { + t.Fatalf("expected error, got nil") + } } -func TestTokenRoundTrip(t *testing.T) { - dep := setupTokenDep(t) - - cases := []struct { - name string - sign func() (string, error) - validate func(string) (any, error) - assert func(t *testing.T, claims any) - expectedError error - }{ - { - name: "UserToken", - sign: func() (string, error) { - return jwt.SignUserToken(dep, 42) - }, - validate: func(token string) (any, error) { - return jwt.ValidateUserTokenGeneric(dep, token) - }, - assert: func(t *testing.T, claims any) { - parsed := claims.(*dto.UserJwtPayload) - if parsed.UserID != 42 { - t.Fatalf("expected user id 42, got %d", parsed.UserID) - } - if parsed.Type != jwt.UserTokenType { - t.Fatalf("expected claim type %q, got %q", jwt.UserTokenType, parsed.Type) - } - if parsed.ExpiresAt == nil || parsed.ExpiresAt.Before(time.Now()) { - t.Fatalf("expected future expiration, got %v", parsed.ExpiresAt) - } - }, - }, - { - name: "OauthStateToken", - sign: func() (string, error) { - return jwt.SignOauthStateToken(dep) - }, - validate: func(token string) (any, error) { - return jwt.ValidateOauthStateToken(dep, token) - }, - assert: func(t *testing.T, claims any) { - parsed := claims.(*dto.OauthStateJwtPayload) - if parsed.Type != jwt.GoogleOAuthStateType { - t.Fatalf("expected oauth state type %q, got %q", jwt.GoogleOAuthStateType, parsed.Type) - } - }, - }, - { - name: "TwoFASetupToken", - sign: func() (string, error) { - return jwt.SignTwoFASetupToken(dep, 7, "secret") - }, - validate: func(token string) (any, error) { - return jwt.ValidateTwoFASetupToken(dep, token) - }, - assert: func(t *testing.T, claims any) { - parsed := claims.(*dto.TwoFaSetupJwtPayload) - if parsed.Secret != "secret" { - t.Fatalf("expected secret to be propagated, got %q", parsed.Secret) - } - }, - }, - { - name: "UserTokenRejectsWrongType", - sign: func() (string, error) { - return jwt.SignTwoFAToken(dep, 10) - }, - validate: func(token string) (any, error) { - return jwt.ValidateUserTokenGeneric(dep, token) - }, - expectedError: libjwt.ErrTokenInvalidClaims, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - token, err := tc.sign() - if err != nil { - t.Fatalf("sign returned error: %v", err) - } - - claims, err := tc.validate(token) - if tc.expectedError != nil { - if !errors.Is(err, tc.expectedError) { - t.Fatalf("expected %v, got %v", tc.expectedError, err) - } - return - } - if err != nil { - t.Fatalf("validate returned error: %v", err) - } - - if tc.assert != nil { - tc.assert(t, claims) - } - }) +func TestTwoFAToken(t *testing.T) { + token, err := jwt.SignTwoFAToken(testDep, 11) + if err != nil { + t.Fatalf("failed to generate token, got an error: %v", err) + } + + parsed, err := jwt.ValidateTwoFAToken(testDep, token) + if err != nil { + t.Fatalf("faled to parse two-fa token, got an error: %v", err) + } + + if parsed.Type != jwt.TwoFATokenType { + t.Fatalf("expected token type: %s, got %s", jwt.TwoFATokenType, parsed.Type) + } + + if parsed.UserID != 11 { + t.Fatalf("expected userID: %d, got %d", 11, parsed.UserID) + } + + _, err = jwt.ValidateTwoFAToken(testDep, "aaa") + if err == nil { + t.Fatalf("expected error, got nil") + } + + invalidToken, err := jwt.SignTwoFASetupToken(testDep, 11, "test-secret") + if err != nil { + t.Fatalf("failed to generate two-fa setup token") + } + + _, err = jwt.ValidateTwoFAToken(testDep, invalidToken) + if err == nil { + t.Fatalf("expected error, got nil") } } + From 1d450261c2fea7c6019eb388c2ed610c72700d2a Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:41:14 +0200 Subject: [PATCH 11/15] refactor/backend: enhance rate limiter tests --- .../internal/middleware/rate_limiter_test.go | 69 ++++++++++++++----- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index be8b5c8..f5d71e5 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -13,11 +13,21 @@ import ( ) func TestRateLimiter(t *testing.T) { + testDurationShort := 30 * time.Millisecond + testDurationMedium := 50 * time.Millisecond + testDurationLong := 100 * time.Millisecond + testLimitLow := 1 + testLimitMedium := 2 + testSleepShort := 20 * time.Millisecond + testSleepReset := 60 * time.Millisecond + testCleanup := time.Minute + testCases := []struct { name string duration time.Duration limit int sleep time.Duration + sleepAfter int methods []string remoteAddrs []string expectedStatus []int @@ -25,25 +35,45 @@ func TestRateLimiter(t *testing.T) { }{ { name: "blocks after limit", - duration: 100 * time.Millisecond, - limit: 2, + duration: testDurationLong, + limit: testLimitMedium, methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, remoteAddrs: []string{"203.0.113.1:1000", "203.0.113.1:1000", "203.0.113.1:1000"}, expectedStatus: []int{200, 200, 429}, }, + { + name: "blocks within window", + duration: testDurationMedium, + limit: testLimitLow, + sleep: testSleepShort, + sleepAfter: 0, + methods: []string{http.MethodPost, http.MethodPost}, + remoteAddrs: []string{"203.0.113.9:1111", "203.0.113.9:1111"}, + expectedStatus: []int{200, 429}, + }, { name: "resets after window", - duration: 30 * time.Millisecond, - limit: 1, - sleep: 60 * time.Millisecond, + duration: testDurationShort, + limit: testLimitLow, + sleep: testSleepReset, + sleepAfter: 1, methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, remoteAddrs: []string{"198.51.100.3:9999", "198.51.100.3:9999", "198.51.100.3:9999"}, expectedStatus: []int{200, 429, 200}, }, + { + name: "options not limited", + duration: testDurationLong, + limit: testLimitLow, + methods: []string{http.MethodOptions, http.MethodOptions, http.MethodOptions}, + remoteAddrs: []string{"203.0.113.8:4444", "203.0.113.8:4444", "203.0.113.8:4444"}, + expectedStatus: []int{204, 204, 204}, + needOptions: true, + }, { name: "options bypass", - duration: 100 * time.Millisecond, - limit: 1, + duration: testDurationLong, + limit: testLimitLow, methods: []string{http.MethodOptions, http.MethodOptions, http.MethodOptions, http.MethodPost, http.MethodPost}, remoteAddrs: []string{"203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555", "203.0.113.2:5555"}, expectedStatus: []int{204, 204, 204, 200, 429}, @@ -51,8 +81,8 @@ func TestRateLimiter(t *testing.T) { }, { name: "client isolation", - duration: 100 * time.Millisecond, - limit: 1, + duration: testDurationLong, + limit: testLimitLow, methods: []string{http.MethodPost, http.MethodPost, http.MethodPost}, remoteAddrs: []string{"203.0.113.10:5000", "203.0.113.11:5000", "203.0.113.10:6000"}, expectedStatus: []int{200, 200, 429}, @@ -61,7 +91,7 @@ func TestRateLimiter(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - rl := middleware.NewRateLimiter(tc.duration, tc.limit, time.Minute) + rl := middleware.NewRateLimiter(tc.duration, tc.limit, testCleanup) r := testutil.NewMiddlewareTestRouter(rl.RateLimit(), middleware.ErrorHandler()) if tc.needOptions { r.OPTIONS("/middleware-test", func(c *gin.Context) { @@ -80,7 +110,7 @@ func TestRateLimiter(t *testing.T) { t.Fatalf("expected: %d, got: %d", tc.expectedStatus[i], w.Code) } - if tc.sleep > 0 && i == 1 { + if tc.sleep > 0 && i == tc.sleepAfter { time.Sleep(tc.sleep) } } @@ -89,6 +119,11 @@ func TestRateLimiter(t *testing.T) { } func TestAllowRequest(t *testing.T) { + testDuration := 50 * time.Millisecond + testLimit := 1 + testCleanup := time.Minute + testCleanupFast := 1 * time.Millisecond + type step struct { sleep time.Duration client string @@ -104,9 +139,9 @@ func TestAllowRequest(t *testing.T) { }{ { name: "limit reached", - duration: 50 * time.Millisecond, - limit: 1, - cleanup: time.Minute, + duration: testDuration, + limit: testLimit, + cleanup: testCleanup, steps: []step{ {client: "client-a", expect: true}, {client: "client-a", expect: false}, @@ -114,9 +149,9 @@ func TestAllowRequest(t *testing.T) { }, { name: "cleanup path", - duration: 50 * time.Millisecond, - limit: 1, - cleanup: 1 * time.Millisecond, + duration: testDuration, + limit: testLimit, + cleanup: testCleanupFast, steps: []step{ {client: "client-a", expect: true}, {sleep: 2 * time.Millisecond}, From a0f30749b659686123c49df889299e8331b260f8 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:50:36 +0200 Subject: [PATCH 12/15] refactor/backend: remove testing to prepare the refactoring --- backend/internal/routers/users_router_test.go | 860 ------------------ .../internal/service/friend_service_test.go | 171 ---- .../service/google_oauth_service_test.go | 445 --------- backend/internal/service/helper_test.go | 135 --- .../internal/service/redis_service_test.go | 344 ------- backend/internal/service/setup_test.go | 100 -- .../internal/service/twofa_service_test.go | 446 --------- backend/internal/service/user_service_test.go | 685 -------------- 8 files changed, 3186 deletions(-) delete mode 100644 backend/internal/routers/users_router_test.go delete mode 100644 backend/internal/service/friend_service_test.go delete mode 100644 backend/internal/service/google_oauth_service_test.go delete mode 100644 backend/internal/service/helper_test.go delete mode 100644 backend/internal/service/redis_service_test.go delete mode 100644 backend/internal/service/setup_test.go delete mode 100644 backend/internal/service/twofa_service_test.go delete mode 100644 backend/internal/service/user_service_test.go diff --git a/backend/internal/routers/users_router_test.go b/backend/internal/routers/users_router_test.go deleted file mode 100644 index 342b7e3..0000000 --- a/backend/internal/routers/users_router_test.go +++ /dev/null @@ -1,860 +0,0 @@ -package routers - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "net/url" - "strconv" - "strings" - "testing" - "time" - - "cloud.google.com/go/auth/credentials/idtoken" - "github.com/alicebob/miniredis/v2" - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "github.com/redis/go-redis/v9" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/service" - "github.com/paularynty/transcendence/auth-service-go/internal/testutil" - "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" -) - -func mustNewUserService(t *testing.T, dep *dependency.Dependency) *service.UserService { - t.Helper() - svc, err := service.NewUserService(dep) - if err != nil { - t.Fatalf("failed to create user service: %v", err) - } - return svc -} - -type usersRouterEnv struct { - router *gin.Engine - dep *dependency.Dependency - mr *miniredis.Miniredis - cleanup func() -} - -func setupUsersRouterTest(t *testing.T, useRedis bool) *usersRouterEnv { - t.Helper() - gin.SetMode(gin.TestMode) - - logger := testutil.NewTestLogger() - cfg := testutil.NewTestConfig() - cfg.JwtSecret = "test-secret" - cfg.UserTokenExpiry = 3600 - cfg.UserTokenAbsoluteExpiry = 600 - cfg.TwoFaTokenExpiry = 3600 - cfg.OauthStateTokenExpiry = 3600 - cfg.GoogleClientId = "test-client" - cfg.GoogleRedirectUri = "http://localhost/cb" - cfg.FrontendUrl = "http://localhost:3000" - - if useRedis { - cfg.IsRedisEnabled = true - } - - dto.InitValidator() - - dbName := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000&_foreign_keys=on" - dbConn, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{TranslateError: true}) - if err != nil { - t.Fatalf("failed to connect to db: %v", err) - } - - dbConn.Exec("PRAGMA foreign_keys = ON") - if err := dbConn.AutoMigrate(&model.User{}, &model.Friend{}, &model.Token{}, &model.HeartBeat{}); err != nil { - t.Fatalf("failed to migrate db: %v", err) - } - - var mr *miniredis.Miniredis - var redisClient *redis.Client - if useRedis { - mr = miniredis.RunT(t) - redisClient = redis.NewClient(&redis.Options{Addr: mr.Addr()}) - cfg.RedisURL = "redis://" + mr.Addr() - } - - dep := dependency.NewDependency(cfg, dbConn, redisClient, logger) - router := gin.New() - userService := mustNewUserService(t, dep) - UsersRouter(router.Group("/users"), userService) - - if sqlDB, err := dbConn.DB(); err == nil && sqlDB != nil { - sqlDB.SetMaxOpenConns(1) - } - - cleanup := func() { - if redisClient != nil { - _ = redisClient.Close() - } - if mr != nil { - mr.Close() - } - if sqlDB, err := dbConn.DB(); err == nil && sqlDB != nil { - _ = sqlDB.Close() - } - } - - return &usersRouterEnv{ - router: router, - dep: dep, - mr: mr, - cleanup: cleanup, - } -} - -func signUserToken(t *testing.T, dep *dependency.Dependency, userID uint) string { - t.Helper() - token, err := jwt.SignUserToken(dep, userID) - if err != nil { - t.Fatalf("failed to sign user token: %v", err) - } - return token -} - -func addUserToken(t *testing.T, dep *dependency.Dependency, userID uint) string { - t.Helper() - token := signUserToken(t, dep, userID) - if err := dep.DB.Create(&model.Token{UserID: userID, Token: token}).Error; err != nil { - t.Fatalf("failed to insert token: %v", err) - } - return token -} - -func createUser(t *testing.T, dep *dependency.Dependency, username, email, password string) *dto.UserWithoutTokenResponse { - t.Helper() - svc := mustNewUserService(t, dep) - user, err := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: username}, Email: email}, - Password: dto.Password{Password: password}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - return user -} - -func TestUsersRouter_CreateUser(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - reqBody := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "newuser"}, Email: "new@example.com"}, - Password: dto.Password{Password: "password123"}, - } - body, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusCreated { - t.Fatalf("expected status 201, got %d. Body: %s", resp.Code, resp.Body.String()) - } - - var user dto.UserWithoutTokenResponse - _ = json.Unmarshal(resp.Body.Bytes(), &user) - if user.Username != "newuser" { - t.Errorf("expected username newuser, got %s", user.Username) - } -} - -func TestUsersRouter_CreateUser_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - cases := []struct { - name string - body string - wantStatus int - }{ - {"InvalidBody", `{"username": "u"}`, http.StatusBadRequest}, - {"DuplicateUser", "duplicate", http.StatusConflict}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - body := tc.body - if tc.name == "DuplicateUser" { - validReq := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "dupuser"}, Email: "dup@e.com"}, - Password: dto.Password{Password: "pass123"}, - } - payload, _ := json.Marshal(validReq) - env.router.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(payload))) - body = string(payload) - } - - req := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - }) - } -} - -func TestUsersRouter_LoginUser(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - createReq := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "loginuser"}, Email: "login@example.com"}, - Password: dto.Password{Password: "password123"}, - } - createBody, _ := json.Marshal(createReq) - - cReq := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(createBody)) - cReq.Header.Set("Content-Type", "application/json") - env.router.ServeHTTP(httptest.NewRecorder(), cReq) - - loginReq := dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "loginuser"}, - Password: dto.Password{Password: "password123"}, - } - body, _ := json.Marshal(loginReq) - - req := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) - } - - var res dto.UserWithTokenResponse - _ = json.Unmarshal(resp.Body.Bytes(), &res) - if res.Token == "" { - t.Errorf("expected token in response. Body: %s", resp.Body.String()) - } -} - -func TestUsersRouter_LoginUser_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - cases := []struct { - name string - body string - setup func() - wantStatus int - }{ - {"InvalidBody", `{}`, nil, http.StatusBadRequest}, - {"UserNotFound", "missing", nil, http.StatusUnauthorized}, - {"WrongPassword", "wrong", func() { - _ = createUser(t, env.dep, "loginfail", "fail@e.com", "correct123") - }, http.StatusUnauthorized}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - - var payload []byte - switch tc.body { - case "missing": - loginReq := dto.LoginUserRequest{Identifier: dto.Identifier{Identifier: "missing"}, Password: dto.Password{Password: "pass123"}} - payload, _ = json.Marshal(loginReq) - case "wrong": - loginReq := dto.LoginUserRequest{Identifier: dto.Identifier{Identifier: "loginfail"}, Password: dto.Password{Password: "wrong123"}} - payload, _ = json.Marshal(loginReq) - default: - payload = []byte(tc.body) - } - - req := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(payload)) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - }) - } -} - -func TestUsersRouter_GetProfile(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := model.User{Username: "profileuser", Email: "profile@example.com"} - env.dep.DB.Create(&user) - tokenStr := addUserToken(t, env.dep, user.ID) - - req := httptest.NewRequest(http.MethodGet, "/users/me", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) - } - - var res dto.UserWithoutTokenResponse - _ = json.Unmarshal(resp.Body.Bytes(), &res) - if res.Username != "profileuser" { - t.Errorf("expected username profileuser, got %s", res.Username) - } -} - -func TestUsersRouter_Unauthorized(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - req := httptest.NewRequest(http.MethodGet, "/users/me", nil) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusUnauthorized { - t.Fatalf("expected status 401, got %d", resp.Code) - } -} - -func TestUsersRouter_UpdateUserProfile(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := model.User{Username: "u", Email: "u@e.com"} - env.dep.DB.Create(&user) - tokenStr := addUserToken(t, env.dep, user.ID) - - newAvatar := "http://pic.com/1.png" - reqBody := dto.UpdateUserRequest{User: dto.User{UserName: dto.UserName{Username: "newname"}, Email: "new@e.com", Avatar: &newAvatar}} - body, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPut, "/users/me", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) - } - - var res dto.UserWithoutTokenResponse - _ = json.Unmarshal(resp.Body.Bytes(), &res) - if res.Username != "newname" { - t.Errorf("expected new username, got %s", res.Username) - } -} - -func TestUsersRouter_UpdateUser_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - svc := mustNewUserService(t, env.dep) - u, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, - Password: dto.Password{Password: "pass123"}, - }) - token := addUserToken(t, env.dep, u.ID) - - _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, - Password: dto.Password{Password: "pass123"}, - }) - - cases := []struct { - name string - method string - path string - body any - wantStatus int - }{ - {"DuplicateProfile", http.MethodPut, "/users/me", dto.UpdateUserRequest{User: dto.User{UserName: dto.UserName{Username: "update_u2"}, Email: "u2@e.com"}}, http.StatusConflict}, - {"WrongOldPassword", http.MethodPut, "/users/password", dto.UpdateUserPasswordRequest{OldPassword: dto.OldPassword{OldPassword: "wrong123"}, NewPassword: dto.NewPassword{NewPassword: "newpass"}}, http.StatusUnauthorized}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - payload, _ := json.Marshal(tc.body) - req := httptest.NewRequest(tc.method, tc.path, bytes.NewBuffer(payload)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - }) - } -} - -func TestUsersRouter_UpdateUserPassword(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - svc := mustNewUserService(t, env.dep) - userResp, _ := svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "pw"}, Email: "pw@e.com"}, - Password: dto.Password{Password: "oldpass"}, - }) - tokenStr := addUserToken(t, env.dep, userResp.ID) - - reqBody := dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "oldpass"}, - NewPassword: dto.NewPassword{NewPassword: "newpass"}, - } - body, _ := json.Marshal(reqBody) - - req := httptest.NewRequest(http.MethodPut, "/users/password", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) - } -} - -func TestUsersRouter_DeleteUser(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := model.User{Username: "del", Email: "del@e.com"} - env.dep.DB.Create(&user) - tokenStr := addUserToken(t, env.dep, user.ID) - - time.Sleep(500 * time.Millisecond) - - req := httptest.NewRequest(http.MethodDelete, "/users/me", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusNoContent { - t.Fatalf("expected status 204, got %d", resp.Code) - } -} - -func TestUsersRouter_GetUsersWithLimitedInfo(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := model.User{Username: "list", Email: "list@e.com"} - env.dep.DB.Create(&user) - tokenStr := addUserToken(t, env.dep, user.ID) - - req := httptest.NewRequest(http.MethodGet, "/users/", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } -} - -func TestUsersRouter_ValidateUser(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := model.User{Username: "val", Email: "val@e.com"} - env.dep.DB.Create(&user) - tokenStr := addUserToken(t, env.dep, user.ID) - - req := httptest.NewRequest(http.MethodPost, "/users/validate", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - - var res dto.UserValidationResponse - _ = json.Unmarshal(resp.Body.Bytes(), &res) - if res.UserID != user.ID { - t.Errorf("expected userID %d, got %d", user.ID, res.UserID) - } -} - -func TestUsersRouter_Friends(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - svc := mustNewUserService(t, env.dep) - u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") - u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") - _ = svc - - tokenStr := addUserToken(t, env.dep, u1.ID) - - reqBody := dto.AddNewFriendRequest{UserID: u2.ID} - body, _ := json.Marshal(reqBody) - req := httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusCreated { - t.Fatalf("expected status 201, got %d. Body: %s", resp.Code, resp.Body.String()) - } - - req = httptest.NewRequest(http.MethodGet, "/users/friends", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp = httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", resp.Code) - } - var friends []dto.FriendResponse - _ = json.Unmarshal(resp.Body.Bytes(), &friends) - if len(friends) != 1 || friends[0].ID != u2.ID { - t.Error("expected friend f2") - } -} - -func TestUsersRouter_Friends_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - svc := mustNewUserService(t, env.dep) - u1 := createUser(t, env.dep, "f1", "f1@e.com", "pass123") - u2 := createUser(t, env.dep, "f2", "f2@e.com", "pass123") - token := addUserToken(t, env.dep, u1.ID) - - cases := []struct { - name string - payload dto.AddNewFriendRequest - setup func() - wantStatus int - }{ - {"AddSelf", dto.AddNewFriendRequest{UserID: u1.ID}, nil, http.StatusBadRequest}, - {"AddMissing", dto.AddNewFriendRequest{UserID: 999}, nil, http.StatusNotFound}, - {"Duplicate", dto.AddNewFriendRequest{UserID: u2.ID}, func() { - _ = svc.AddNewFriend(context.Background(), u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - }, http.StatusConflict}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - if tc.name == "Duplicate" { - time.Sleep(200 * time.Millisecond) - } - } - body, _ := json.Marshal(tc.payload) - req := httptest.NewRequest(http.MethodPost, "/users/friends", bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - }) - } -} - -func TestUsersRouter_2FA(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - user := createUser(t, env.dep, "2fa", "2fa@e.com", "pass123") - tokenStr := addUserToken(t, env.dep, user.ID) - - req := httptest.NewRequest(http.MethodPost, "/users/2fa/setup", nil) - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("setup failed: %d", resp.Code) - } - var setupRes dto.TwoFASetupResponse - _ = json.Unmarshal(resp.Body.Bytes(), &setupRes) - - time.Sleep(200 * time.Millisecond) - - code, _ := totp.GenerateCode(setupRes.TwoFASecret, time.Now()) - confirmBody, _ := json.Marshal(dto.TwoFAConfirmRequest{SetupToken: setupRes.SetupToken, TwoFACode: code}) - req = httptest.NewRequest(http.MethodPost, "/users/2fa/confirm", bytes.NewBuffer(confirmBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp = httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("confirm failed: %d", resp.Code) - } - - sessionToken, _ := jwt.SignTwoFAToken(env.dep, user.ID) - code, _ = totp.GenerateCode(setupRes.TwoFASecret, time.Now()) - challengeBody, _ := json.Marshal(dto.TwoFAChallengeRequest{SessionToken: sessionToken, TwoFACode: code}) - req = httptest.NewRequest(http.MethodPost, "/users/2fa", bytes.NewBuffer(challengeBody)) - req.Header.Set("Content-Type", "application/json") - resp = httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("challenge failed: %d body: %s", resp.Code, resp.Body.String()) - } - - var userRes dto.UserWithTokenResponse - _ = json.Unmarshal(resp.Body.Bytes(), &userRes) - tokenStr = userRes.Token - - time.Sleep(200 * time.Millisecond) - - disableBody, _ := json.Marshal(dto.DisableTwoFARequest{Password: dto.Password{Password: "pass123"}}) - req = httptest.NewRequest(http.MethodPut, "/users/2fa/disable", bytes.NewBuffer(disableBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+tokenStr) - resp = httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusOK { - t.Fatalf("disable failed: %d", resp.Code) - } -} - -func TestUsersRouter_2FA_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - svc := mustNewUserService(t, env.dep) - u := createUser(t, env.dep, "2fafail", "2fafail@e.com", "pass123") - token := addUserToken(t, env.dep, u.ID) - - setupResp, _ := svc.StartTwoFaSetup(context.Background(), u.ID) - - cases := []struct { - name string - method string - path string - body any - wantStatus int - setup func() - }{ - {"InvalidCode", http.MethodPost, "/users/2fa/confirm", dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: "000000"}, http.StatusBadRequest, nil}, - {"SetupAlreadyEnabled", http.MethodPost, "/users/2fa/setup", nil, http.StatusBadRequest, func() { - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - confirmRes, _ := svc.ConfirmTwoFaSetup(context.Background(), u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - token = confirmRes.Token - }}, - {"WrongDisablePassword", http.MethodPut, "/users/2fa/disable", dto.DisableTwoFARequest{Password: dto.Password{Password: "wrong123"}}, http.StatusUnauthorized, func() { - if token == "" { - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - confirmRes, _ := svc.ConfirmTwoFaSetup(context.Background(), u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - token = confirmRes.Token - } - }}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - var body []byte - if tc.body != nil { - body, _ = json.Marshal(tc.body) - } - req := httptest.NewRequest(tc.method, tc.path, bytes.NewBuffer(body)) - req.Header.Set("Authorization", "Bearer "+token) - if tc.body != nil { - req.Header.Set("Content-Type", "application/json") - } - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - }) - } -} - -func TestUsersRouter_GoogleOAuth(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - req := httptest.NewRequest(http.MethodGet, "/users/google/login", nil) - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusFound { - t.Fatalf("expected status 302, got %d", resp.Code) - } - if loc := resp.Header().Get("Location"); loc == "" { - t.Error("expected location header") - } - - origExchange := service.ExchangeCodeForTokens - origFetch := service.FetchGoogleUserInfo - defer func() { - service.ExchangeCodeForTokens = origExchange - service.FetchGoogleUserInfo = origFetch - }() - - service.ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return &idtoken.Payload{Subject: "g123"}, nil - } - service.FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ID: "g123", Email: "test@google.com", Name: "Google User"}, nil - } - - state, _ := jwt.SignOauthStateToken(env.dep) - req = httptest.NewRequest(http.MethodGet, "/users/google/callback?code=valid&state="+state, nil) - resp = httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - - if resp.Code != http.StatusFound { - t.Fatalf("expected status 302, got %d", resp.Code) - } - - redirectURL, _ := url.Parse(resp.Header().Get("Location")) - token := redirectURL.Query().Get("token") - if token == "" { - t.Error("expected token in redirect") - } -} - -func TestUsersRouter_GoogleOAuth_Failures(t *testing.T) { - env := setupUsersRouterTest(t, false) - defer env.cleanup() - - cases := []struct { - name string - path string - setup func() - wantStatus int - }{ - {"MissingParams", "/users/google/callback", nil, http.StatusBadRequest}, - {"ExchangeError", "/users/google/callback?code=c&state=state", func() { - origExchange := service.ExchangeCodeForTokens - service.ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return nil, errors.New("mock error") - } - t.Cleanup(func() { service.ExchangeCodeForTokens = origExchange }) - }, http.StatusFound}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - path := tc.path - if strings.Contains(path, "state=state") { - state, _ := jwt.SignOauthStateToken(env.dep) - path = "/users/google/callback?code=c&state=" + state - } - req := httptest.NewRequest(http.MethodGet, path, nil) - resp := httptest.NewRecorder() - env.router.ServeHTTP(resp, req) - if resp.Code != tc.wantStatus { - t.Fatalf("expected %d, got %d", tc.wantStatus, resp.Code) - } - if tc.name == "ExchangeError" { - loc := resp.Header().Get("Location") - if !strings.Contains(loc, "error=") { - t.Fatalf("expected error param in redirect: %s", loc) - } - } - }) - } -} - -func TestUsersRouter_Redis_LoginValidateLogout(t *testing.T) { - env := setupUsersRouterTest(t, true) - defer env.cleanup() - - createReq := dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "redisrouter"}, Email: "redisrouter@example.com"}, - Password: dto.Password{Password: "password123"}, - } - createBody, _ := json.Marshal(createReq) - createResp := httptest.NewRecorder() - createHTTP := httptest.NewRequest(http.MethodPost, "/users/", bytes.NewBuffer(createBody)) - createHTTP.Header.Set("Content-Type", "application/json") - env.router.ServeHTTP(createResp, createHTTP) - if createResp.Code != http.StatusCreated { - t.Fatalf("expected 201 on create, got %d. Body: %s", createResp.Code, createResp.Body.String()) - } - - loginReq := dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "redisrouter"}, - Password: dto.Password{Password: "password123"}, - } - loginBody, _ := json.Marshal(loginReq) - loginResp := httptest.NewRecorder() - loginHTTP := httptest.NewRequest(http.MethodPost, "/users/loginByIdentifier", bytes.NewBuffer(loginBody)) - loginHTTP.Header.Set("Content-Type", "application/json") - env.router.ServeHTTP(loginResp, loginHTTP) - if loginResp.Code != http.StatusOK { - t.Fatalf("expected 200 on login, got %d. Body: %s", loginResp.Code, loginResp.Body.String()) - } - - var loginUser dto.UserWithTokenResponse - _ = json.Unmarshal(loginResp.Body.Bytes(), &loginUser) - if loginUser.Token == "" || loginUser.ID == 0 { - t.Fatalf("expected login to return token and id") - } - - validateResp := httptest.NewRecorder() - validateHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) - validateHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - env.router.ServeHTTP(validateResp, validateHTTP) - if validateResp.Code != http.StatusOK { - t.Fatalf("expected 200 on validate, got %d. Body: %s", validateResp.Code, validateResp.Body.String()) - } - - logoutResp := httptest.NewRecorder() - logoutHTTP := httptest.NewRequest(http.MethodDelete, "/users/logout", nil) - logoutHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - env.router.ServeHTTP(logoutResp, logoutHTTP) - if logoutResp.Code != http.StatusNoContent { - t.Fatalf("expected 204 on logout, got %d. Body: %s", logoutResp.Code, logoutResp.Body.String()) - } - - validateAfterResp := httptest.NewRecorder() - validateAfterHTTP := httptest.NewRequest(http.MethodPost, "/users/validate", nil) - validateAfterHTTP.Header.Set("Authorization", "Bearer "+loginUser.Token) - env.router.ServeHTTP(validateAfterResp, validateAfterHTTP) - if validateAfterResp.Code != http.StatusUnauthorized { - t.Fatalf("expected 401 on validate after logout, got %d. Body: %s", validateAfterResp.Code, validateAfterResp.Body.String()) - } - - time.Sleep(200 * time.Millisecond) - score, err := env.dep.Redis.ZScore(context.Background(), "heartbeat:", strconv.FormatUint(uint64(loginUser.ID), 10)).Result() - if err != nil { - t.Fatalf("expected heartbeat entry, got error: %v", err) - } - if int64(score) < time.Now().Unix()-10 { - t.Fatalf("expected recent heartbeat score, got %v", score) - } -} diff --git a/backend/internal/service/friend_service_test.go b/backend/internal/service/friend_service_test.go deleted file mode 100644 index 0534e2b..0000000 --- a/backend/internal/service/friend_service_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" - - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" -) - -func TestGetAllUsersLimitedInfo(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - // Create users - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - t.Run("Success", func(t *testing.T) { - users, err := svc.GetAllUsersLimitedInfo(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(users) < 2 { - t.Errorf("expected at least 2 users, got %d", len(users)) - } - }) - - t.Run("DBError", func(t *testing.T) { - sqlDB, _ := db.DB() - _ = sqlDB.Close() - _, err := svc.GetAllUsersLimitedInfo(ctx) - if err == nil { - t.Error("expected error on closed db") - } - }) -} - -func TestAddNewFriend(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f1"}, Email: "f1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - u2, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "f2"}, Email: "f2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - cases := []struct { - name string - userID uint - friendID uint - setup func() - wantErrStatus int - checkFK bool - }{ - {"Success", u1.ID, u2.ID, nil, 0, false}, - {"AddSelf", u1.ID, u1.ID, nil, 400, false}, - {"DuplicateFriend", u1.ID, u2.ID, func() { - _ = svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - }, 409, false}, - {"UserNotFound", u1.ID, 999, nil, 404, true}, - {"DBError", u1.ID, u2.ID, func() { - sqlDB, _ := db.DB() - _ = sqlDB.Close() - }, -1, false}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - err := svc.AddNewFriend(ctx, tc.userID, &dto.AddNewFriendRequest{UserID: tc.friendID}) - if tc.wantErrStatus == 0 { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - return - } - if tc.wantErrStatus == -1 { - if err == nil { - t.Error("expected error on closed db") - } - return - } - if err == nil && tc.checkFK { - var count int64 - db.Model(&model.Friend{}).Where("user_id = ? AND friend_id = ?", u1.ID, tc.friendID).Count(&count) - if count > 0 { - t.Fatal("expected error, but friend was added despite FK violation") - } - return - } - requireAuthStatus(t, err, tc.wantErrStatus) - }) - } -} - -func TestGetUserFriends(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u1, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "gf1"}, Email: "gf1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - u2, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "gf2"}, Email: "gf2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - // Add friend - _ = svc.AddNewFriend(ctx, u1.ID, &dto.AddNewFriendRequest{UserID: u2.ID}) - - cases := []struct { - name string - setup func() - wantOnline bool - wantErrStatus int - }{ - {"Success", nil, false, 0}, - {"OnlineFriend", func() { - db.Create(&model.HeartBeat{UserID: u2.ID, LastSeenAt: time.Now()}) - }, true, 0}, - {"DBError", func() { - sqlDB, _ := db.DB() - _ = sqlDB.Close() - }, false, -1}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - friends, err := svc.GetUserFriends(ctx, u1.ID) - if tc.wantErrStatus == -1 { - if err == nil { - t.Error("expected error on closed db") - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(friends) != 1 { - t.Errorf("expected 1 friend, got %d", len(friends)) - } - if friends[0].ID != u2.ID { - t.Errorf("expected friend ID %d, got %d", u2.ID, friends[0].ID) - } - if friends[0].Online != tc.wantOnline { - t.Errorf("expected online=%v, got %v", tc.wantOnline, friends[0].Online) - } - }) - } -} diff --git a/backend/internal/service/google_oauth_service_test.go b/backend/internal/service/google_oauth_service_test.go deleted file mode 100644 index 5141c55..0000000 --- a/backend/internal/service/google_oauth_service_test.go +++ /dev/null @@ -1,445 +0,0 @@ -package service - -import ( - "context" - "errors" - "net/url" - "testing" - - "cloud.google.com/go/auth/credentials/idtoken" - authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/testutil" - "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" -) - -func TestGetGoogleOAuthURL(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, nil)) - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - authURL, err := svc.GetGoogleOAuthURL(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - u, err := url.Parse(authURL) - if err != nil { - t.Fatalf("failed to parse url: %v", err) - } - - q := u.Query() - if q.Get("client_id") != cfg.GoogleClientId { - t.Errorf("expected client_id %s, got %s", cfg.GoogleClientId, q.Get("client_id")) - } - if q.Get("redirect_uri") != cfg.GoogleRedirectUri { - t.Errorf("expected redirect_uri %s, got %s", cfg.GoogleRedirectUri, q.Get("redirect_uri")) - } - if q.Get("state") == "" { - t.Error("expected state param") - } - }) -} - -func TestHandleGoogleOAuthCallback_InvalidState(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - // Helper to parse redirect URL - parseRedirect := func(redirectURL string) (string, string) { - u, _ := url.Parse(redirectURL) - q := u.Query() - return q.Get("token"), q.Get("error") - } - - t.Run("InvalidState", func(t *testing.T) { - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "somecode", "invalidstate") - token, errMsg := parseRedirect(redirectURL) - - if token != "" { - t.Error("expected no token") - } - if errMsg == "" { - t.Error("expected error message") - } - }) - - t.Run("ExpiredState", func(t *testing.T) { - userToken, _ := jwt.SignUserToken(svc.Dep, 1) - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "somecode", userToken) - - _, errMsg := parseRedirect(redirectURL) - if errMsg == "" { - t.Error("expected error message for wrong token type") - } - }) -} - -func TestHandleGoogleOAuthCallback_Success(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - // Mock dependencies - origExchange := ExchangeCodeForTokens - origFetch := FetchGoogleUserInfo - defer func() { - ExchangeCodeForTokens = origExchange - FetchGoogleUserInfo = origFetch - }() - - ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return &idtoken.Payload{Subject: "g123"}, nil - } - - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "g123", - Email: "test@google.com", - Name: "Google User", - }, nil - } - - state, _ := jwt.SignOauthStateToken(svc.Dep) - - t.Run("NewUser", func(t *testing.T) { - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "validcode", state) - - u, _ := url.Parse(redirectURL) - q := u.Query() - if q.Get("token") == "" { - t.Error("expected token in redirect") - } - if q.Get("error") != "" { - t.Errorf("unexpected error in redirect: %s", q.Get("error")) - } - - // Verify user created - var user model.User - err := db.Where("email = ?", "test@google.com").First(&user).Error - if err != nil { - t.Error("expected user to be created") - } - if *user.GoogleOauthID != "g123" { - t.Error("expected google oauth id to be set") - } - }) - - t.Run("ExistingUser", func(t *testing.T) { - // User already created in previous run - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "validcode", state) - - u, _ := url.Parse(redirectURL) - q := u.Query() - if q.Get("token") == "" { - t.Error("expected token in redirect") - } - }) - - t.Run("ExistingEmailLink", func(t *testing.T) { - // Create a non-OAuth user with matching email - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "emailmatch"}, Email: "linkme@google.com"}, - Password: dto.Password{Password: "p"}, - }) - - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "g_link", - Email: "linkme@google.com", - Name: "Link Me", - }, nil - } - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "validcode", state) - u, _ := url.Parse(redirectURL) - q := u.Query() - if q.Get("token") != "" { - t.Error("expected no token in redirect") - } - if q.Get("error") == "" { - t.Error("expected error in redirect for same-email linking") - } - - // Verify existing user is NOT linked - var user model.User - err := db.Where("email = ?", "linkme@google.com").First(&user).Error - if err != nil { - t.Fatal("expected existing user") - } - if user.GoogleOauthID != nil { - t.Error("expected google oauth id to remain unset") - } - }) - - t.Run("ExistingEmailWith2FA", func(t *testing.T) { - // Create a user with 2FA enabled - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "email2fa"}, Email: "2fa@google.com"}, - Password: dto.Password{Password: "p"}, - }) - db.Model(&model.User{}).Where("id = ?", u.ID).Update("two_fa_token", "secret") - - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "g_2fa", - Email: "2fa@google.com", - Name: "Two Fa", - }, nil - } - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "validcode", state) - u2, _ := url.Parse(redirectURL) - q := u2.Query() - if q.Get("token") != "" { - t.Error("expected no token in redirect") - } - if q.Get("error") == "" { - t.Error("expected error in redirect for 2FA user") - } - }) -} - -func TestHandleGoogleOAuthCallback_Errors(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - origExchange := ExchangeCodeForTokens - origFetch := FetchGoogleUserInfo - defer func() { - ExchangeCodeForTokens = origExchange - FetchGoogleUserInfo = origFetch - }() - - state, _ := jwt.SignOauthStateToken(svc.Dep) - - t.Run("ExchangeError", func(t *testing.T) { - ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return nil, errors.New("exchange failed") - } - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "code", state) - u, _ := url.Parse(redirectURL) - if u.Query().Get("error") == "" { - t.Error("expected error message") - } - }) - - t.Run("FetchError", func(t *testing.T) { - ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return &idtoken.Payload{}, nil - } - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return nil, errors.New("fetch failed") - } - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "code", state) - u, _ := url.Parse(redirectURL) - if u.Query().Get("error") == "" { - t.Error("expected error message") - } - }) -} - -func TestLinkGoogleAccountToExistingUser(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "linkuser"}, Email: "link@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - // Fetch model user - var modelUser model.User - db.First(&modelUser, u.ID) - - t.Run("BlockedForSafety", func(t *testing.T) { - picture := "pic.png" - googleInfo := &dto.GoogleUserData{ - ID: "g123", - Email: "link@e.com", - Picture: &picture, - } - - err := svc.linkGoogleAccountToExistingUser(ctx, &modelUser, googleInfo) - if err == nil { - t.Fatal("expected linking to be blocked") - } - authErr, ok := err.(*authError.AuthError) - if !ok { - t.Fatalf("expected AuthError, got %T: %v", err, err) - } - if authErr.Status != 409 { - t.Fatalf("expected 409 error, got %d: %v", authErr.Status, authErr) - } - if authErr.Message != "same email exists" { - t.Fatalf("expected safety message, got %q", authErr.Message) - } - if modelUser.GoogleOauthID != nil { - t.Error("expected google id to remain unset") - } - if modelUser.Avatar != nil { - t.Error("expected avatar to remain unchanged") - } - }) - - t.Run("EmailMismatch", func(t *testing.T) { - googleInfo := &dto.GoogleUserData{ - ID: "g456", - Email: "other@e.com", - } - err := svc.linkGoogleAccountToExistingUser(ctx, &modelUser, googleInfo) - authErr, ok := err.(*authError.AuthError) - if err == nil || !ok || authErr.Status != 409 { - t.Errorf("expected 409 AuthError, got %v", err) - } - }) - - t.Run("AlreadyLinked", func(t *testing.T) { - // Linking is currently blocked regardless of state. - googleInfo := &dto.GoogleUserData{ - ID: "g789", - Email: "link@e.com", - } - err := svc.linkGoogleAccountToExistingUser(ctx, &modelUser, googleInfo) - authErr, ok := err.(*authError.AuthError) - if err == nil || !ok || authErr.Status != 409 { - t.Errorf("expected 409 AuthError, got %v", err) - } - }) -} - -func TestCreateNewUserFromGoogleInfo(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - googleInfo := &dto.GoogleUserData{ - ID: "newg1", - Email: "new@g.com", - Name: "New User", - } - - user, err := svc.createNewUserFromGoogleInfo(ctx, googleInfo, false) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if user.Email != "new@g.com" { - t.Errorf("expected email new@g.com, got %s", user.Email) - } - if user.Username != "G_newg1" { - t.Errorf("expected username G_newg1, got %s", user.Username) - } - }) - - t.Run("DuplicateUsernameRetry", func(t *testing.T) { - // Create a user that conflicts with the default google username - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "G_gdup"}, Email: "existing@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - googleInfo := &dto.GoogleUserData{ - ID: "gdup", - Email: "unique@g.com", - } - - user, err := svc.createNewUserFromGoogleInfo(ctx, googleInfo, false) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // Should have generated a random UUID based username - if user.Username == "G_gdup" { - t.Error("expected random username on collision") - } - if user.Email != "unique@g.com" { - t.Error("expected correct email") - } - }) -} - -func TestHandleGoogleOAuthCallback_DBError(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - origExchange := ExchangeCodeForTokens - origFetch := FetchGoogleUserInfo - defer func() { - ExchangeCodeForTokens = origExchange - FetchGoogleUserInfo = origFetch - }() - - state, _ := jwt.SignOauthStateToken(svc.Dep) - - ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return &idtoken.Payload{Subject: "g123"}, nil - } - - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "g123", - Email: "test@google.com", - Name: "Google User", - }, nil - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "code", state) - u, _ := url.Parse(redirectURL) - if u.Query().Get("error") == "" { - t.Error("expected error message on closed db") - } -} - -func TestHandleGoogleOAuthCallback_LinkError(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - origExchange := ExchangeCodeForTokens - origFetch := FetchGoogleUserInfo - defer func() { - ExchangeCodeForTokens = origExchange - FetchGoogleUserInfo = origFetch - }() - - state, _ := jwt.SignOauthStateToken(svc.Dep) - - ExchangeCodeForTokens = func(_ *dependency.Dependency, ctx context.Context, code string) (*idtoken.Payload, error) { - return &idtoken.Payload{Subject: "new_g_id"}, nil - } - - FetchGoogleUserInfo = func(payload *idtoken.Payload) (*dto.GoogleUserData, error) { - return &dto.GoogleUserData{ - ID: "new_g_id", - Email: "test@google.com", - Name: "Google User", - }, nil - } - - // Create user with SAME email but DIFFERENT google ID (already linked) - googleID := "old_g_id" - svc.Dep.DB.Create(&model.User{ - Username: "existing", - Email: "test@google.com", - GoogleOauthID: &googleID, - }) - - redirectURL := svc.HandleGoogleOAuthCallback(ctx, "code", state) - u, _ := url.Parse(redirectURL) - if u.Query().Get("error") == "" { - t.Error("expected error message for link failure") - } -} diff --git a/backend/internal/service/helper_test.go b/backend/internal/service/helper_test.go deleted file mode 100644 index d13dd95..0000000 --- a/backend/internal/service/helper_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" - - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" -) - -func mustNewUserService(t *testing.T, dep *dependency.Dependency) *UserService { - t.Helper() - svc, err := NewUserService(dep) - if err != nil { - t.Fatalf("failed to create user service: %v", err) - } - return svc -} - -func TestHelperFunctions(t *testing.T) { - t.Run("isTwoFAEnabled", func(t *testing.T) { - token := "pre-secret" - if isTwoFAEnabled(&token) { - t.Error("expected false for pre- prefix") - } - - token = "secret" - if !isTwoFAEnabled(&token) { - t.Error("expected true for valid secret") - } - - token = "" - if isTwoFAEnabled(&token) { - t.Error("expected false for empty token") - } - - if isTwoFAEnabled(nil) { - t.Error("expected false for nil token") - } - }) - - t.Run("userToUserWithTokenResponse", func(t *testing.T) { - token := "secret" - user := &model.User{ - Username: "u", - TwoFAToken: &token, - } - resp := userToUserWithTokenResponse(user, "jwt") - if !resp.TwoFA { - t.Error("expected 2FA true") - } - if resp.Token != "jwt" { - t.Error("expected token match") - } - }) - - t.Run("OnlineStatusChecker", func(t *testing.T) { - hbs := []model.HeartBeat{ - {UserID: 1}, - } - checker := newOnlineStatusChecker(hbs) - if !checker.isOnline(1) { - t.Error("expected 1 online") - } - if checker.isOnline(2) { - t.Error("expected 2 offline") - } - }) - - t.Run("UpdateHeartBeat", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - - // Create user first to satisfy FK - _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "hb"}, Email: "hb@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - // Create heartbeat entry - svc.updateHeartBeat(1) - - // Wait for goroutine - time.Sleep(100 * time.Millisecond) - - var hb model.HeartBeat - if err := db.Where("user_id = ?", 1).First(&hb).Error; err != nil { - t.Fatalf("expected heartbeat created: %v", err) - } - }) - - t.Run("IssueNewTokenForUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - - // Create user first - _, _ = svc.CreateUser(context.Background(), &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "issue"}, Email: "issue@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - token, err := svc.issueNewTokenForUser(context.Background(), 1, false) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if token == "" { - t.Error("expected token") - } - - // Allow async heartbeat to finish - time.Sleep(200 * time.Millisecond) - - // Revoke old tokens - _, _ = svc.issueNewTokenForUser(context.Background(), 1, true) - var count int64 - db.Model(&model.Token{}).Where("user_id = ?", 1).Count(&count) - if count != 1 { - t.Errorf("expected 1 token, got %d", count) - } - }) - - t.Run("IssueNewTokenForUser_DBError", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.issueNewTokenForUser(context.Background(), 1, true) - if err == nil { - t.Error("expected error on closed db") - } - }) -} diff --git a/backend/internal/service/redis_service_test.go b/backend/internal/service/redis_service_test.go deleted file mode 100644 index 34deed0..0000000 --- a/backend/internal/service/redis_service_test.go +++ /dev/null @@ -1,344 +0,0 @@ -package service - -import ( - "context" - "errors" - "fmt" - "strings" - "testing" - "time" - - authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - "github.com/paularynty/transcendence/auth-service-go/internal/config" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/testutil" - "github.com/redis/go-redis/v9" -) - -func withRedisTestExpiries(cfg *config.Config, userTTLSeconds int, absoluteTTLSeconds int) { - cfg.UserTokenExpiry = userTTLSeconds - cfg.UserTokenAbsoluteExpiry = absoluteTTLSeconds -} - -func TestRedisTokenLifecycle(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - withRedisTestExpiries(cfg, 10, 30) - mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "redisuser"}, Email: "redis@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - token, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token: %v", err) - } - if token == "" { - t.Fatal("expected non-empty token") - } - - key := buildTokenKey(userResp.ID, token) - if !mr.Exists(key) { - t.Fatalf("expected redis token key to exist: %s", key) - } - - // Drive time close to expiry, then validate and ensure TTL slides forward. - mr.FastForward(9 * time.Second) - ttlBefore := mr.TTL(key) - if ttlBefore <= 0 { - t.Fatalf("expected TTL before validation to be positive, got %v", ttlBefore) - } - - if err := svc.ValidateUserToken(ctx, token, userResp.ID); err != nil { - t.Fatalf("expected token to validate, got %v", err) - } - - ttlAfter := mr.TTL(key) - if ttlAfter < 8*time.Second { - t.Fatalf("expected sliding TTL refresh, got %v", ttlAfter) - } - - // Logout should revoke all redis tokens for the user. - if err := svc.LogoutUser(ctx, userResp.ID); err != nil { - t.Fatalf("logout failed: %v", err) - } - - if mr.Exists(key) { - t.Fatal("expected redis token key to be deleted on logout") - } - - err = svc.ValidateUserToken(ctx, token, userResp.ID) - if err == nil { - t.Fatal("expected token to be invalid after logout") - } - var authErr *authError.AuthError - if !strings.Contains(err.Error(), "invalid token") || !errors.As(err, &authErr) { - t.Fatalf("expected auth error for invalid token, got %v", err) - } -} - -func TestRedisHeartbeatOnlineStatusAndCleanup(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - _, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - u1, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "hb1"}, Email: "hb1@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user1: %v", err) - } - - _, err = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "hb2"}, Email: "hb2@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user2: %v", err) - } - - svc.updateHeartBeat(u1.ID) - time.Sleep(100 * time.Millisecond) - - onlineNow, err := svc.getOnlineStatus(ctx) - if err != nil { - t.Fatalf("getOnlineStatus failed: %v", err) - } - - checkerNow := newOnlineStatusChecker(onlineNow) - if !checkerNow.isOnline(u1.ID) { - t.Fatal("expected user1 to be online after heartbeat") - } - - // Force the heartbeat score to be old, then ensure cleanup happens. - oldScore := float64(time.Now().Add(-3 * time.Minute).Unix()) - if err := redisClient.ZAdd(ctx, HeartBeatPrefix, redis.Z{Score: oldScore, Member: u1.ID}).Err(); err != nil { - t.Fatalf("failed to set old heartbeat score: %v", err) - } - - onlineLater, err := svc.getOnlineStatus(ctx) - if err != nil { - t.Fatalf("getOnlineStatus later failed: %v", err) - } - - checkerLater := newOnlineStatusChecker(onlineLater) - if checkerLater.isOnline(u1.ID) { - t.Fatal("expected user1 to be offline after expiration window") - } - - // Cleanup should have removed the expired heartbeat entry. - time.Sleep(100 * time.Millisecond) - if _, err := redisClient.ZScore(ctx, HeartBeatPrefix, fmt.Sprint(u1.ID)).Result(); err == nil { - t.Fatal("expected expired heartbeat to be removed from redis") - } -} - -func TestRedisLoginUpdatesHeartbeat(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - _, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - created, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "loginhb"}, Email: "loginhb@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - userID := created.ID - - res, err := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "loginhb"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("login failed: %v", err) - } - if res.User == nil || res.User.Token == "" { - t.Fatal("expected login to issue a valid token") - } - - time.Sleep(100 * time.Millisecond) - - score, err := redisClient.ZScore(ctx, HeartBeatPrefix, fmt.Sprint(userID)).Result() - if err != nil { - t.Fatalf("expected heartbeat entry for user, got error: %v", err) - } - now := time.Now().Unix() - if int64(score) < now-5 { - t.Fatalf("expected recent heartbeat score, got %v (now=%d)", score, now) - } -} - -func TestRedisLogoutRevokesAllTokens(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "logoutmulti"}, Email: "logoutmulti@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - token1, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token1: %v", err) - } - token2, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token2: %v", err) - } - - key1 := buildTokenKey(userResp.ID, token1) - key2 := buildTokenKey(userResp.ID, token2) - if !mr.Exists(key1) || !mr.Exists(key2) { - t.Fatalf("expected both redis token keys to exist: %s, %s", key1, key2) - } - - if err := svc.LogoutUser(ctx, userResp.ID); err != nil { - t.Fatalf("logout failed: %v", err) - } - - if mr.Exists(key1) || mr.Exists(key2) { - t.Fatal("expected redis token keys to be deleted on logout") - } - - if err := svc.ValidateUserToken(ctx, token1, userResp.ID); err == nil { - t.Fatal("expected token1 to be invalid after logout") - } - if err := svc.ValidateUserToken(ctx, token2, userResp.ID); err == nil { - t.Fatal("expected token2 to be invalid after logout") - } -} - -func TestRedisDeleteUserRevokesAllTokens(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "delredis"}, Email: "delredis@example.com"}, - Password: dto.Password{Password: "password123"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - token1, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token1: %v", err) - } - token2, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token2: %v", err) - } - - key1 := buildTokenKey(userResp.ID, token1) - key2 := buildTokenKey(userResp.ID, token2) - if !mr.Exists(key1) || !mr.Exists(key2) { - t.Fatalf("expected both redis token keys to exist: %s, %s", key1, key2) - } - - if err := svc.DeleteUser(ctx, userResp.ID); err != nil { - t.Fatalf("delete failed: %v", err) - } - - if mr.Exists(key1) || mr.Exists(key2) { - t.Fatal("expected redis token keys to be deleted on user deletion") - } - - if err := svc.ValidateUserToken(ctx, token1, userResp.ID); err == nil { - t.Fatal("expected token1 to be invalid after delete") - } - if err := svc.ValidateUserToken(ctx, token2, userResp.ID); err == nil { - t.Fatal("expected token2 to be invalid after delete") - } -} - -func TestRedisUpdatePasswordRevokesOldTokens(t *testing.T) { - db := setupTestDB(t.Name()) - cfg := testutil.NewTestConfig() - mr, redisClient, cleanupRedis := setupTestRedis(t, cfg) - defer cleanupRedis() - - svc := mustNewUserService(t, newTestDependencyWithConfig(cfg, db, redisClient)) - ctx := context.Background() - - userResp, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "pwredis"}, Email: "pwredis@example.com"}, - Password: dto.Password{Password: "oldpass"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - token1, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token1: %v", err) - } - token2, err := svc.issueNewTokenForUser(ctx, userResp.ID, false) - if err != nil { - t.Fatalf("failed to issue token2: %v", err) - } - - key1 := buildTokenKey(userResp.ID, token1) - key2 := buildTokenKey(userResp.ID, token2) - if !mr.Exists(key1) || !mr.Exists(key2) { - t.Fatalf("expected both redis token keys to exist: %s, %s", key1, key2) - } - - updateReq := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "oldpass"}, - NewPassword: dto.NewPassword{NewPassword: "newpass"}, - } - resp, err := svc.UpdateUserPassword(ctx, userResp.ID, updateReq) - if err != nil { - t.Fatalf("update password failed: %v", err) - } - if resp.Token == "" { - t.Fatal("expected new token from password update") - } - - if mr.Exists(key1) || mr.Exists(key2) { - t.Fatal("expected old redis token keys to be deleted on password change") - } - - if err := svc.ValidateUserToken(ctx, token1, userResp.ID); err == nil { - t.Fatal("expected token1 to be invalid after password change") - } - if err := svc.ValidateUserToken(ctx, token2, userResp.ID); err == nil { - t.Fatal("expected token2 to be invalid after password change") - } - if err := svc.ValidateUserToken(ctx, resp.Token, userResp.ID); err != nil { - t.Fatalf("expected new token to be valid after password change, got %v", err) - } -} diff --git a/backend/internal/service/setup_test.go b/backend/internal/service/setup_test.go deleted file mode 100644 index a9b8ed0..0000000 --- a/backend/internal/service/setup_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package service - -import ( - "os" - "strings" - "testing" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - - "github.com/paularynty/transcendence/auth-service-go/internal/config" - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/testutil" -) - -func setupTestDB(testName string) *gorm.DB { - // Sanitize test name for use as DB identifier - // Add busy_timeout to reduce locking errors - // Add _foreign_keys=on to enforce FK constraints - dbName := "file:" + strings.ReplaceAll(testName, "/", "_") + "?mode=memory&cache=shared&_busy_timeout=5000&_foreign_keys=on" - - db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{ - TranslateError: true, - }) - if err != nil { - panic("failed to connect database") - } - - // Explicitly enable foreign keys for SQLite just in case the DSN parameter isn't enough for the driver version - db.Exec("PRAGMA foreign_keys = ON") - - err = db.AutoMigrate(&model.User{}, &model.Friend{}, &model.Token{}, &model.HeartBeat{}) - if err != nil { - panic("failed to migrate database") - } - - if sqlDB, err := db.DB(); err == nil { - sqlDB.SetMaxOpenConns(1) - } - - return db -} - -func TestMain(m *testing.M) { - code := m.Run() - os.Exit(code) -} - -func newTestDependency(db *gorm.DB, redis *redis.Client, cfgMutators ...func(*config.Config)) *dependency.Dependency { - cfg := testutil.NewTestConfig() - for _, mutate := range cfgMutators { - mutate(cfg) - } - logger := testutil.NewTestLogger() - if redis != nil { - cfg.IsRedisEnabled = true - if cfg.RedisURL == "" { - cfg.RedisURL = "redis://test" - } - } - return dependency.NewDependency(cfg, db, redis, logger) -} - -func newTestDependencyWithConfig(cfg *config.Config, db *gorm.DB, redis *redis.Client) *dependency.Dependency { - if cfg == nil { - cfg = testutil.NewTestConfig() - } - logger := testutil.NewTestLogger() - if redis != nil { - cfg.IsRedisEnabled = true - if cfg.RedisURL == "" { - cfg.RedisURL = "redis://test" - } - } - return dependency.NewDependency(cfg, db, redis, logger) -} - -func setupTestRedis(t *testing.T, cfg *config.Config) (*miniredis.Miniredis, *redis.Client, func()) { - t.Helper() - - mr := miniredis.RunT(t) - client := redis.NewClient(&redis.Options{ - Addr: mr.Addr(), - }) - - if cfg != nil { - cfg.RedisURL = "redis://" + mr.Addr() - cfg.IsRedisEnabled = true - } - - cleanup := func() { - _ = client.Close() - mr.Close() - } - - return mr, client, cleanup -} diff --git a/backend/internal/service/twofa_service_test.go b/backend/internal/service/twofa_service_test.go deleted file mode 100644 index b086534..0000000 --- a/backend/internal/service/twofa_service_test.go +++ /dev/null @@ -1,446 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" - - authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" - "github.com/paularynty/transcendence/auth-service-go/internal/util/jwt" - "github.com/pquerna/otp/totp" -) - -func TestTwoFASetupAndConfirm(t *testing.T) { - ctx := context.Background() - - t.Run("StartSetup_Success", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u1"}, Email: "u1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - resp, err := svc.StartTwoFaSetup(ctx, u.ID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.TwoFASecret == "" { - t.Error("expected secret") - } - }) - - t.Run("ConfirmSetup_Success", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u2"}, Email: "u2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - resp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, err := totp.GenerateCode(resp.TwoFASecret, time.Now()) - if err != nil { - t.Fatalf("failed to generate code: %v", err) - } - - req := &dto.TwoFAConfirmRequest{ - SetupToken: resp.SetupToken, - TwoFACode: code, - } - - res, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !res.TwoFA { - t.Error("expected 2FA to be enabled") - } - }) - - t.Run("StartSetup_AlreadyEnabled", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u3"}, Email: "u3@e.com"}, - Password: dto.Password{Password: "p"}, - }) - resp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(resp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{ - SetupToken: resp.SetupToken, - TwoFACode: code, - }) - - _, err := svc.StartTwoFaSetup(ctx, u.ID) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 400 { - t.Errorf("expected 400 error, got %v", err) - } - }) - - t.Run("StartSetup_OAuthUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - // Mock OAuth user - oauthUser := dto.GoogleUserData{ - ID: "oauth123", - Email: "oauth@test.com", - } - user, err := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - _, err = svc.StartTwoFaSetup(ctx, user.ID) - if err == nil { - t.Fatal("expected error for oauth user") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 400 { - t.Errorf("expected 400 error, got %v", err) - } - }) - - t.Run("StartSetup_DBError", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "u4"}, Email: "u4@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - _, err := svc.StartTwoFaSetup(ctx, u.ID) - if err == nil { - t.Error("expected error on closed db") - } - }) -} - -func TestConfirmTwoFaSetup_Errors(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "c"}, Email: "c@e.com"}, - Password: dto.Password{Password: "p"}, - }) - resp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(resp.TwoFASecret, time.Now()) - - t.Run("InvalidToken", func(t *testing.T) { - req := &dto.TwoFAConfirmRequest{ - SetupToken: "invalid", - TwoFACode: code, - } - _, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) - if err == nil { - t.Error("expected error for invalid token") - } - }) - - t.Run("UserMismatch", func(t *testing.T) { - // Create setup token for another user - u2, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "c2"}, Email: "c2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - resp2, _ := svc.StartTwoFaSetup(ctx, u2.ID) - - req := &dto.TwoFAConfirmRequest{ - SetupToken: resp2.SetupToken, - TwoFACode: code, - } - _, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) // Wrong user ID - if err == nil { - t.Error("expected error for user mismatch") - } - }) - - t.Run("WrongTokenType", func(t *testing.T) { - token, _ := jwt.SignUserToken(svc.Dep, u.ID) - req := &dto.TwoFAConfirmRequest{ - SetupToken: token, - TwoFACode: code, - } - _, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) - if err == nil { - t.Error("expected error for wrong token type") - } - }) - - t.Run("DBError", func(t *testing.T) { - req := &dto.TwoFAConfirmRequest{ - SetupToken: resp.SetupToken, - TwoFACode: code, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - _, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) - if err == nil { - t.Error("expected error on closed db") - } - }) - - t.Run("NotInitiated", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - // User with no 2FA token - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "ni"}, Email: "ni@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - // Create a valid setup token manually - setupToken, _ := jwt.SignTwoFASetupToken(svc.Dep, u.ID, "secret") - code, _ := totp.GenerateCode("secret", time.Now()) - - req := &dto.TwoFAConfirmRequest{ - SetupToken: setupToken, - TwoFACode: code, - } - - _, err := svc.ConfirmTwoFaSetup(ctx, u.ID, req) - if err == nil { - t.Error("expected error for not initiated") - } - }) -} - -func TestTwoFAChallenge(t *testing.T) { - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "ch1"}, Email: "ch1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - setupResp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - - loginResp, _ := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "ch1"}, - Password: dto.Password{Password: "p"}, - }) - sessionToken := loginResp.TwoFAPending.SessionToken - - code, _ = totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - req := &dto.TwoFAChallengeRequest{ - SessionToken: sessionToken, - TwoFACode: code, - } - - resp, err := svc.SubmitTwoFAChallenge(ctx, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.Token == "" { - t.Error("expected valid user token") - } - }) - - t.Run("InvalidCode", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "ch2"}, Email: "ch2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - setupResp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - - loginResp, _ := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "ch2"}, - Password: dto.Password{Password: "p"}, - }) - sessionToken := loginResp.TwoFAPending.SessionToken - - req := &dto.TwoFAChallengeRequest{ - SessionToken: sessionToken, - TwoFACode: "000000", - } - - _, err := svc.SubmitTwoFAChallenge(ctx, req) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("NotEnabled", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "chne"}, Email: "chne@e.com"}, - Password: dto.Password{Password: "p"}, - }) - // Do NOT enable 2FA - - // Create session token manually - sessionToken, _ := jwt.SignTwoFAToken(svc.Dep, u.ID) - - req := &dto.TwoFAChallengeRequest{ - SessionToken: sessionToken, - TwoFACode: "000000", - } - - _, err := svc.SubmitTwoFAChallenge(ctx, req) - if err == nil { - t.Fatal("expected error for not enabled") - } - }) - - t.Run("DBError", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "ch3"}, Email: "ch3@e.com"}, - Password: dto.Password{Password: "p"}, - }) - setupResp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - - loginResp, _ := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "ch3"}, - Password: dto.Password{Password: "p"}, - }) - sessionToken := loginResp.TwoFAPending.SessionToken - - req := &dto.TwoFAChallengeRequest{ - SessionToken: sessionToken, - TwoFACode: code, - } - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - _, err := svc.SubmitTwoFAChallenge(ctx, req) - if err == nil { - t.Error("expected error on closed db") - } - }) -} - -func TestDisableTwoFA(t *testing.T) { - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "dis1"}, Email: "dis1@e.com"}, - Password: dto.Password{Password: "p"}, - }) - setupResp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - req := &dto.DisableTwoFARequest{ - Password: dto.Password{Password: "p"}, - } - - resp, err := svc.DisableTwoFA(ctx, u.ID, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.TwoFA { - t.Error("expected 2FA to be disabled") - } - }) - - t.Run("AlreadyDisabled", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "dis2"}, Email: "dis2@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - req := &dto.DisableTwoFARequest{ - Password: dto.Password{Password: "p"}, - } - _, err := svc.DisableTwoFA(ctx, u.ID, req) - if err == nil { - t.Fatal("expected error") - } - }) - - t.Run("OAuthUser", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - // Mock OAuth user - oauthUser := dto.GoogleUserData{ - ID: "oauth456", - Email: "oauth2@test.com", - } - user, _ := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) - - req := &dto.DisableTwoFARequest{ - Password: dto.Password{Password: "any"}, - } - _, err := svc.DisableTwoFA(ctx, user.ID, req) - if err == nil { - t.Fatal("expected error for oauth user") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 400 { - t.Errorf("expected 400 error, got %v", err) - } - }) - - t.Run("DBError", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "dis3"}, Email: "dis3@e.com"}, - Password: dto.Password{Password: "p"}, - }) - - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - req := &dto.DisableTwoFARequest{ - Password: dto.Password{Password: "p"}, - } - _, err := svc.DisableTwoFA(ctx, u.ID, req) - if err == nil { - t.Error("expected error on closed db") - } - }) - - t.Run("InvalidPassword", func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "disinv"}, Email: "disinv@e.com"}, - Password: dto.Password{Password: "correct"}, - }) - - // Enable 2FA manually - setupResp, _ := svc.StartTwoFaSetup(ctx, u.ID) - code, _ := totp.GenerateCode(setupResp.TwoFASecret, time.Now()) - _, _ = svc.ConfirmTwoFaSetup(ctx, u.ID, &dto.TwoFAConfirmRequest{SetupToken: setupResp.SetupToken, TwoFACode: code}) - req := &dto.DisableTwoFARequest{ - Password: dto.Password{Password: "wrong"}, - } - _, err := svc.DisableTwoFA(ctx, u.ID, req) - if err == nil { - t.Fatal("expected error for invalid password") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 401 { - t.Errorf("expected 401 error, got %v", err) - } - }) -} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go deleted file mode 100644 index 8a35244..0000000 --- a/backend/internal/service/user_service_test.go +++ /dev/null @@ -1,685 +0,0 @@ -package service - -import ( - "context" - "strings" - "testing" - - authError "github.com/paularynty/transcendence/auth-service-go/internal/auth_error" - model "github.com/paularynty/transcendence/auth-service-go/internal/db" - "github.com/paularynty/transcendence/auth-service-go/internal/dto" -) - -func requireAuthStatus(t *testing.T, err error, status int) { - t.Helper() - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != status { - t.Fatalf("expected %d error, got %v", status, err) - } -} - -func TestCreateUser(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - cases := []struct { - name string - req *dto.CreateUserRequest - setup func() - wantErrStatus int - }{ - { - name: "Success", - req: &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "testuser"}, - Email: "test@example.com", - }, - Password: dto.Password{Password: "password123"}, - }, - }, - { - name: "DuplicateUsername", - req: &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "testuser"}, - Email: "other@example.com", - }, - Password: dto.Password{Password: "password123"}, - }, - setup: func() { - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "testuser"}, - Email: "test@example.com", - }, - Password: dto.Password{Password: "password123"}, - }) - }, - wantErrStatus: 409, - }, - { - name: "DuplicateEmail", - req: &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "otheruser"}, - Email: "test@example.com", - }, - Password: dto.Password{Password: "password123"}, - }, - setup: func() { - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "seeduser"}, - Email: "test@example.com", - }, - Password: dto.Password{Password: "password123"}, - }) - }, - wantErrStatus: 409, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - resp, err := svc.CreateUser(ctx, tc.req) - if tc.wantErrStatus == 0 { - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if resp.Username != tc.req.Username { - t.Errorf("expected username %s, got %s", tc.req.Username, resp.Username) - } - if resp.Email != tc.req.Email { - t.Errorf("expected email %s, got %s", tc.req.Email, resp.Email) - } - if resp.ID == 0 { - t.Error("expected valid ID") - } - return - } - if err == nil { - t.Fatalf("expected error for %s", tc.name) - } - requireAuthStatus(t, err, tc.wantErrStatus) - }) - } -} - -func TestLoginUser(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - // Setup user - createReq := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "loginuser"}, - Email: "login@example.com", - }, - Password: dto.Password{Password: "password123"}, - } - _, _ = svc.CreateUser(ctx, createReq) - - t.Run("SuccessUsername", func(t *testing.T) { - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "loginuser"}, - Password: dto.Password{Password: "password123"}, - } - - res, err := svc.LoginUser(ctx, req) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if res.User == nil || res.User.Token == "" { - t.Error("expected user with token") - } - }) - - t.Run("SuccessEmail", func(t *testing.T) { - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "login@example.com"}, - Password: dto.Password{Password: "password123"}, - } - - res, err := svc.LoginUser(ctx, req) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if res.User == nil || res.User.Token == "" { - t.Error("expected user with token") - } - }) - - t.Run("InvalidPassword", func(t *testing.T) { - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "loginuser"}, - Password: dto.Password{Password: "wrongpass"}, - } - - _, err := svc.LoginUser(ctx, req) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 401 { - t.Errorf("expected 401 error, got %v", err) - } - }) - - t.Run("UserNotFound", func(t *testing.T) { - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "nonexistent"}, - Password: dto.Password{Password: "password123"}, - } - - _, err := svc.LoginUser(ctx, req) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 401 { - t.Errorf("expected 401 error, got %v", err) - } - }) - - t.Run("2FARequired", func(t *testing.T) { - // Enable 2FA for user - user, _ := svc.GetUserByID(ctx, 1) // First user created (loginuser) - _, _ = svc.StartTwoFaSetup(ctx, user.ID) - // We need to confirm it properly, but we can hack it for this test - db.Model(&model.User{}).Where("id = ?", user.ID).Update("two_fa_token", "secret") - - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "loginuser"}, - Password: dto.Password{Password: "password123"}, - } - - res, err := svc.LoginUser(ctx, req) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if res.User != nil { - t.Error("expected no user token when 2FA required") - } - if res.TwoFAPending == nil { - t.Error("expected 2FA pending response") - } - if res.TwoFAPending.Message != "2FA_REQUIRED" { - t.Errorf("expected message 2FA_REQUIRED, got %s", res.TwoFAPending.Message) - } - }) - - t.Run("OAuthUser", func(t *testing.T) { - // Create oauth user - oauthUser := dto.GoogleUserData{ID: "login_oauth", Email: "login_oauth@e.com"} - user, _ := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) - - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: user.Username}, - Password: dto.Password{Password: "any"}, - } - - _, err := svc.LoginUser(ctx, req) - if err == nil { - t.Fatal("expected error") - } - authErr, ok := err.(*authError.AuthError) - if !ok || authErr.Status != 401 { - t.Errorf("expected 401 error, got %v", err) - } - }) - - t.Run("InvalidHash", func(t *testing.T) { - // Manually create user with invalid hash - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "badhash"}, Email: "badhash@e.com"}, - Password: dto.Password{Password: "p"}, - }) - badHash := "invalid_hash" - db.Model(&model.User{}).Where("username = ?", "badhash").Update("password_hash", badHash) - - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "badhash"}, - Password: dto.Password{Password: "p"}, - } - - _, err := svc.LoginUser(ctx, req) - if err == nil { - t.Fatal("expected error") - } - // Should return raw error, not AuthError - if _, ok := err.(*authError.AuthError); ok { - t.Error("expected raw error for invalid hash") - } - }) -} - -func TestGetUserByID(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "getuser"}, - Email: "get@example.com", - }, - Password: dto.Password{Password: "pass"}, - }) - - cases := []struct { - name string - userID uint - wantErrStatus int - }{ - {"Success", u.ID, 0}, - {"NotFound", 9999, 404}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - got, err := svc.GetUserByID(ctx, tc.userID) - if tc.wantErrStatus == 0 { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.ID != u.ID { - t.Errorf("want ID %d, got %d", u.ID, got.ID) - } - return - } - if err == nil { - t.Fatal("expected error") - } - requireAuthStatus(t, err, tc.wantErrStatus) - }) - } -} - -func TestUpdateUserPassword(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "passupdate"}, - Email: "pass@example.com", - }, - Password: dto.Password{Password: "oldpass"}, - }) - - t.Run("Success", func(t *testing.T) { - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "oldpass"}, - NewPassword: dto.NewPassword{NewPassword: "newpass"}, - } - - resp, err := svc.UpdateUserPassword(ctx, u.ID, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.Token == "" { - t.Error("expected new token") - } - - loginReq := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "passupdate"}, - Password: dto.Password{Password: "newpass"}, - } - if _, err := svc.LoginUser(ctx, loginReq); err != nil { - t.Error("failed to login with new password") - } - }) - - errorCases := []struct { - name string - setup func() uint - req *dto.UpdateUserPasswordRequest - wantErrStatus int - }{ - { - name: "InvalidOldPassword", - setup: func() uint { - return u.ID - }, - req: &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "wrongold"}, - NewPassword: dto.NewPassword{NewPassword: "newpass2"}, - }, - wantErrStatus: 401, - }, - { - name: "OAuthUser", - setup: func() uint { - oauthUser := dto.GoogleUserData{ID: "passoauth", Email: "passoauth@e.com"} - user, _ := svc.createNewUserFromGoogleInfo(ctx, &oauthUser, false) - return user.ID - }, - req: &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "any"}, - NewPassword: dto.NewPassword{NewPassword: "new"}, - }, - wantErrStatus: 400, - }, - } - - for _, tc := range errorCases { - t.Run(tc.name, func(t *testing.T) { - userID := tc.setup() - _, err := svc.UpdateUserPassword(ctx, userID, tc.req) - if err == nil { - t.Fatal("expected error") - } - requireAuthStatus(t, err, tc.wantErrStatus) - }) - } - - t.Run("InvalidHash", func(t *testing.T) { - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "badhash2"}, Email: "badhash2@e.com"}, - Password: dto.Password{Password: "password123"}, - }) - badHash := "invalid_hash" - var user model.User - db.Where("username = ?", "badhash2").First(&user) - db.Model(&user).Update("password_hash", badHash) - - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "password123"}, - NewPassword: dto.NewPassword{NewPassword: "new"}, - } - - _, err := svc.UpdateUserPassword(ctx, user.ID, req) - if err == nil { - t.Fatal("expected error") - } - if _, ok := err.(*authError.AuthError); ok { - t.Error("expected raw error") - } - }) -} - -func TestUpdateUserProfile(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "updateprofile"}, - Email: "update@example.com", - }, - Password: dto.Password{Password: "pass"}, - }) - - cases := []struct { - name string - setup func() - req *dto.UpdateUserRequest - wantErrStatus int - }{ - { - name: "Success", - req: &dto.UpdateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "newname"}, - Email: "new@example.com", - Avatar: func() *string { v := "new_avatar.png"; return &v }(), - }, - }, - }, - { - name: "Duplicate", - setup: func() { - _, _ = svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "other"}, Email: "other@e.com"}, - Password: dto.Password{Password: "password123"}, - }) - }, - req: &dto.UpdateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "other"}, - Email: "new@example.com", - }, - }, - wantErrStatus: 409, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - if tc.setup != nil { - tc.setup() - } - got, err := svc.UpdateUserProfile(ctx, u.ID, tc.req) - if tc.wantErrStatus == 0 { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.Username != tc.req.Username { - t.Errorf("want username %s, got %s", tc.req.Username, got.Username) - } - if got.Email != tc.req.Email { - t.Errorf("want email %s, got %s", tc.req.Email, got.Email) - } - return - } - if err == nil { - t.Fatal("expected error for duplicate") - } - requireAuthStatus(t, err, tc.wantErrStatus) - }) - } -} - -func TestDeleteUser(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - u, _ := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "deleteuser"}, - Email: "del@example.com", - }, - Password: dto.Password{Password: "pass"}, - }) - - t.Run("Success", func(t *testing.T) { - err := svc.DeleteUser(ctx, u.ID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - _, err = svc.GetUserByID(ctx, u.ID) - if err == nil { - t.Error("expected user to be deleted") - } - }) -} - -func TestValidateUserToken(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - createReq := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "tokenuser"}, - Email: "token@example.com", - }, - Password: dto.Password{Password: "pass"}, - } - _, _ = svc.CreateUser(ctx, createReq) - - loginRes, _ := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "tokenuser"}, - Password: dto.Password{Password: "pass"}, - }) - token := loginRes.User.Token - userID := loginRes.User.ID - - t.Run("Success", func(t *testing.T) { - err := svc.ValidateUserToken(ctx, token, userID) - if err != nil { - t.Errorf("expected token to be valid, got %v", err) - } - }) - - t.Run("InvalidToken", func(t *testing.T) { - err := svc.ValidateUserToken(ctx, "invalidtoken", userID) - if err == nil { - t.Error("expected error for invalid token") - } - }) - - t.Run("TokenMismatchUser", func(t *testing.T) { - u2, err := svc.CreateUser(ctx, &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "user2"}, - Email: "u2@ex.com", - }, - Password: dto.Password{Password: "pass"}, - }) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - err = svc.ValidateUserToken(ctx, token, u2.ID) - if err == nil { - t.Error("expected error for token mismatch") - } - if !strings.Contains(err.Error(), "token does not match user") { - t.Errorf("expected mismatch error, got %v", err) - } - }) -} - -func TestLogoutUser(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - ctx := context.Background() - - createReq := &dto.CreateUserRequest{ - User: dto.User{ - UserName: dto.UserName{Username: "logoutuser"}, - Email: "logout@example.com", - }, - Password: dto.Password{Password: "pass"}, - } - _, _ = svc.CreateUser(ctx, createReq) - - loginRes, _ := svc.LoginUser(ctx, &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "logoutuser"}, - Password: dto.Password{Password: "pass"}, - }) - token := loginRes.User.Token - userID := loginRes.User.ID - - err := svc.LogoutUser(ctx, userID) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - err = svc.ValidateUserToken(ctx, token, userID) - if err == nil { - t.Error("expected token to be invalid after logout") - } -} - -func TestDBErrors(t *testing.T) { - ctx := context.Background() - - cases := []struct { - name string - run func(svc *UserService) error - }{ - { - name: "CreateUser", - run: func(svc *UserService) error { - req := &dto.CreateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "db1"}, Email: "db1@e.com"}, - Password: dto.Password{Password: "password123"}, - } - _, err := svc.CreateUser(ctx, req) - return err - }, - }, - { - name: "LoginUser", - run: func(svc *UserService) error { - req := &dto.LoginUserRequest{ - Identifier: dto.Identifier{Identifier: "db1"}, - Password: dto.Password{Password: "password123"}, - } - _, err := svc.LoginUser(ctx, req) - return err - }, - }, - { - name: "GetUserByID", - run: func(svc *UserService) error { - _, err := svc.GetUserByID(ctx, 1) - return err - }, - }, - { - name: "UpdateUserPassword", - run: func(svc *UserService) error { - req := &dto.UpdateUserPasswordRequest{ - OldPassword: dto.OldPassword{OldPassword: "password123"}, - NewPassword: dto.NewPassword{NewPassword: "password456"}, - } - _, err := svc.UpdateUserPassword(ctx, 1, req) - return err - }, - }, - { - name: "UpdateUserProfile", - run: func(svc *UserService) error { - req := &dto.UpdateUserRequest{ - User: dto.User{UserName: dto.UserName{Username: "n"}, Email: "n@e.com"}, - } - _, err := svc.UpdateUserProfile(ctx, 1, req) - return err - }, - }, - { - name: "DeleteUser", - run: func(svc *UserService) error { - return svc.DeleteUser(ctx, 1) - }, - }, - { - name: "ValidateUserToken", - run: func(svc *UserService) error { - return svc.ValidateUserToken(ctx, "token", 1) - }, - }, - { - name: "LogoutUser", - run: func(svc *UserService) error { - return svc.LogoutUser(ctx, 1) - }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - db := setupTestDB(t.Name()) - svc := mustNewUserService(t, newTestDependency(db, nil)) - sqlDB, _ := db.DB() - _ = sqlDB.Close() - - if err := tc.run(svc); err == nil { - t.Error("expected db error") - } - }) - } -} From 8433761c8be1d1965f5245ccfc60944d84ca5d5b Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:37:13 +0200 Subject: [PATCH 13/15] refactor/backend: move router setup to a dedicated file for better organization --- backend/cmd/server/main.go | 46 +------------------------- backend/internal/routers/router.go | 52 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 45 deletions(-) create mode 100644 backend/internal/routers/router.go diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 78a791e..1388f48 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -6,7 +6,6 @@ import ( "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -23,52 +22,9 @@ import ( "log/slog" - sloggin "github.com/samber/slog-gin" - - "github.com/gin-contrib/cors" - "github.com/paularynty/transcendence/auth-service-go/internal/dependency" - "github.com/paularynty/transcendence/auth-service-go/internal/middleware" ) -func SetupRouter(dep *dependency.Dependency) *gin.Engine { - r := gin.New() - - logConfig := sloggin.Config{ - DefaultLevel: slog.LevelInfo, - ClientErrorLevel: slog.LevelWarn, - ServerErrorLevel: slog.LevelError, - } - - // A rough CORS - r.Use(cors.New(cors.Config{ - AllowOriginFunc: func(origin string) bool { - if origin == "http://localhost:5173" || - origin == "http://localhost:4173" { - return true - } - if strings.HasSuffix(origin, ".vercel.app") { - return true - } - return false - }, - AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, - ExposeHeaders: []string{"Content-Length"}, - AllowCredentials: true, - MaxAge: 12 * time.Hour, - })) - - rateLimiter := middleware.NewRateLimiter(time.Duration(dep.Cfg.RateLimiterDurationInSec)*time.Second, dep.Cfg.RateLimiterRequestLimit, time.Duration(dep.Cfg.RateLimiterCleanupIntervalInSec)*time.Second) - r.Use(rateLimiter.RateLimit()) - - r.Use(middleware.PanicHandler()) - r.Use(sloggin.NewWithConfig(dep.Logger, logConfig)) - r.Use(middleware.ErrorHandler()) - - return r -} - // @title Auth Service API // @version 1.0 // @description Auth service @@ -97,7 +53,7 @@ func main() { } // router - r := SetupRouter(dep) + r := routers.SetupRouter(dep) routers.UsersRouter(r.Group("/api/users"), userService) // Health check diff --git a/backend/internal/routers/router.go b/backend/internal/routers/router.go new file mode 100644 index 0000000..42af5f6 --- /dev/null +++ b/backend/internal/routers/router.go @@ -0,0 +1,52 @@ +package routers + +import ( + "log/slog" + "strings" + "time" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" + "github.com/paularynty/transcendence/auth-service-go/internal/dependency" + "github.com/paularynty/transcendence/auth-service-go/internal/middleware" + sloggin "github.com/samber/slog-gin" +) + +func SetupRouter(dep *dependency.Dependency) *gin.Engine { + r := gin.New() + + r.Use(middleware.PanicHandler()) + + logConfig := sloggin.Config{ + DefaultLevel: slog.LevelInfo, + ClientErrorLevel: slog.LevelWarn, + ServerErrorLevel: slog.LevelError, + } + + // A rough CORS + r.Use(cors.New(cors.Config{ + AllowOriginFunc: func(origin string) bool { + if origin == "http://localhost:5173" || + origin == "http://localhost:4173" { + return true + } + if strings.HasSuffix(origin, ".vercel.app") { + return true + } + return false + }, + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + ExposeHeaders: []string{"Content-Length"}, + AllowCredentials: true, + MaxAge: 12 * time.Hour, + })) + + rateLimiter := middleware.NewRateLimiter(time.Duration(dep.Cfg.RateLimiterDurationInSec)*time.Second, dep.Cfg.RateLimiterRequestLimit, time.Duration(dep.Cfg.RateLimiterCleanupIntervalInSec)*time.Second) + r.Use(rateLimiter.RateLimit()) + + r.Use(sloggin.NewWithConfig(dep.Logger, logConfig)) + r.Use(middleware.ErrorHandler()) + + return r +} From b4510613d2b6a66a242366ca0069f53db2587246 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Wed, 4 Feb 2026 23:22:40 +0200 Subject: [PATCH 14/15] refactor/backend: add integration tests --- backend/internal/routers/user_router_test.go | 1218 ++++++++++++++++++ backend/internal/service/user_service.go | 6 +- backend/internal/testutil/testutil.go | 2 +- backend/internal/util/jwt/token_test.go | 1 - 4 files changed, 1224 insertions(+), 3 deletions(-) create mode 100644 backend/internal/routers/user_router_test.go diff --git a/backend/internal/routers/user_router_test.go b/backend/internal/routers/user_router_test.go new file mode 100644 index 0000000..4ec60ec --- /dev/null +++ b/backend/internal/routers/user_router_test.go @@ -0,0 +1,1218 @@ +package routers_test + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/pquerna/otp/totp" + "github.com/redis/go-redis/v9" + + "github.com/gin-gonic/gin" + "github.com/paularynty/transcendence/auth-service-go/internal/config" + "github.com/paularynty/transcendence/auth-service-go/internal/db" + "github.com/paularynty/transcendence/auth-service-go/internal/dto" + "github.com/paularynty/transcendence/auth-service-go/internal/routers" + "github.com/paularynty/transcendence/auth-service-go/internal/service" + "github.com/paularynty/transcendence/auth-service-go/internal/testutil" +) + +func testRouterFactory(t *testing.T, testCfg *config.Config, setDBDown bool) *gin.Engine { + t.Helper() + + dto.InitValidator() + + testLogger := testutil.NewTestLogger() + + if testCfg.DbAddress == "file::memory:?cache=shared" { + safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()) + testCfg.DbAddress = fmt.Sprintf("file:%s?mode=memory&cache=shared", safeName) + } + + myDB, err := db.GetDB(testCfg.DbAddress, testLogger) + if err != nil { + t.Fatalf("failed to init the test db, err: %v", err) + } + db.ResetDB(myDB, testLogger) + + var redisClient *redis.Client + + if testCfg.IsRedisEnabled { + mr := miniredis.RunT(t) + redisClient, err = db.GetRedis("redis://"+mr.Addr(), testCfg, testLogger) + if err != nil { + t.Fatalf("failed to init the test redis, err: %v", err) + } + } + + dep := testutil.NewTestDependency(testCfg, myDB, redisClient, testLogger) + + userService, err := service.NewUserService(dep) + + if err != nil { + t.Fatalf("faled to create user service") + } + + if setDBDown { + userService.Dep.DB = nil + } + + r := routers.SetupRouter(dep) + routers.UsersRouter(r.Group("/"), userService) + + return r +} + +func toJSON(t *testing.T, v any) *strings.Reader { + t.Helper() + + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + t.Fatalf("failed to serialize the obj %v, err: %v", v, err) + } + return strings.NewReader(string(b)) +} + +var testUsername1 = "test1" +var testEmail1 = "test1@test.com" +var testUsername2 = "test2" +var testEmail2 = "test2@test.com" +var testPwd = "Password.777" +var testNewPassword = "Password.888" +var testAvatar = "https://example.com/a.png" +var loginUsername1 = "loginuser1" +var loginEmail1 = "loginuser1@test.com" + +var mockRegisterRequest = map[string]string{ + "username": testUsername1, + "email": testEmail1, + "password": testPwd, +} + +var mockRegisterRequest2 = map[string]string{ + "username": testUsername2, + "email": testEmail2, + "password": testPwd, +} + +var mockUpdateUserPasswordRequest = map[string]string{ + "oldPassword": testPwd, + "newPassword": testNewPassword, +} + +var mockLoginUserByUsernameRequest = map[string]string{ + "identifier": testUsername1, + "password": testPwd, +} + +var mockLoginUserByEmailRequest = map[string]string{ + "identifier": testEmail1, + "password": testPwd, +} + +var mockUpdateUserRequest = map[string]string{ + "username": testUsername1, + "email": testEmail1, + "avatar": testAvatar, +} + +func TestCreateUserEndpoint(t *testing.T) { + testCases := []struct { + name string + setup func(t *testing.T, r *gin.Engine) + body any + want int + }{ + { + name: "happy", + body: mockRegisterRequest, + want: 201, + }, + { + name: "duplicate username", + setup: func(t *testing.T, r *gin.Engine) { + userReq := map[string]string{ + "username": "dupusername", + "email": "dupusername@test.com", + "password": testPwd, + } + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, userReq)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + }, + body: map[string]string{ + "username": "dupusername", + "email": "other_dupusername@test.com", + "password": testPwd, + }, + want: 409, + }, + { + name: "duplicate email", + setup: func(t *testing.T, r *gin.Engine) { + userReq := map[string]string{ + "username": "dupemail", + "email": "dupemail@test.com", + "password": testPwd, + } + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, userReq)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + }, + body: map[string]string{ + "username": "other_dupemail", + "email": "dupemail@test.com", + "password": testPwd, + }, + want: 409, + }, + { + name: "invalid email", + body: map[string]string{ + "username": testUsername1, + "email": "not-an-email", + "password": testPwd, + }, + want: 400, + }, + {name: "missing body", body: nil, want: 400}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RedisURL = "redis" + testCfg.IsRedisEnabled = true + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + body := tc.body + if tc.setup != nil { + tc.setup(t, r) + } + + w := httptest.NewRecorder() + var bodyReader io.Reader + if body != nil { + bodyReader = toJSON(t, body) + } else { + bodyReader = strings.NewReader("") + } + req, _ := http.NewRequest("POST", "/", bodyReader) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != tc.want { + t.Fatalf("expected: %d, got %d", tc.want, w.Code) + } + }) + } +} + +func TestLoginEndpoint(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + // Register a user once for login tests. + registerReq := map[string]string{ + "username": loginUsername1, + "email": loginEmail1, + "password": testPwd, + } + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, registerReq)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + testCases := []struct { + name string + body any + want int + }{ + { + name: "happy username", + body: map[string]string{ + "identifier": loginUsername1, + "password": testPwd, + }, + want: 200, + }, + { + name: "happy email", + body: map[string]string{ + "identifier": loginEmail1, + "password": testPwd, + }, + want: 200, + }, + { + name: "wrong password", + body: map[string]string{ + "identifier": loginEmail1, + "password": "WrongPassword.123", + }, + want: 401, + }, + { + name: "wrong identifier", + body: map[string]string{ + "identifier": "unknown@test.com", + "password": testPwd, + }, + want: 401, + }, + { + name: "missing body", + body: nil, + want: 400, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + var bodyReader io.Reader + if tc.body != nil { + bodyReader = toJSON(t, tc.body) + } else { + bodyReader = strings.NewReader("") + } + req, _ := http.NewRequest("POST", "/loginByIdentifier", bodyReader) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != tc.want { + t.Fatalf("expected: %d, got %d", tc.want, w.Code) + } + }) + } +} + +func TestUpdatePasswordEndpoint(t *testing.T) { + testCases := []struct { + name string + body any + want int + }{ + {name: "happy", body: mockUpdateUserPasswordRequest, want: 200}, + { + name: "wrong old password", + body: map[string]string{ + "oldPassword": "WrongPassword.123", + "newPassword": testNewPassword, + }, + want: 401, + }, + {name: "missing body", body: nil, want: 400}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + // setup user + login + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + var secondToken string + if tc.name == "happy" { + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup second login failed, got %d", w.Code) + } + var login2 dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login2); err != nil { + t.Fatalf("failed to unmarshal second login response: %v", err) + } + secondToken = login2.Token + } + + w = httptest.NewRecorder() + var bodyReader io.Reader + if tc.body != nil { + bodyReader = toJSON(t, tc.body) + } else { + bodyReader = strings.NewReader("") + } + req, _ = http.NewRequest("PUT", "/password", bodyReader) + req.Header.Add("Authorization", "Bearer "+login.Token) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != tc.want { + t.Fatalf("expected: %d, got %d", tc.want, w.Code) + } + + if tc.name == "happy" { + // Old token should be invalid after password update + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected old token to be invalid after password update, got %d", w.Code) + } + + if secondToken == "" { + t.Fatalf("expected second token to be set") + } + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+secondToken) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected second token to be invalid after password update, got %d", w.Code) + } + } + }) + } +} + +func TestUpdateProfileEndpoint(t *testing.T) { + createUser := func(t *testing.T, r *gin.Engine, body any) dto.UserWithoutTokenResponse { + t.Helper() + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, body)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + var user dto.UserWithoutTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &user); err != nil { + t.Fatalf("failed to unmarshal register response: %v", err) + } + return user + } + + loginUser := func(t *testing.T, r *gin.Engine, body any) dto.UserWithTokenResponse { + t.Helper() + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/loginByIdentifier", toJSON(t, body)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var user dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &user); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + return user + } + + testCases := []struct { + name string + body any + want int + check func(t *testing.T, body []byte) + }{ + { + name: "happy", + body: mockUpdateUserRequest, + want: 200, + }, + { + name: "avatar null", + body: map[string]any{ + "username": testUsername1, + "email": testEmail1, + "avatar": nil, + }, + want: 200, + check: func(t *testing.T, body []byte) { + var resp dto.UserWithoutTokenResponse + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if resp.Avatar != nil { + t.Fatalf("expected avatar to be nil, got %v", *resp.Avatar) + } + }, + }, + { + name: "avatar empty", + body: map[string]any{ + "username": testUsername1, + "email": testEmail1, + "avatar": "", + }, + want: 400, + check: func(t *testing.T, body []byte) { + var resp struct { + Error []string `json:"error"` + } + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("failed to unmarshal error response: %v", err) + } + if len(resp.Error) == 0 { + t.Fatalf("expected validation errors, got none") + } + found := false + for _, msg := range resp.Error { + if strings.Contains(msg, "Avatar") { + found = true + break + } + } + if !found { + t.Fatalf("expected avatar validation error, got %v", resp.Error) + } + }, + }, + { + name: "duplicate", + body: map[string]string{ + "username": testUsername2, + "email": testEmail2, + "avatar": testAvatar, + }, + want: 409, + }, + {name: "missing body", body: nil, want: 400}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + createUser(t, r, mockRegisterRequest) + createUser(t, r, mockRegisterRequest2) + login := loginUser(t, r, mockLoginUserByEmailRequest) + + w := httptest.NewRecorder() + var bodyReader io.Reader + if tc.body != nil { + bodyReader = toJSON(t, tc.body) + } else { + bodyReader = strings.NewReader("") + } + req, _ := http.NewRequest("PUT", "/me", bodyReader) + req.Header.Add("Authorization", "Bearer "+login.Token) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != tc.want { + t.Fatalf("expected: %d, got %d", tc.want, w.Code) + } + if tc.check != nil { + tc.check(t, w.Body.Bytes()) + } + }) + } +} + +func TestFriendsEndpoints(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RedisURL = "redis" + testCfg.IsRedisEnabled = true + testCfg.RateLimiterRequestLimit = 1000 + + testCases := []struct { + name string + run func(t *testing.T, r *gin.Engine, token string, user1 dto.UserWithoutTokenResponse, user2 dto.UserWithoutTokenResponse) + want int + }{ + { + name: "list empty", + run: func(t *testing.T, r *gin.Engine, token string, _ dto.UserWithoutTokenResponse, _ dto.UserWithoutTokenResponse) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/friends", nil) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("expected: 200, got %d", w.Code) + } + }, + want: 200, + }, + { + name: "add friend", + run: func(t *testing.T, r *gin.Engine, token string, _ dto.UserWithoutTokenResponse, user2 dto.UserWithoutTokenResponse) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/friends", toJSON(t, map[string]uint{"userId": user2.ID})) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("expected: 201, got %d", w.Code) + } + }, + want: 201, + }, + { + name: "add friend duplicate", + run: func(t *testing.T, r *gin.Engine, token string, _ dto.UserWithoutTokenResponse, user2 dto.UserWithoutTokenResponse) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/friends", toJSON(t, map[string]uint{"userId": user2.ID})) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup add friend failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/friends", toJSON(t, map[string]uint{"userId": user2.ID})) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 409 { + t.Fatalf("expected: 409, got %d", w.Code) + } + }, + want: 409, + }, + { + name: "add friend self", + run: func(t *testing.T, r *gin.Engine, token string, user1 dto.UserWithoutTokenResponse, _ dto.UserWithoutTokenResponse) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/friends", toJSON(t, map[string]uint{"userId": user1.ID})) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 400 { + t.Fatalf("expected: 400, got %d", w.Code) + } + }, + want: 400, + }, + { + name: "add friend not found", + run: func(t *testing.T, r *gin.Engine, token string, _ dto.UserWithoutTokenResponse, _ dto.UserWithoutTokenResponse) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/friends", toJSON(t, map[string]uint{"userId": 999999})) + req.Header.Add("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + if w.Code != 404 { + t.Fatalf("expected: 404, got %d", w.Code) + } + }, + want: 404, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := testRouterFactory(t, testCfg, false) + + // setup users + login + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register user1 failed, got %d", w.Code) + } + var user1 dto.UserWithoutTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &user1); err != nil { + t.Fatalf("failed to unmarshal user1 response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest2)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register user2 failed, got %d", w.Code) + } + + var user2 dto.UserWithoutTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &user2); err != nil { + t.Fatalf("failed to unmarshal user2 response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + tc.run(t, r, login.Token, user1, user2) + }) + } +} + +func TestValidateEndpoint(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("expected: 200, got %d", w.Code) + } +} + +func TestListUsersEndpoint(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register user1 failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest2)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register user2 failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("expected: 200, got %d", w.Code) + } +} + +func TestGoogleOAuthEndpoints(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + t.Run("login redirect", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/google/login", nil) + r.ServeHTTP(w, req) + if w.Code != 302 { + t.Fatalf("expected: 302, got %d", w.Code) + } + if w.Header().Get("Location") == "" { + t.Fatalf("expected redirect location header") + } + }) + + t.Run("callback missing code", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/google/callback?state=abc", nil) + r.ServeHTTP(w, req) + if w.Code != 400 { + t.Fatalf("expected: 400, got %d", w.Code) + } + }) + + t.Run("callback missing state", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/google/callback?code=abc", nil) + r.ServeHTTP(w, req) + if w.Code != 400 { + t.Fatalf("expected: 400, got %d", w.Code) + } + }) +} + +func TestLogoutEndpoint(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + // Create a second token + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup second login failed, got %d", w.Code) + } + var login2 dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login2); err != nil { + t.Fatalf("failed to unmarshal second login response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("DELETE", "/logout", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 204 { + t.Fatalf("expected: 204, got %d", w.Code) + } + + // Token should be invalid after logout + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected token to be invalid after logout, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login2.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected second token to be invalid after logout, got %d", w.Code) + } +} + +func TestDeleteUserEndpoint(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RedisURL = "redis" + testCfg.IsRedisEnabled = true + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + // Create a second token + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup second login failed, got %d", w.Code) + } + var login2 dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login2); err != nil { + t.Fatalf("failed to unmarshal second login response: %v", err) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("DELETE", "/me", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 204 { + t.Fatalf("expected: 204, got %d", w.Code) + } + + // Token should be invalid after deletion + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected token to be invalid after deletion, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+login2.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected second token to be invalid after deletion, got %d", w.Code) + } +} + +func TestTwoFAEndpoints(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RedisURL = "redis" + testCfg.IsRedisEnabled = true + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + // setup user + login + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("setup register failed, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("setup login failed, got %d", w.Code) + } + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + + // 2FA setup + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa/setup", nil) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("2fa setup, expected: 200, got %d", w.Code) + } + var setup dto.TwoFASetupResponse + if err := json.Unmarshal(w.Body.Bytes(), &setup); err != nil { + t.Fatalf("failed to unmarshal 2fa setup response: %v", err) + } + twoFACode, err := totp.GenerateCode(setup.TwoFASecret, time.Now()) + if err != nil { + t.Fatalf("failed to generate 2fa code: %v", err) + } + + // 2FA confirm invalid code + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa/confirm", toJSON(t, map[string]string{ + "twoFaCode": "000000", + "setupToken": setup.SetupToken, + })) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 400 { + t.Fatalf("2fa confirm invalid, expected: 400, got %d", w.Code) + } + + // 2FA confirm happy + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa/confirm", toJSON(t, map[string]string{ + "twoFaCode": twoFACode, + "setupToken": setup.SetupToken, + })) + req.Header.Add("Authorization", "Bearer "+login.Token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("2fa confirm, expected: 200, got %d", w.Code) + } + + // login now returns 428 + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByUsernameRequest)) + r.ServeHTTP(w, req) + if w.Code != 428 { + t.Fatalf("login with 2fa enabled, expected: 428, got %d", w.Code) + } + var pending dto.TwoFAPendingUserResponse + if err := json.Unmarshal(w.Body.Bytes(), &pending); err != nil { + t.Fatalf("failed to unmarshal 2fa pending response: %v", err) + } + + // 2FA submit invalid code + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa", toJSON(t, map[string]string{ + "twoFaCode": "000000", + "sessionToken": pending.SessionToken, + })) + r.ServeHTTP(w, req) + if w.Code != 400 { + t.Fatalf("2fa submit invalid, expected: 400, got %d", w.Code) + } + + // 2FA submit happy + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa", toJSON(t, map[string]string{ + "twoFaCode": twoFACode, + "sessionToken": pending.SessionToken, + })) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("2fa submit, expected: 200, got %d", w.Code) + } + var afterChallenge dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &afterChallenge); err != nil { + t.Fatalf("failed to unmarshal 2fa submit response: %v", err) + } + + // Create a second token after 2FA enabled + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByUsernameRequest)) + r.ServeHTTP(w, req) + if w.Code != 428 { + t.Fatalf("login with 2fa enabled, expected: 428, got %d", w.Code) + } + var pending2 dto.TwoFAPendingUserResponse + if err := json.Unmarshal(w.Body.Bytes(), &pending2); err != nil { + t.Fatalf("failed to unmarshal 2fa pending response: %v", err) + } + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/2fa", toJSON(t, map[string]string{ + "twoFaCode": twoFACode, + "sessionToken": pending2.SessionToken, + })) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("2fa submit second token, expected: 200, got %d", w.Code) + } + var afterChallenge2 dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &afterChallenge2); err != nil { + t.Fatalf("failed to unmarshal 2fa submit second response: %v", err) + } + + // 2FA disable wrong password + w = httptest.NewRecorder() + req, _ = http.NewRequest("PUT", "/2fa/disable", toJSON(t, map[string]string{"password": "WrongPassword.123"})) + req.Header.Add("Authorization", "Bearer "+afterChallenge.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("2fa disable wrong password, expected: 401, got %d", w.Code) + } + + // 2FA disable happy + w = httptest.NewRecorder() + req, _ = http.NewRequest("PUT", "/2fa/disable", toJSON(t, map[string]string{"password": testPwd})) + req.Header.Add("Authorization", "Bearer "+afterChallenge.Token) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("2fa disable, expected: 200, got %d", w.Code) + } + + // Tokens should be invalid after 2FA disable + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+afterChallenge.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected token to be invalid after 2fa disable, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/validate", nil) + req.Header.Add("Authorization", "Bearer "+afterChallenge2.Token) + r.ServeHTTP(w, req) + if w.Code != 401 { + t.Fatalf("expected second token to be invalid after 2fa disable, got %d", w.Code) + } +} + +func TestDBIsDown(t *testing.T) { + testCfg := testutil.NewTestConfig() + r := testRouterFactory(t, testCfg, true) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + + if w.Code != 500 { + t.Fatalf("expected: 500, got: %d", w.Code) + } +} + +func TestAuthRequiredEndpoints(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + testCases := []struct { + name string + method string + path string + }{ + {name: "get me", method: http.MethodGet, path: "/me"}, + {name: "update password", method: http.MethodPut, path: "/password"}, + {name: "update profile", method: http.MethodPut, path: "/me"}, + {name: "logout", method: http.MethodDelete, path: "/logout"}, + {name: "delete me", method: http.MethodDelete, path: "/me"}, + {name: "2fa setup", method: http.MethodPost, path: "/2fa/setup"}, + {name: "2fa confirm", method: http.MethodPost, path: "/2fa/confirm"}, + {name: "2fa disable", method: http.MethodPut, path: "/2fa/disable"}, + {name: "get friends", method: http.MethodGet, path: "/friends"}, + {name: "add friend", method: http.MethodPost, path: "/friends"}, + {name: "validate user", method: http.MethodPost, path: "/validate"}, + {name: "list users", method: http.MethodGet, path: "/"}, + } + + for _, tc := range testCases { + t.Run(tc.name+"/no-token", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest(tc.method, tc.path, nil) + r.ServeHTTP(w, req) + + if w.Code != 401 { + t.Fatalf("expected: 401, got %d", w.Code) + } + }) + + t.Run(tc.name+"/invalid-token", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest(tc.method, tc.path, nil) + req.Header.Add("Authorization", "Bearer aaa") + r.ServeHTTP(w, req) + + if w.Code != 401 { + t.Fatalf("expected: 401, got %d", w.Code) + } + }) + } +} + +func TestValidationMissingBody(t *testing.T) { + testCfg := testutil.NewTestConfig() + testCfg.RateLimiterRequestLimit = 1000 + r := testRouterFactory(t, testCfg, false) + + // Create user and get a valid token for auth-required endpoints. + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", toJSON(t, mockRegisterRequest)) + r.ServeHTTP(w, req) + if w.Code != 201 { + t.Fatalf("registering user, expected: 201, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/loginByIdentifier", toJSON(t, mockLoginUserByEmailRequest)) + r.ServeHTTP(w, req) + if w.Code != 200 { + t.Fatalf("login user, expected: 200, got %d", w.Code) + } + + var login dto.UserWithTokenResponse + if err := json.Unmarshal(w.Body.Bytes(), &login); err != nil { + t.Fatalf("failed to unmarshal login response: %v", err) + } + if login.Token == "" { + t.Fatalf("login user, expected token to be set") + } + + type validationErrorResp struct { + Error []string `json:"error"` + } + + assertValidationFields := func(t *testing.T, body []byte, expectedFields []string) { + t.Helper() + + var resp validationErrorResp + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("failed to unmarshal validation error response: %v", err) + } + if len(resp.Error) == 0 { + t.Fatalf("expected validation errors, got none") + } + + for _, field := range expectedFields { + found := false + for _, msg := range resp.Error { + if strings.Contains(msg, field) { + found = true + break + } + } + if !found { + t.Fatalf("expected validation error to mention field %q, got: %v", field, resp.Error) + } + } + } + + testCases := []struct { + name string + method string + path string + needsAuth bool + expectedFields []string + }{ + {name: "create user", method: http.MethodPost, path: "/", needsAuth: false, expectedFields: []string{"Username", "Email", "Password"}}, + {name: "login by identifier", method: http.MethodPost, path: "/loginByIdentifier", needsAuth: false, expectedFields: []string{"Identifier", "Password"}}, + {name: "2fa submit", method: http.MethodPost, path: "/2fa", needsAuth: false, expectedFields: []string{"TwoFACode", "SessionToken"}}, + {name: "update password", method: http.MethodPut, path: "/password", needsAuth: true, expectedFields: []string{"OldPassword", "NewPassword"}}, + {name: "update profile", method: http.MethodPut, path: "/me", needsAuth: true, expectedFields: []string{"Username", "Email"}}, + {name: "2fa confirm", method: http.MethodPost, path: "/2fa/confirm", needsAuth: true, expectedFields: []string{"TwoFACode", "SetupToken"}}, + {name: "2fa disable", method: http.MethodPut, path: "/2fa/disable", needsAuth: true, expectedFields: []string{"Password"}}, + {name: "add friend", method: http.MethodPost, path: "/friends", needsAuth: true, expectedFields: []string{"UserID"}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest(tc.method, tc.path, strings.NewReader("{}")) + req.Header.Set("Content-Type", "application/json") + if tc.needsAuth { + req.Header.Add("Authorization", "Bearer "+login.Token) + } + r.ServeHTTP(w, req) + + if w.Code != 400 { + t.Fatalf("expected: 400, got %d", w.Code) + } + assertValidationFields(t, w.Body.Bytes(), tc.expectedFields) + }) + } +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 6b15d6b..87604f5 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -191,7 +191,11 @@ func (s *UserService) UpdateUserProfile(ctx context.Context, userID uint, reques } modelUser.Username = request.Username - modelUser.Avatar = request.Avatar + if request.Avatar != nil && strings.TrimSpace(*request.Avatar) == "" { + modelUser.Avatar = nil + } else { + modelUser.Avatar = request.Avatar + } modelUser.Email = request.Email err = s.Dep.DB.WithContext(ctx).Save(&modelUser).Error diff --git a/backend/internal/testutil/testutil.go b/backend/internal/testutil/testutil.go index 3bde0f1..e5c35eb 100644 --- a/backend/internal/testutil/testutil.go +++ b/backend/internal/testutil/testutil.go @@ -20,7 +20,7 @@ func NewTestLogger() *slog.Logger { func NewTestConfig() *config.Config { return &config.Config{ GinMode: "test", - DbAddress: "inmemory://test", + DbAddress: "file::memory:?cache=shared", JwtSecret: "test-jwt-secret", UserTokenExpiry: 5, OauthStateTokenExpiry: 5, diff --git a/backend/internal/util/jwt/token_test.go b/backend/internal/util/jwt/token_test.go index 29dc6ab..90e0e4c 100644 --- a/backend/internal/util/jwt/token_test.go +++ b/backend/internal/util/jwt/token_test.go @@ -148,4 +148,3 @@ func TestTwoFAToken(t *testing.T) { t.Fatalf("expected error, got nil") } } - From 20d69124e2e353167dbe887a45210dd1865f7344 Mon Sep 17 00:00:00 2001 From: Xin Feng <126309503+danielxfeng@users.noreply.github.com> Date: Wed, 4 Feb 2026 23:32:59 +0200 Subject: [PATCH 15/15] refactor/backend: improve Redis setup in tests for better error handling and cleanup --- backend/internal/routers/user_router_test.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/backend/internal/routers/user_router_test.go b/backend/internal/routers/user_router_test.go index 4ec60ec..ca65769 100644 --- a/backend/internal/routers/user_router_test.go +++ b/backend/internal/routers/user_router_test.go @@ -44,11 +44,22 @@ func testRouterFactory(t *testing.T, testCfg *config.Config, setDBDown bool) *gi var redisClient *redis.Client if testCfg.IsRedisEnabled { - mr := miniredis.RunT(t) - redisClient, err = db.GetRedis("redis://"+mr.Addr(), testCfg, testLogger) + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("failed to start miniredis, err: %v", err) + } + t.Cleanup(func() { + mr.Close() + }) + + testCfg.RedisURL = "redis://" + mr.Addr() + redisClient, err = db.GetRedis(testCfg.RedisURL, testCfg, testLogger) if err != nil { t.Fatalf("failed to init the test redis, err: %v", err) } + t.Cleanup(func() { + db.CloseRedis(redisClient, testLogger) + }) } dep := testutil.NewTestDependency(testCfg, myDB, redisClient, testLogger)