diff --git a/billing/model.go b/billing/model.go index 10f6bf0..6a11efc 100644 --- a/billing/model.go +++ b/billing/model.go @@ -7,9 +7,13 @@ import ( "os" "strconv" "time" + + "github.com/michaelboegner/interviewer/user" ) -type Billing struct { +type BillingService struct { + BillingRepo BillingRepo + UserRepo user.UserRepo APIKey string VariantIDIndividual int VariantIDPro int @@ -102,7 +106,7 @@ type BillingRepo interface { MarkWebhookProcessed(id string, event string) error } -func NewBilling(logger *slog.Logger) (*Billing, error) { +func NewBillingService(billingRepo BillingRepo, userRepo user.UserRepo, logger *slog.Logger) (*BillingService, error) { individualID, err := strconv.Atoi(os.Getenv("LEMON_VARIANT_ID_INDIVIDUAL")) if err != nil { return nil, fmt.Errorf("invalid INDIVIDUAL ID: %w", err) @@ -115,7 +119,9 @@ func NewBilling(logger *slog.Logger) (*Billing, error) { if err != nil { return nil, fmt.Errorf("invalid PREMIUM ID: %w", err) } - return &Billing{ + return &BillingService{ + BillingRepo: billingRepo, + UserRepo: userRepo, APIKey: os.Getenv("LEMON_API_KEY"), VariantIDIndividual: individualID, VariantIDPro: proID, diff --git a/billing/service.go b/billing/service.go index 7f32790..6b6312f 100644 --- a/billing/service.go +++ b/billing/service.go @@ -12,11 +12,9 @@ import ( "os" "strconv" "time" - - "github.com/michaelboegner/interviewer/user" ) -func (b *Billing) RequestCheckoutSession(userEmail string, variantID int) (string, error) { +func (b *BillingService) RequestCheckoutSession(userEmail string, variantID int) (string, error) { payload := CheckoutPayload{ Data: CheckoutData{ Type: "checkouts", @@ -83,7 +81,7 @@ func (b *Billing) RequestCheckoutSession(userEmail string, variantID int) (strin return result.Data.Attributes.URL, nil } -func (b *Billing) RequestDeleteSubscription(subscriptionID string) error { +func (b *BillingService) RequestDeleteSubscription(subscriptionID string) error { client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest("DELETE", "https://api.lemonsqueezy.com/v1/subscriptions/"+subscriptionID, nil) @@ -108,7 +106,7 @@ func (b *Billing) RequestDeleteSubscription(subscriptionID string) error { return nil } -func (b *Billing) RequestResumeSubscription(subscriptionID string) error { +func (b *BillingService) RequestResumeSubscription(subscriptionID string) error { client := &http.Client{Timeout: 10 * time.Second} payload := map[string]interface{}{ @@ -148,7 +146,7 @@ func (b *Billing) RequestResumeSubscription(subscriptionID string) error { return nil } -func (b *Billing) RequestUpdateSubscriptionVariant(subscriptionID string, newVariantID int) error { +func (b *BillingService) RequestUpdateSubscriptionVariant(subscriptionID string, newVariantID int) error { payload := map[string]interface{}{ "data": map[string]interface{}{ "type": "subscriptions", @@ -179,15 +177,15 @@ func (b *Billing) RequestUpdateSubscriptionVariant(subscriptionID string, newVar return nil } -func (b *Billing) VerifyBillingSignature(signature string, body []byte, secret string) bool { +func (b *BillingService) VerifyBillingSignature(signature string, body []byte, secret string) bool { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(body) expected := hex.EncodeToString(mac.Sum(nil)) return hmac.Equal([]byte(expected), []byte(signature)) } -func (b *Billing) ApplyCredits(userRepo user.UserRepo, billingRepo BillingRepo, email string, variantID int) error { - user, err := userRepo.GetUserByEmail(email) +func (b *BillingService) ApplyCredits(email string, variantID int) error { + user, err := b.UserRepo.GetUserByEmail(email) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -216,7 +214,7 @@ func (b *Billing) ApplyCredits(userRepo user.UserRepo, billingRepo BillingRepo, return fmt.Errorf("unknown variant ID: %d", variantID) } - if err := userRepo.AddCredits(user.ID, credits, creditType); err != nil { + if err := b.UserRepo.AddCredits(user.ID, credits, creditType); err != nil { b.Logger.Error("repo.AddCredits failed", "error", err) return err } @@ -227,7 +225,7 @@ func (b *Billing) ApplyCredits(userRepo user.UserRepo, billingRepo BillingRepo, CreditType: creditType, Reason: reason, } - if err := billingRepo.LogCreditTransaction(tx); err != nil { + if err := b.BillingRepo.LogCreditTransaction(tx); err != nil { b.Logger.Error("Warning: credit granted but failed to log transaction", "error", err) return err } @@ -235,8 +233,8 @@ func (b *Billing) ApplyCredits(userRepo user.UserRepo, billingRepo BillingRepo, return nil } -func (b *Billing) DeductCredits(userRepo user.UserRepo, billingRepo BillingRepo, orderAttrs OrderAttributes) error { - user, err := userRepo.GetUserByEmail(orderAttrs.UserEmail) +func (b *BillingService) DeductCredits(orderAttrs OrderAttributes) error { + user, err := b.UserRepo.GetUserByEmail(orderAttrs.UserEmail) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -267,7 +265,7 @@ func (b *Billing) DeductCredits(userRepo user.UserRepo, billingRepo BillingRepo, return fmt.Errorf("unknown variant ID: %d", variantID) } - if err := userRepo.AddCredits(user.ID, -credits, creditType); err != nil { + if err := b.UserRepo.AddCredits(user.ID, -credits, creditType); err != nil { b.Logger.Error("repo.DeductCredits failed", "error", err) return err } @@ -278,7 +276,7 @@ func (b *Billing) DeductCredits(userRepo user.UserRepo, billingRepo BillingRepo, CreditType: creditType, Reason: reason, } - if err := billingRepo.LogCreditTransaction(tx); err != nil { + if err := b.BillingRepo.LogCreditTransaction(tx); err != nil { b.Logger.Warn("Refund deduction succeeded but failed to log transaction", "error", err) return err } @@ -286,8 +284,8 @@ func (b *Billing) DeductCredits(userRepo user.UserRepo, billingRepo BillingRepo, return nil } -func (b *Billing) CreateSubscription(userRepo user.UserRepo, subCreatedAttrs SubscriptionAttributes, subscriptionID string) error { - user, err := userRepo.GetUserByEmail(subCreatedAttrs.UserEmail) +func (b *BillingService) CreateSubscription(subCreatedAttrs SubscriptionAttributes, subscriptionID string) error { + user, err := b.UserRepo.GetUserByEmail(subCreatedAttrs.UserEmail) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -304,7 +302,7 @@ func (b *Billing) CreateSubscription(userRepo user.UserRepo, subCreatedAttrs Sub return fmt.Errorf("unknown variant ID: %d", subCreatedAttrs.VariantID) } - err = userRepo.UpdateSubscriptionData( + err = b.UserRepo.UpdateSubscriptionData( user.ID, "active", tier, @@ -320,14 +318,14 @@ func (b *Billing) CreateSubscription(userRepo user.UserRepo, subCreatedAttrs Sub return nil } -func (b *Billing) CancelSubscription(userRepo user.UserRepo, email string) error { - user, err := userRepo.GetUserByEmail(email) +func (b *BillingService) CancelSubscription(email string) error { + user, err := b.UserRepo.GetUserByEmail(email) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err } - err = userRepo.UpdateSubscriptionStatusData( + err = b.UserRepo.UpdateSubscriptionStatusData( user.ID, "cancelled", ) @@ -339,14 +337,14 @@ func (b *Billing) CancelSubscription(userRepo user.UserRepo, email string) error return nil } -func (b *Billing) ResumeSubscription(userRepo user.UserRepo, email string) error { - user, err := userRepo.GetUserByEmail(email) +func (b *BillingService) ResumeSubscription(email string) error { + user, err := b.UserRepo.GetUserByEmail(email) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err } - err = userRepo.UpdateSubscriptionStatusData( + err = b.UserRepo.UpdateSubscriptionStatusData( user.ID, "active", ) @@ -358,14 +356,14 @@ func (b *Billing) ResumeSubscription(userRepo user.UserRepo, email string) error return nil } -func (b *Billing) ExpireSubscription(userRepo user.UserRepo, billingRepo BillingRepo, email string) error { - user, err := userRepo.GetUserByEmail(email) +func (b *BillingService) ExpireSubscription(email string) error { + user, err := b.UserRepo.GetUserByEmail(email) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err } - err = userRepo.UpdateSubscriptionStatusData( + err = b.UserRepo.UpdateSubscriptionStatusData( user.ID, "expired", ) @@ -375,7 +373,7 @@ func (b *Billing) ExpireSubscription(userRepo user.UserRepo, billingRepo Billing } if user.SubscriptionCredits > 0 { - err = userRepo.AddCredits(user.ID, -user.SubscriptionCredits, "subscription") + err = b.UserRepo.AddCredits(user.ID, -user.SubscriptionCredits, "subscription") if err != nil { b.Logger.Error("repo.AddCredits failed", "error", err) return err @@ -387,7 +385,7 @@ func (b *Billing) ExpireSubscription(userRepo user.UserRepo, billingRepo Billing CreditType: "subscription", Reason: "Zeroed out credits on subscription expiration", } - if err := billingRepo.LogCreditTransaction(tx); err != nil { + if err := b.BillingRepo.LogCreditTransaction(tx); err != nil { b.Logger.Warn("Zero-out succeeded but failed to log transaction", "error", err) } } @@ -395,8 +393,8 @@ func (b *Billing) ExpireSubscription(userRepo user.UserRepo, billingRepo Billing return nil } -func (b *Billing) RenewSubscription(userRepo user.UserRepo, billingRepo BillingRepo, subRenewAttrs SubscriptionRenewAttributes) error { - user, err := userRepo.GetUserByEmail(subRenewAttrs.UserEmail) +func (b *BillingService) RenewSubscription(subRenewAttrs SubscriptionRenewAttributes) error { + user, err := b.UserRepo.GetUserByEmail(subRenewAttrs.UserEmail) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -422,7 +420,7 @@ func (b *Billing) RenewSubscription(userRepo user.UserRepo, billingRepo BillingR return fmt.Errorf("unknown user.SubscriptionTier: %s", user.SubscriptionTier) } - if err := userRepo.AddCredits(user.ID, credits, "subscription"); err != nil { + if err := b.UserRepo.AddCredits(user.ID, credits, "subscription"); err != nil { b.Logger.Error("repo.AddCredits failed", "error", err) return err } @@ -433,7 +431,7 @@ func (b *Billing) RenewSubscription(userRepo user.UserRepo, billingRepo BillingR CreditType: "subscription", Reason: reason, } - if err := billingRepo.LogCreditTransaction(tx); err != nil { + if err := b.BillingRepo.LogCreditTransaction(tx); err != nil { b.Logger.Warn("credit granted but failed to log transaction", "error", err) return err } @@ -441,8 +439,8 @@ func (b *Billing) RenewSubscription(userRepo user.UserRepo, billingRepo BillingR return nil } -func (b *Billing) ChangeSubscription(userRepo user.UserRepo, billingRepo BillingRepo, subChangedAttrs SubscriptionAttributes) error { - user, err := userRepo.GetUserByEmail(subChangedAttrs.UserEmail) +func (b *BillingService) ChangeSubscription(subChangedAttrs SubscriptionAttributes) error { + user, err := b.UserRepo.GetUserByEmail(subChangedAttrs.UserEmail) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -471,7 +469,7 @@ func (b *Billing) ChangeSubscription(userRepo user.UserRepo, billingRepo Billing return fmt.Errorf("unknown user.SubscriptionTier: %s", user.SubscriptionTier) } - if err := userRepo.AddCredits(user.ID, credits, "subscription"); err != nil { + if err := b.UserRepo.AddCredits(user.ID, credits, "subscription"); err != nil { b.Logger.Error("repo.AddCredits failed", "error", err) return err } @@ -482,7 +480,7 @@ func (b *Billing) ChangeSubscription(userRepo user.UserRepo, billingRepo Billing CreditType: "subscription", Reason: reason, } - if err := billingRepo.LogCreditTransaction(tx); err != nil { + if err := b.BillingRepo.LogCreditTransaction(tx); err != nil { b.Logger.Warn("credit granted but failed to log transaction", "error", err) return err } @@ -490,8 +488,8 @@ func (b *Billing) ChangeSubscription(userRepo user.UserRepo, billingRepo Billing return nil } -func (b *Billing) UpdateSubscription(userRepo user.UserRepo, subUpdatedAttrs SubscriptionAttributes, subscriptionID string) error { - user, err := userRepo.GetUserByEmail(subUpdatedAttrs.UserEmail) +func (b *BillingService) UpdateSubscription(subUpdatedAttrs SubscriptionAttributes, subscriptionID string) error { + user, err := b.UserRepo.GetUserByEmail(subUpdatedAttrs.UserEmail) if err != nil { b.Logger.Error("repo.GetUserByEmail failed", "error", err) return err @@ -508,7 +506,7 @@ func (b *Billing) UpdateSubscription(userRepo user.UserRepo, subUpdatedAttrs Sub return fmt.Errorf("unknown variant ID: %d", subUpdatedAttrs.VariantID) } - err = userRepo.UpdateSubscriptionData( + err = b.UserRepo.UpdateSubscriptionData( user.ID, subUpdatedAttrs.Status, tier, diff --git a/billing/service_test.go b/billing/service_test.go index 07b13b9..9cb65c4 100644 --- a/billing/service_test.go +++ b/billing/service_test.go @@ -4,23 +4,19 @@ import ( "crypto/hmac" "crypto/sha256" "fmt" - "log" "log/slog" "os" - "strings" "testing" "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/user" ) -func NewTestBilling() *billing.Billing { - handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelDebug, - }) - logger := slog.New(handler) - - return &billing.Billing{ +func NewTestBillingService(billingRepo billing.BillingRepo, userRepo user.UserRepo, logger *slog.Logger) *billing.BillingService { + return &billing.BillingService{ + BillingRepo: billingRepo, + UserRepo: userRepo, + APIKey: "", VariantIDIndividual: 1, VariantIDPro: 2, VariantIDPremium: 3, @@ -89,20 +85,24 @@ func TestApplyCredits(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) + logger := slog.New(handler) - userRepo := user.NewMockRepo() billingRepo := billing.NewMockRepo() + userRepo := user.NewMockRepo() - userRepo.FailGetUserByEmail = tc.failUser - userRepo.FailAddCredits = tc.failCredit - billingRepo.FailLogCreditTransaction = tc.failLog + billingService := NewTestBillingService(billingRepo, userRepo, logger) - b := NewTestBilling() + if mockUserRepo, ok := billingService.UserRepo.(*user.MockRepo); ok { + mockUserRepo.FailGetUserByEmail = tc.failUser + mockUserRepo.FailAddCredits = tc.failCredit + } + if mockBillingRepo, ok := billingService.BillingRepo.(*billing.MockRepo); ok { + mockBillingRepo.FailLogCreditTransaction = tc.failLog + } + + err := billingService.ApplyCredits("test@example.com", tc.variantID) - err := b.ApplyCredits(userRepo, billingRepo, "test@example.com", tc.variantID) if tc.expectErr && err == nil { t.Fatal("expected error but got nil") } @@ -166,18 +166,21 @@ func TestDeductCredits(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) + logger := slog.New(handler) - userRepo := user.NewMockRepo() billingRepo := billing.NewMockRepo() + userRepo := user.NewMockRepo() - userRepo.FailGetUserByEmail = tc.failUser - userRepo.FailAddCredits = tc.failCredit - billingRepo.FailLogCreditTransaction = tc.failLog + billingService := NewTestBillingService(billingRepo, userRepo, logger) - b := NewTestBilling() + if mockUserRepo, ok := billingService.UserRepo.(*user.MockRepo); ok { + mockUserRepo.FailGetUserByEmail = tc.failUser + mockUserRepo.FailAddCredits = tc.failCredit + } + if mockBillingRepo, ok := billingService.BillingRepo.(*billing.MockRepo); ok { + mockBillingRepo.FailLogCreditTransaction = tc.failLog + } attrs := billing.OrderAttributes{ UserEmail: "test@example.com", @@ -188,7 +191,7 @@ func TestDeductCredits(t *testing.T) { }, } - err := b.DeductCredits(userRepo, billingRepo, attrs) + err := billingService.DeductCredits(attrs) if tc.expectErr && err == nil { t.Fatal("expected error but got nil") } @@ -200,14 +203,21 @@ func TestDeductCredits(t *testing.T) { } func TestVerifyBillingSignature(t *testing.T) { - b := NewTestBilling() + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) + logger := slog.New(handler) + + billingRepo := billing.NewMockRepo() + userRepo := user.NewMockRepo() + + billingService := NewTestBillingService(billingRepo, userRepo, logger) + body := []byte(`{"key":"value"}`) secret := "testsecret" mac := hmacSha256(body, secret) - if !b.VerifyBillingSignature(mac, body, secret) { + if !billingService.VerifyBillingSignature(mac, body, secret) { t.Fatal("expected signature to be valid") } - if b.VerifyBillingSignature("invalid", body, secret) { + if billingService.VerifyBillingSignature("invalid", body, secret) { t.Fatal("expected signature to be invalid") } } @@ -217,9 +227,3 @@ func hmacSha256(message []byte, secret string) string { mac.Write(message) return fmt.Sprintf("%x", mac.Sum(nil)) } - -func showLogsIfFail(t *testing.T, name string, buf strings.Builder) { - if t.Failed() { - t.Logf("---- logs for test: %s ----\n%s\n", name, buf.String()) - } -} diff --git a/chatgpt/model.go b/chatgpt/model.go index 9a1c81b..6aec051 100644 --- a/chatgpt/model.go +++ b/chatgpt/model.go @@ -23,7 +23,7 @@ type ChatGPTResponse struct { Level string `json:"level"` } -type OpenAIClient struct { +type AIService struct { APIKey string Logger *slog.Logger } @@ -45,8 +45,8 @@ func (e *OpenAIError) Error() string { return fmt.Sprintf("OpenAI error %d: %s", e.StatusCode, e.Message) } -func NewOpenAI(logger *slog.Logger) *OpenAIClient { - return &OpenAIClient{ +func NewAIService(logger *slog.Logger) *AIService { + return &AIService{ APIKey: os.Getenv("OPENAI_API_KEY"), Logger: logger, } diff --git a/chatgpt/service.go b/chatgpt/service.go index 03c6d19..f388030 100644 --- a/chatgpt/service.go +++ b/chatgpt/service.go @@ -11,7 +11,7 @@ import ( "strings" ) -func (c *OpenAIClient) GetChatGPTResponse(prompt string) (*ChatGPTResponse, error) { +func (c *AIService) GetChatGPTResponse(prompt string) (*ChatGPTResponse, error) { ctx := context.Background() var messagesArray []map[string]string @@ -84,7 +84,7 @@ func (c *OpenAIClient) GetChatGPTResponse(prompt string) (*ChatGPTResponse, erro return &chatGPTResponse, nil } -func (c *OpenAIClient) GetChatGPTResponseConversation(conversationHistory []map[string]string) (*ChatGPTResponse, error) { +func (c *AIService) GetChatGPTResponseConversation(conversationHistory []map[string]string) (*ChatGPTResponse, error) { ctx := context.Background() requestBody, err := json.Marshal(map[string]interface{}{ @@ -151,7 +151,7 @@ func (c *OpenAIClient) GetChatGPTResponseConversation(conversationHistory []map[ return &chatGPTResponse, nil } -func (c *OpenAIClient) GetChatGPT35Response(prompt string) (*ChatGPTResponse, error) { +func (c *AIService) GetChatGPT35Response(prompt string) (*ChatGPTResponse, error) { ctx := context.Background() var messagesArray []map[string]string @@ -224,7 +224,7 @@ func (c *OpenAIClient) GetChatGPT35Response(prompt string) (*ChatGPTResponse, er return &chatGPTResponse, nil } -func (c *OpenAIClient) ExtractJDInput(jd string) (*JDParsedOutput, error) { +func (c *AIService) ExtractJDInput(jd string) (*JDParsedOutput, error) { systemPrompt := BuildJDPromptInput(jd) response, err := c.GetChatGPT35Response(systemPrompt) if err != nil { @@ -239,7 +239,7 @@ func (c *OpenAIClient) ExtractJDInput(jd string) (*JDParsedOutput, error) { }, nil } -func (c *OpenAIClient) ExtractJDSummary(jdInput *JDParsedOutput) (string, error) { +func (c *AIService) ExtractJDSummary(jdInput *JDParsedOutput) (string, error) { jdJSON, err := json.MarshalIndent(jdInput, "", " ") if err != nil { return "", fmt.Errorf("failed to marshal JDParsedOutput: %w", err) diff --git a/conversation/model.go b/conversation/model.go index f3fc1a8..052654f 100644 --- a/conversation/model.go +++ b/conversation/model.go @@ -1,6 +1,12 @@ package conversation -import "time" +import ( + "log/slog" + "time" + + "github.com/michaelboegner/interviewer/chatgpt" + "github.com/michaelboegner/interviewer/interview" +) type Author string @@ -64,6 +70,22 @@ type Message struct { Content string `json:"content"` } +type ConversationService struct { + ConversationRepo ConversationRepo + InterviewRepo interview.InterviewRepo + AIService chatgpt.AIClient + Logger *slog.Logger +} + +func NewConversationService(conversationRepo ConversationRepo, interviewRepo interview.InterviewRepo, aiService chatgpt.AIClient, logger *slog.Logger) *ConversationService { + return &ConversationService{ + ConversationRepo: conversationRepo, + InterviewRepo: interviewRepo, + AIService: aiService, + Logger: logger, + } +} + type ConversationRepo interface { CheckForConversation(interviewID int) (bool, error) GetConversation(interviewID int) (*Conversation, error) diff --git a/conversation/service.go b/conversation/service.go index 4d1d4e1..1081b88 100644 --- a/conversation/service.go +++ b/conversation/service.go @@ -3,16 +3,13 @@ package conversation import ( "errors" "log" - - "github.com/michaelboegner/interviewer/chatgpt" - "github.com/michaelboegner/interviewer/interview" ) -func CheckForConversation(repo ConversationRepo, interviewID int) (bool, error) { - return repo.CheckForConversation(interviewID) +func (c *ConversationService) CheckForConversation(interviewID int) (bool, error) { + return c.ConversationRepo.CheckForConversation(interviewID) } -func CreateEmptyConversation(repo ConversationRepo, interviewID int, subTopic string) (int, error) { +func (c *ConversationService) CreateEmptyConversation(interviewID int, subTopic string) (int, error) { conversation := &Conversation{ Topics: ClonePredefinedTopics(), CurrentTopic: 1, @@ -20,7 +17,7 @@ func CreateEmptyConversation(repo ConversationRepo, interviewID int, subTopic st CurrentQuestionNumber: 1, } - conversationID, err := repo.CreateConversation(interviewID, conversation) + conversationID, err := c.ConversationRepo.CreateConversation(interviewID, conversation) if err != nil { log.Printf("CreateConversation failed: %v", err) return 0, err @@ -29,10 +26,7 @@ func CreateEmptyConversation(repo ConversationRepo, interviewID int, subTopic st return conversationID, nil } -func CreateConversation( - repo ConversationRepo, - interviewRepo interview.InterviewRepo, - openAI chatgpt.AIClient, +func (c *ConversationService) CreateConversation( conversation *Conversation, interviewID int, prompt, @@ -44,7 +38,7 @@ func CreateConversation( topicID := conversation.CurrentTopic questionNumber := conversation.CurrentQuestionNumber - _, err := repo.CreateQuestion(conversation, firstQuestion) + _, err := c.ConversationRepo.CreateQuestion(conversation, firstQuestion) if err != nil { log.Printf("CreateQuestion failed: %v", err) return nil, err @@ -60,19 +54,19 @@ func CreateConversation( topic.Questions = make(map[int]*Question) topic.Questions[questionNumber] = NewQuestion(conversationID, topicID, questionNumber, firstQuestion, messages) - err = repo.CreateMessages(conversation, messages) + err = c.ConversationRepo.CreateMessages(conversation, messages) if err != nil { log.Printf("repo.CreateMessages failed: %v", err) return nil, err } - chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo) + chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, c.AIService, c.InterviewRepo) if err != nil { log.Printf("getChatGPTResponses failed: %v", err) return nil, err } - err = interviewRepo.UpdateScore(interviewID, chatGPTResponse.Score) + err = c.InterviewRepo.UpdateScore(interviewID, chatGPTResponse.Score) if err != nil { log.Printf("interviewRepo.UpdateScore failed: %v", err) return nil, err @@ -81,7 +75,7 @@ func CreateConversation( conversation.CurrentQuestionNumber++ conversation.CurrentSubtopic = chatGPTResponse.NextSubtopic questionNumber++ - _, err = repo.UpdateConversationCurrents(conversationID, topicID, questionNumber, chatGPTResponse.NextSubtopic) + _, err = c.ConversationRepo.UpdateConversationCurrents(conversationID, topicID, questionNumber, chatGPTResponse.NextSubtopic) if err != nil { log.Printf("UpdateConversationTopic error: %v", err) return nil, err @@ -92,12 +86,12 @@ func CreateConversation( } conversation.Topics[topicID].Questions[questionNumber] = NewQuestion(conversationID, topicID, questionNumber, chatGPTResponse.NextQuestion, messagesQ2) - _, err = repo.AddQuestion(conversation.Topics[topicID].Questions[questionNumber]) + _, err = c.ConversationRepo.AddQuestion(conversation.Topics[topicID].Questions[questionNumber]) if err != nil { log.Printf("AddQuestion in CreateConversation err: %v", err) return nil, err } - _, err = repo.AddMessage(conversationID, topicID, questionNumber, messagesQ2[0]) + _, err = c.ConversationRepo.AddMessage(conversationID, topicID, questionNumber, messagesQ2[0]) if err != nil { log.Printf("AddMessage in CreateConversation err: %v", err) return nil, err @@ -106,10 +100,7 @@ func CreateConversation( return conversation, nil } -func AppendConversation( - repo ConversationRepo, - interviewRepo interview.InterviewRepo, - openAI chatgpt.AIClient, +func (c *ConversationService) AppendConversation( interviewID, userID int, conversation *Conversation, @@ -124,19 +115,19 @@ func AppendConversation( } messageUser := NewMessage(conversationID, topicID, questionNumber, User, message) - _, err := repo.AddMessage(conversationID, topicID, questionNumber, messageUser) + _, err := c.ConversationRepo.AddMessage(conversationID, topicID, questionNumber, messageUser) if err != nil { return nil, err } conversation.Topics[topicID].Questions[questionNumber].Messages = append(conversation.Topics[topicID].Questions[questionNumber].Messages, messageUser) - chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, openAI, interviewRepo) + chatGPTResponse, chatGPTResponseString, err := GetChatGPTResponses(conversation, c.AIService, c.InterviewRepo) if err != nil { log.Printf("getChatGPTResponses failed: %v", err) return nil, err } - err = interviewRepo.UpdateScore(interviewID, chatGPTResponse.Score) + err = c.InterviewRepo.UpdateScore(interviewID, chatGPTResponse.Score) if err != nil { log.Printf("interviewRepo.UpdateScore failed: %v", err) return nil, err @@ -153,20 +144,20 @@ func AppendConversation( conversation.CurrentSubtopic = "finished" conversation.CurrentQuestionNumber = 0 - err := interviewRepo.UpdateStatus(interviewID, userID, "finished") + err := c.InterviewRepo.UpdateStatus(interviewID, userID, "finished") if err != nil { log.Printf("interviewRepo.UpdateStatus failed: %v", err) return nil, err } - _, err = repo.UpdateConversationCurrents(conversationID, conversation.CurrentTopic, 0, conversation.CurrentSubtopic) + _, err = c.ConversationRepo.UpdateConversationCurrents(conversationID, conversation.CurrentTopic, 0, conversation.CurrentSubtopic) if err != nil { log.Printf("UpdateConversationTopic error: %v", err) return nil, err } messageFinal := NewMessage(conversationID, topicID, questionNumber, Interviewer, chatGPTResponseString) - _, err = repo.AddMessage(conversationID, topicID, questionNumber, messageFinal) + _, err = c.ConversationRepo.AddMessage(conversationID, topicID, questionNumber, messageFinal) if err != nil { return nil, err } @@ -183,7 +174,7 @@ func AppendConversation( conversation.CurrentSubtopic = chatGPTResponse.NextSubtopic conversation.CurrentQuestionNumber = resetQuestionNumber - _, err := repo.UpdateConversationCurrents(conversationID, nextTopicID, resetQuestionNumber, chatGPTResponse.NextSubtopic) + _, err := c.ConversationRepo.UpdateConversationCurrents(conversationID, nextTopicID, resetQuestionNumber, chatGPTResponse.NextSubtopic) if err != nil { log.Printf("UpdateConversationTopic error: %v", err) return nil, err @@ -198,11 +189,11 @@ func AppendConversation( topic.Questions = make(map[int]*Question) topic.Questions[resetQuestionNumber] = question - _, err = repo.AddQuestion(question) + _, err = c.ConversationRepo.AddQuestion(question) if err != nil { log.Printf("AddQuestion in AppendConversation err: %v", err) } - _, err = repo.AddMessage(conversationID, nextTopicID, resetQuestionNumber, messages[0]) + _, err = c.ConversationRepo.AddMessage(conversationID, nextTopicID, resetQuestionNumber, messages[0]) if err != nil { return nil, err } @@ -213,7 +204,7 @@ func AppendConversation( if incrementQuestion { conversation.CurrentQuestionNumber++ questionNumber++ - _, err := repo.UpdateConversationCurrents(conversationID, topicID, questionNumber, chatGPTResponse.NextSubtopic) + _, err := c.ConversationRepo.UpdateConversationCurrents(conversationID, topicID, questionNumber, chatGPTResponse.NextSubtopic) if err != nil { log.Printf("UpdateConversationTopic error: %v", err) return nil, err @@ -225,12 +216,12 @@ func AppendConversation( messageInterviewer := NewMessage(conversationID, topicID, questionNumber, Interviewer, chatGPTResponseString) conversation.Topics[topicID].Questions[questionNumber].Messages = append(conversation.Topics[topicID].Questions[questionNumber].Messages, messageInterviewer) - _, err = repo.AddQuestion(conversation.Topics[topicID].Questions[questionNumber]) + _, err = c.ConversationRepo.AddQuestion(conversation.Topics[topicID].Questions[questionNumber]) if err != nil { log.Printf("AddQuestion in AppendConversation failed: %v", err) return nil, err } - _, err = repo.AddMessage(conversationID, topicID, questionNumber, messageInterviewer) + _, err = c.ConversationRepo.AddMessage(conversationID, topicID, questionNumber, messageInterviewer) if err != nil { log.Printf("AddMessage in AppendConversation failed: %v", err) return nil, err @@ -239,15 +230,15 @@ func AppendConversation( return conversation, nil } -func GetConversation(repo ConversationRepo, interviewID int) (*Conversation, error) { - conversation, err := repo.GetConversation(interviewID) +func (c *ConversationService) GetConversation(interviewID int) (*Conversation, error) { + conversation, err := c.ConversationRepo.GetConversation(interviewID) if err != nil { return nil, err } conversation.Topics = ClonePredefinedTopics() - questionsReturned, err := repo.GetQuestions(conversation) + questionsReturned, err := c.ConversationRepo.GetQuestions(conversation) if err != nil { return nil, err } @@ -263,9 +254,9 @@ func GetConversation(repo ConversationRepo, interviewID int) (*Conversation, err topic.Questions[question.QuestionNumber] = question - messagesReturned, err := repo.GetMessages(conversation.ID, topicID, question.QuestionNumber) + messagesReturned, err := c.ConversationRepo.GetMessages(conversation.ID, topicID, question.QuestionNumber) if err != nil { - log.Printf("repo.GetMessages failed: %v\n", err) + log.Printf("c.ConversationRepo.GetMessages failed: %v\n", err) return nil, err } diff --git a/conversation/service_test.go b/conversation/service_test.go index b6ebc13..03a503b 100644 --- a/conversation/service_test.go +++ b/conversation/service_test.go @@ -1,9 +1,7 @@ package conversation_test import ( - "fmt" - "log" - "os" + "log/slog" "strings" "testing" @@ -15,7 +13,7 @@ import ( ) func TestCreateConversation(t *testing.T) { - ai := &mocks.MockOpenAIClient{} + mockAIService := &mocks.MockAIService{} tests := []struct { name string @@ -57,7 +55,7 @@ func TestCreateConversation(t *testing.T) { CurrentQuestionNumber: 3, }, setup: func() { - ai.Scenario = mocks.ScenarioCreated + mockAIService.Scenario = mocks.ScenarioCreated }, }, { @@ -83,21 +81,17 @@ func TestCreateConversation(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) if tc.setup != nil { tc.setup() } - - repo := conversation.NewMockRepo() + var buf strings.Builder + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + conversationRepo := conversation.NewMockRepo() interviewRepo := interview.NewMockRepo() - repo.FailRepo = tc.failRepo + conversationService := conversation.NewConversationService(conversationRepo, interviewRepo, mockAIService, logger) + conversationRepo.FailRepo = tc.failRepo - convo, err := conversation.CreateConversation( - repo, - interviewRepo, - ai, + convo, err := conversationService.CreateConversation( tc.convo, tc.interviewID, tc.prompt, @@ -126,8 +120,7 @@ func TestCreateConversation(t *testing.T) { } func TestAppendConversation(t *testing.T) { - ai := &mocks.MockOpenAIClient{} - + mockAIService := &mocks.MockAIService{} tests := []struct { name string message string @@ -156,7 +149,7 @@ func TestAppendConversation(t *testing.T) { failRepo: false, expectError: false, setup: func() { - ai.Scenario = mocks.ScenarioAppended1 + mockAIService.Scenario = mocks.ScenarioAppended1 }, }, { @@ -180,21 +173,17 @@ func TestAppendConversation(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) if tc.setup != nil { tc.setup() } - - repo := conversation.NewMockRepo() + var buf strings.Builder + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + conversationRepo := conversation.NewMockRepo() interviewRepo := interview.NewMockRepo() - repo.FailRepo = tc.failRepo + conversationService := conversation.NewConversationService(conversationRepo, interviewRepo, mockAIService, logger) + conversationRepo.FailRepo = tc.failRepo - convo, err := conversation.CreateConversation( - repo, - interviewRepo, - ai, + convo, err := conversationService.CreateConversation( tc.convo, tc.interviewID, "Prompt", @@ -209,7 +198,12 @@ func TestAppendConversation(t *testing.T) { t.Fatalf("failed to create initial conversation: %v", err) } - updatedConvo, err := conversation.AppendConversation(repo, interviewRepo, ai, tc.interviewID, tc.userID, convo, tc.message, tc.prompt) + updatedConvo, err := conversationService.AppendConversation( + tc.interviewID, + tc.userID, + convo, + tc.message, + tc.prompt) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -225,10 +219,3 @@ func TestAppendConversation(t *testing.T) { }) } } - -func showLogsIfFail(t *testing.T, name string, buf strings.Builder) { - log.SetOutput(os.Stderr) - if t.Failed() { - fmt.Printf("---- logs for test: %s ----\n%s\n", name, buf.String()) - } -} diff --git a/dashboard/model.go b/dashboard/model.go index bc09d9c..6151e6d 100644 --- a/dashboard/model.go +++ b/dashboard/model.go @@ -1,9 +1,11 @@ package dashboard import ( + "log/slog" "time" "github.com/michaelboegner/interviewer/interview" + "github.com/michaelboegner/interviewer/user" ) type DashboardData struct { @@ -16,3 +18,17 @@ type DashboardData struct { SubscriptionCredits int `json:"subscription_credits"` PastInterviews []interview.Summary `json:"past_interviews"` } + +type DashboardService struct { + UserRepo user.UserRepo + InterviewRepo interview.InterviewRepo + Logger *slog.Logger +} + +func NewDashboardService(userRepo user.UserRepo, interviewRepo interview.InterviewRepo, logger *slog.Logger) *DashboardService { + return &DashboardService{ + UserRepo: userRepo, + InterviewRepo: interviewRepo, + Logger: logger, + } +} diff --git a/dashboard/service.go b/dashboard/service.go index 773d1e5..03d395f 100644 --- a/dashboard/service.go +++ b/dashboard/service.go @@ -2,19 +2,16 @@ package dashboard import ( "log" - - "github.com/michaelboegner/interviewer/interview" - "github.com/michaelboegner/interviewer/user" ) -func GetDashboardData(userID int, userRepo user.UserRepo, interviewRepo interview.InterviewRepo) (*DashboardData, error) { - user, err := userRepo.GetUser(userID) +func (d *DashboardService) GetDashboardData(userID int) (*DashboardData, error) { + user, err := d.UserRepo.GetUser(userID) if err != nil { log.Printf("GetUser failed for userID %d: %v", userID, err) return nil, err } - interviews, err := interviewRepo.GetInterviewSummariesByUserID(userID) + interviews, err := d.InterviewRepo.GetInterviewSummariesByUserID(userID) if err != nil { log.Printf("GetInterviewSummariesByUserID failed for userID %d: %v", userID, err) return nil, err diff --git a/handlers/handlers.go b/handlers/handlers.go index 57e16e4..9a5dd6c 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "io" + "log" "net/http" "net/url" "os" @@ -13,11 +14,8 @@ import ( "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/chatgpt" - "github.com/michaelboegner/interviewer/conversation" - "github.com/michaelboegner/interviewer/dashboard" "github.com/michaelboegner/interviewer/interview" "github.com/michaelboegner/interviewer/middleware" - "github.com/michaelboegner/interviewer/token" "github.com/michaelboegner/interviewer/user" ) @@ -52,7 +50,7 @@ func (h *Handler) RequestVerificationHandler(w http.ResponseWriter, r *http.Requ return } - verificationJWT, err := user.VerificationToken(req.Email, req.Username, req.Password) + verificationJWT, err := h.UserService.VerificationToken(req.Email, req.Username, req.Password) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to create token") return @@ -90,7 +88,7 @@ func (h *Handler) CheckEmailHandler(w http.ResponseWriter, r *http.Request) { return } - err := user.GetUserByEmail(h.UserRepo, req.Email) + err := h.UserService.GetUserByEmail(req.Email) if err != nil { if errors.Is(err, sql.ErrNoRows) { RespondWithJSON(w, http.StatusOK, map[string]bool{"exists": false}) @@ -118,14 +116,15 @@ func (h *Handler) CreateUsersHandler(w http.ResponseWriter, r *http.Request) { return } - userCreated, err := user.CreateUser(h.UserRepo, req.Token) + userCreated, err := h.UserService.CreateUser(req.Token) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Internal server error") return } - jwt, err := token.CreateJWT(strconv.Itoa(userCreated.ID), 0) + jwt, err := h.TokenService.CreateJWT(strconv.Itoa(userCreated.ID), 0) if err != nil { + h.Logger.Error("h.TokenService.CreateJWT failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Internal server error") return } @@ -167,7 +166,7 @@ func (h *Handler) GetUsersHandler(w http.ResponseWriter, r *http.Request) { return } - userReturned, err := user.GetUser(h.UserRepo, userID) + userReturned, err := h.UserService.GetUser(userID) if err != nil { return } @@ -204,25 +203,26 @@ func (h *Handler) DeleteUserHandler(w http.ResponseWriter, r *http.Request) { return } - userReturned, err := user.GetUser(h.UserRepo, userID) + userReturned, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to find user") return } - err = h.Billing.CancelSubscription(h.UserRepo, userReturned.Email) + err = h.BillingService.CancelSubscription(userReturned.Email) if err != nil { + h.Logger.Error("h.BillingService.CancelSubscription failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return } - err = user.MarkUserDeleted(h.UserRepo, userID) + err = h.UserService.MarkUserDeleted(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to delete user") return } - err = token.DeleteRefreshToken(h.TokenRepo, userID) + err = h.TokenService.DeleteRefreshToken(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Internal server error") return @@ -256,7 +256,7 @@ func (h *Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { } - jwToken, username, userID, err := user.LoginUser(h.UserRepo, params.Email, params.Password) + username, userID, err := h.UserService.LoginUser(params.Email, params.Password) if err != nil { if errors.Is(err, user.ErrAccountDeleted) { RespondWithError(w, http.StatusUnauthorized, user.ErrAccountDeleted.Error()) @@ -266,7 +266,14 @@ func (h *Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { return } - refreshToken, err := token.CreateRefreshToken(h.TokenRepo, userID) + jwToken, err := h.TokenService.CreateJWT(strconv.Itoa(userID), 0) + if err != nil { + log.Printf("JWT creation failed: %v", err) + RespondWithError(w, http.StatusInternalServerError, "Internal server error") + return + } + + refreshToken, err := h.TokenService.CreateRefreshToken(userID) if err != nil { RespondWithError(w, http.StatusUnauthorized, "") return @@ -380,19 +387,21 @@ func (h *Handler) GithubLoginHandler(w http.ResponseWriter, r *http.Request) { return } - user, err := user.GetOrCreateByEmail(h.UserRepo, githubUser.Email, githubUser.Login) + user, err := h.UserService.GetOrCreateByEmail(githubUser.Email, githubUser.Login) if err != nil { RespondWithError(w, http.StatusInternalServerError, "User creation failed") return } - jwt, err := token.CreateJWT(strconv.Itoa(user.ID), 0) + jwt, err := h.TokenService.CreateJWT(strconv.Itoa(user.ID), 0) if err != nil { + h.Logger.Error("h.TokenService.CreateJWT failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Internal server error") return } - refreshToken, err := token.CreateRefreshToken(h.TokenRepo, user.ID) + refreshToken, err := h.TokenService.CreateRefreshToken(user.ID) if err != nil { + h.Logger.Error("h.TokenService.CreateRefreshToken failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Internal server error") return } @@ -430,25 +439,25 @@ func (h *Handler) RefreshTokensHandler(w http.ResponseWriter, r *http.Request) { return } - storedToken, err := token.GetStoredRefreshToken(h.TokenRepo, params.UserID) + storedToken, err := h.TokenService.GetStoredRefreshToken(params.UserID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid user_id") return } - ok := token.VerifyRefreshToken(storedToken, providedToken) + ok := h.TokenService.VerifyRefreshToken(storedToken, providedToken) if !ok { RespondWithError(w, http.StatusUnauthorized, "Refresh token is invalid") return } - refreshToken, err := token.CreateRefreshToken(h.TokenRepo, params.UserID) + refreshToken, err := h.TokenService.CreateRefreshToken(params.UserID) if err != nil { RespondWithError(w, http.StatusUnauthorized, "") return } - user, err := h.UserRepo.GetUser(params.UserID) + user, err := h.UserService.UserRepo.GetUser(params.UserID) if err != nil { RespondWithError(w, http.StatusUnauthorized, "Account deactivated") return @@ -458,7 +467,7 @@ func (h *Handler) RefreshTokensHandler(w http.ResponseWriter, r *http.Request) { return } - jwToken, err := token.CreateJWT(strconv.Itoa(params.UserID), 0) + jwToken, err := h.TokenService.CreateJWT(strconv.Itoa(params.UserID), 0) if err != nil { RespondWithError(w, http.StatusInternalServerError, "") return @@ -491,17 +500,13 @@ func (h *Handler) InterviewsHandler(w http.ResponseWriter, r *http.Request) { return } - userReturned, err := user.GetUser(h.UserRepo, userID) + userReturned, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to find user") return } - interviewStarted, err := interview.StartInterview( - h.InterviewRepo, - h.UserRepo, - h.BillingRepo, - h.OpenAI, + interviewStarted, err := h.InterviewService.StartInterview( userReturned, 30, 3, @@ -521,13 +526,13 @@ func (h *Handler) InterviewsHandler(w http.ResponseWriter, r *http.Request) { return } - conversationID, err := conversation.CreateEmptyConversation(h.ConversationRepo, interviewStarted.Id, interviewStarted.Subtopic) + conversationID, err := h.ConversationService.CreateEmptyConversation(interviewStarted.Id, interviewStarted.Subtopic) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Internal server error") return } - err = interview.LinkConversation(h.InterviewRepo, interviewStarted.Id, conversationID) + err = h.InterviewService.LinkConversation(interviewStarted.Id, conversationID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Internal server error") return @@ -560,7 +565,7 @@ func (h *Handler) GetInterviewHandler(w http.ResponseWriter, r *http.Request) { return } - interviewReturned, err := interview.GetInterview(h.InterviewRepo, interviewID) + interviewReturned, err := h.InterviewService.GetInterview(interviewID) if err != nil { RespondWithError(w, http.StatusNotFound, "Interview not found") return @@ -603,7 +608,7 @@ func (h *Handler) UpdateInterviewStatusHandler(w http.ResponseWriter, r *http.Re return } - interviewReturned, err := interview.GetInterview(h.InterviewRepo, interviewID) + interviewReturned, err := h.InterviewService.GetInterview(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID") return @@ -615,7 +620,7 @@ func (h *Handler) UpdateInterviewStatusHandler(w http.ResponseWriter, r *http.Re return } - err = h.InterviewRepo.UpdateStatus(interviewID, userID, payload.Status) + err = h.InterviewService.InterviewRepo.UpdateStatus(interviewID, userID, payload.Status) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not update status") return @@ -650,7 +655,7 @@ func (h *Handler) CreateConversationsHandler(w http.ResponseWriter, r *http.Requ return } - interviewReturned, err := interview.GetInterview(h.InterviewRepo, interviewID) + interviewReturned, err := h.InterviewService.GetInterview(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID") return @@ -664,16 +669,13 @@ func (h *Handler) CreateConversationsHandler(w http.ResponseWriter, r *http.Requ return } - conversationReturned, err := conversation.GetConversation(h.ConversationRepo, interviewID) + conversationReturned, err := h.ConversationService.GetConversation(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID") return } - conversationCreated, err := conversation.CreateConversation( - h.ConversationRepo, - h.InterviewRepo, - h.OpenAI, + conversationCreated, err := h.ConversationService.CreateConversation( conversationReturned, interviewID, interviewReturned.Prompt, @@ -726,7 +728,7 @@ func (h *Handler) AppendConversationsHandler(w http.ResponseWriter, r *http.Requ return } - interviewReturned, err := interview.GetInterview(h.InterviewRepo, interviewID) + interviewReturned, err := h.InterviewService.GetInterview(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID") return @@ -740,16 +742,13 @@ func (h *Handler) AppendConversationsHandler(w http.ResponseWriter, r *http.Requ return } - conversationReturned, err := conversation.GetConversation(h.ConversationRepo, interviewID) + conversationReturned, err := h.ConversationService.GetConversation(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID.") return } - conversationReturned, err = conversation.AppendConversation( - h.ConversationRepo, - h.InterviewRepo, - h.OpenAI, + conversationReturned, err = h.ConversationService.AppendConversation( interviewID, userID, conversationReturned, @@ -789,7 +788,7 @@ func (h *Handler) GetConversationHandler(w http.ResponseWriter, r *http.Request) return } - interviewReturned, err := interview.GetInterview(h.InterviewRepo, interviewID) + interviewReturned, err := h.InterviewService.GetInterview(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID") return @@ -799,7 +798,7 @@ func (h *Handler) GetConversationHandler(w http.ResponseWriter, r *http.Request) return } - conversationReturned, err := conversation.GetConversation(h.ConversationRepo, interviewID) + conversationReturned, err := h.ConversationService.GetConversation(interviewID) if err != nil { RespondWithError(w, http.StatusBadRequest, "Invalid ID.") return @@ -824,7 +823,13 @@ func (h *Handler) RequestResetHandler(w http.ResponseWriter, r *http.Request) { return } - resetJWT, err := user.RequestPasswordReset(h.UserRepo, params.Email) + err = h.UserService.GetUserByEmail(params.Email) + if err != nil { + w.WriteHeader(http.StatusOK) + return + } + + resetJWT, err := h.TokenService.CreateJWT(params.Email, 900) if err != nil { w.WriteHeader(http.StatusOK) return @@ -855,7 +860,7 @@ func (h *Handler) ResetPasswordHandler(w http.ResponseWriter, r *http.Request) { RespondWithError(w, http.StatusBadRequest, "Invalid request body") return } - err := user.ResetPassword(h.UserRepo, params.NewPassword, params.Token) + err := h.UserService.ResetPassword(params.NewPassword, params.Token) if err != nil { RespondWithError(w, http.StatusUnauthorized, "Invalid or expired token") return @@ -883,7 +888,7 @@ func (h *Handler) CreateCheckoutSessionHandler(w http.ResponseWriter, r *http.Re return } - user, err := user.GetUser(h.UserRepo, userID) + user, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not find user") return @@ -908,7 +913,7 @@ func (h *Handler) CreateCheckoutSessionHandler(w http.ResponseWriter, r *http.Re return } - url, err := h.Billing.RequestCheckoutSession(user.Email, priceIDInt) + url, err := h.BillingService.RequestCheckoutSession(user.Email, priceIDInt) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not start checkout") return @@ -929,13 +934,13 @@ func (h *Handler) CancelSubscriptionHandler(w http.ResponseWriter, r *http.Reque return } - userReturned, err := user.GetUser(h.UserRepo, userID) + userReturned, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not retrieve user") return } - err = h.Billing.RequestDeleteSubscription(userReturned.SubscriptionID) + err = h.BillingService.RequestDeleteSubscription(userReturned.SubscriptionID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not cancel subscription") return @@ -956,13 +961,13 @@ func (h *Handler) ResumeSubscriptionHandler(w http.ResponseWriter, r *http.Reque return } - userReturned, err := user.GetUser(h.UserRepo, userID) + userReturned, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not retrieve user") return } - err = h.Billing.RequestResumeSubscription(userReturned.SubscriptionID) + err = h.BillingService.RequestResumeSubscription(userReturned.SubscriptionID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not cancel subscription") return @@ -989,7 +994,7 @@ func (h *Handler) ChangePlanHandler(w http.ResponseWriter, r *http.Request) { return } - user, err := user.GetUser(h.UserRepo, userID) + user, err := h.UserService.GetUser(userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Could not find user") return @@ -1012,7 +1017,8 @@ func (h *Handler) ChangePlanHandler(w http.ResponseWriter, r *http.Request) { return } - if err := h.Billing.RequestUpdateSubscriptionVariant(user.SubscriptionID, priceIDInt); err != nil { + if err := h.BillingService.RequestUpdateSubscriptionVariant(user.SubscriptionID, priceIDInt); err != nil { + h.Logger.Error("UpdateLemonSubscriptionVariant failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Failed to update subscription") return } @@ -1034,7 +1040,8 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) defer r.Body.Close() signature := r.Header.Get("X-Signature") - if !h.Billing.VerifyBillingSignature(signature, body, os.Getenv("LEMON_WEBHOOK_SECRET")) { + if !h.BillingService.VerifyBillingSignature(signature, body, os.Getenv("LEMON_WEBHOOK_SECRET")) { + h.Logger.Error("Invalid billing event signature") RespondWithError(w, http.StatusUnauthorized, "Invalid signature") return } @@ -1047,8 +1054,9 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) } subscriptionID := webhookPayload.Data.SubscriptionID webhookID := webhookPayload.Meta.WebhookID - exists, err := h.BillingRepo.HasWebhookBeenProcessed(webhookID) + exists, err := h.BillingService.BillingRepo.HasWebhookBeenProcessed(webhookID) if err != nil { + h.Logger.Error("h.BillingRepo.HasWebhookBeenProcessed failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Error checking webhook") return } @@ -1071,8 +1079,9 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.ApplyCredits(h.UserRepo, h.BillingRepo, orderAttrs.UserEmail, orderAttrs.FirstOrderItem.VariantID) + err = h.BillingService.ApplyCredits(orderAttrs.UserEmail, orderAttrs.FirstOrderItem.VariantID) if err != nil { + h.Logger.Error("h.Billing.ApplyCredits failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return } @@ -1083,7 +1092,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - exists, err := h.UserRepo.HasActiveOrCancelledSubscription(SubCreatedAttrs.UserEmail) + exists, err := h.UserService.UserRepo.HasActiveOrCancelledSubscription(SubCreatedAttrs.UserEmail) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Subscription check failed") return @@ -1092,7 +1101,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.CreateSubscription(h.UserRepo, SubCreatedAttrs, subscriptionID) + err = h.BillingService.CreateSubscription(SubCreatedAttrs, subscriptionID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1103,8 +1112,9 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.CancelSubscription(h.UserRepo, emailAttribute.UserEmail) + err = h.BillingService.CancelSubscription(emailAttribute.UserEmail) if err != nil { + h.Logger.Error("h.Billing.CancelSubscription failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return } @@ -1114,8 +1124,9 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.ResumeSubscription(h.UserRepo, emailAttribute.UserEmail) + err = h.BillingService.ResumeSubscription(emailAttribute.UserEmail) if err != nil { + h.Logger.Error("h.Billing.ResumeSubscription failed", "error", err) RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return } @@ -1125,7 +1136,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.ExpireSubscription(h.UserRepo, h.BillingRepo, emailAttribute.UserEmail) + err = h.BillingService.ExpireSubscription(emailAttribute.UserEmail) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1141,7 +1152,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.RenewSubscription(h.UserRepo, h.BillingRepo, SubRenewAttrs) + err = h.BillingService.RenewSubscription(SubRenewAttrs) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1153,7 +1164,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.ChangeSubscription(h.UserRepo, h.BillingRepo, SubChangedAttrs) + err = h.BillingService.ChangeSubscription(SubChangedAttrs) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1165,7 +1176,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.UpdateSubscription(h.UserRepo, SubChangedAttrs, subscriptionID) + err = h.BillingService.UpdateSubscription(SubChangedAttrs, subscriptionID) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1177,7 +1188,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.Billing.DeductCredits(h.UserRepo, h.BillingRepo, orderAttrs) + err = h.BillingService.DeductCredits(orderAttrs) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Failed to update user") return @@ -1198,7 +1209,7 @@ func (h *Handler) BillingWebhookHandler(w http.ResponseWriter, r *http.Request) return } - err = h.BillingRepo.MarkWebhookProcessed(webhookID, eventType) + err = h.BillingService.BillingRepo.MarkWebhookProcessed(webhookID, eventType) if err != nil { w.WriteHeader(http.StatusOK) return @@ -1219,7 +1230,7 @@ func (h *Handler) DashboardHandler(w http.ResponseWriter, r *http.Request) { return } - dashboardData, err := dashboard.GetDashboardData(userID, h.UserRepo, h.InterviewRepo) + dashboardData, err := h.DashboardService.GetDashboardData(userID) if err != nil { if errors.Is(err, sql.ErrNoRows) { RespondWithError(w, http.StatusUnauthorized, "User not found") @@ -1246,7 +1257,7 @@ func (h *Handler) JDInputHandler(w http.ResponseWriter, r *http.Request) { return } - jdInput, err := h.OpenAI.ExtractJDInput(input.JobDescription) + jdInput, err := h.AIService.ExtractJDInput(input.JobDescription) if err != nil { var openaiErr *chatgpt.OpenAIError if errors.As(err, &openaiErr) { @@ -1257,7 +1268,7 @@ func (h *Handler) JDInputHandler(w http.ResponseWriter, r *http.Request) { return } - jdSummary, err := h.OpenAI.ExtractJDSummary(jdInput) + jdSummary, err := h.AIService.ExtractJDSummary(jdInput) if err != nil { var openaiErr *chatgpt.OpenAIError if errors.As(err, &openaiErr) { diff --git a/handlers/handlers_test.go b/handlers/handlers_test.go index 7e1d311..9b5fdbc 100644 --- a/handlers/handlers_test.go +++ b/handlers/handlers_test.go @@ -19,7 +19,6 @@ import ( "github.com/michaelboegner/interviewer/internal/mocks" "github.com/michaelboegner/interviewer/internal/testutil" "github.com/michaelboegner/interviewer/interview" - "github.com/michaelboegner/interviewer/token" "github.com/michaelboegner/interviewer/user" ) @@ -47,7 +46,7 @@ type TestCase struct { var ( Handler *handlers.Handler conversationBuilder *testutil.ConversationBuilder - mockAI *mocks.MockOpenAIClient + mockAI *mocks.MockAIService ) var logger *slog.Logger @@ -79,7 +78,7 @@ func TestMain(m *testing.M) { logger.Info("Test server started", "url", testutil.TestServerURL) - mockAI = Handler.OpenAI.(*mocks.MockOpenAIClient) + mockAI = Handler.AIService.(*mocks.MockAIService) conversationBuilder = testutil.NewConversationBuilder() code := m.Run() @@ -178,7 +177,7 @@ func Test_RequestVerificationHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - user, err := user.GetUser(Handler.UserRepo, got.UserID) + user, err := Handler.UserService.GetUser(got.UserID) if err != nil { t.Fatalf("Assert Database: GetUser failed: %v", err) } @@ -227,7 +226,7 @@ func Test_CreateUsersHandler_Integration(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - verificationJWT, err := user.VerificationToken(tc.email, tc.username, tc.password) + verificationJWT, err := Handler.UserService.VerificationToken(tc.email, tc.username, tc.password) if err != nil { t.Fatalf("GenerateEmailVerificationToken failed: %v", err) } @@ -269,7 +268,7 @@ func Test_CreateUsersHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - user, err := user.GetUser(Handler.UserRepo, got.UserID) + user, err := Handler.UserService.GetUser(got.UserID) if err != nil { t.Fatalf("Assert Database: GetUser failed: %v", err) } @@ -288,7 +287,7 @@ func Test_CreateUsersHandler_Integration(t *testing.T) { func Test_GetUsersHandler_Integration(t *testing.T) { cleanDBOrFail(t) - jwtoken, userID := testutil.CreateTestUserAndJWT(logger) + jwtoken, userID := testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) tests := []TestCase{ { @@ -359,7 +358,7 @@ func Test_GetUsersHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - user, err := user.GetUser(Handler.UserRepo, got.UserID) + user, err := Handler.UserService.GetUser(got.UserID) if err != nil { t.Fatalf("Assert Database: GetUser failed: %v", err) } @@ -378,7 +377,7 @@ func Test_GetUsersHandler_Integration(t *testing.T) { func Test_LoginHandler_Integration(t *testing.T) { cleanDBOrFail(t) - _, _ = testutil.CreateTestUserAndJWT(logger) + _, _ = testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) tests := []TestCase{ { @@ -480,7 +479,7 @@ func Test_LoginHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - refreshToken, err := token.GetStoredRefreshToken(Handler.TokenRepo, respUnmarshalled.UserID) + refreshToken, err := Handler.TokenService.GetStoredRefreshToken(respUnmarshalled.UserID) if err != nil { t.Fatalf("Assert Database: GetUser failed: %v", err) } @@ -499,8 +498,8 @@ func Test_LoginHandler_Integration(t *testing.T) { func Test_RefreshTokensHandler_Integration(t *testing.T) { cleanDBOrFail(t) - _, userID := testutil.CreateTestUserAndJWT(logger) - refreshToken, err := token.GetStoredRefreshToken(Handler.TokenRepo, userID) + _, userID := testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) + refreshToken, err := Handler.TokenService.GetStoredRefreshToken(userID) if err != nil { t.Fatalf("TC GetStoredRefreshToken failed: %v", err) } @@ -611,7 +610,7 @@ func Test_RefreshTokensHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - refreshToken, err := token.GetStoredRefreshToken(Handler.TokenRepo, userID) + refreshToken, err := Handler.TokenService.GetStoredRefreshToken(userID) if err != nil { t.Fatalf("Assert Database: GetUser failed: %v", err) } @@ -630,7 +629,7 @@ func Test_RefreshTokensHandler_Integration(t *testing.T) { func Test_InterviewsHandler_Integration(t *testing.T) { cleanDBOrFail(t) - jwtoken, userID := testutil.CreateTestUserAndJWT(logger) + jwtoken, userID := testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) expiredJWT := testutil.CreateTestExpiredJWT(userID, -1, logger) tests := []TestCase{ @@ -749,7 +748,7 @@ func Test_InterviewsHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - interviewReturned, err := interview.GetInterview(Handler.InterviewRepo, respUnmarshalled.InterviewID) + interviewReturned, err := Handler.InterviewService.GetInterview(respUnmarshalled.InterviewID) if err != nil { t.Fatalf("Assert Database: GetInterview failed: %v", err) } @@ -768,7 +767,7 @@ func Test_InterviewsHandler_Integration(t *testing.T) { func Test_CreateConversationsHandler_Integration(t *testing.T) { cleanDBOrFail(t) - jwtoken, _ := testutil.CreateTestUserAndJWT(logger) + jwtoken, _ := testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) mockAI.Scenario = mocks.ScenarioInterview interviewID := testutil.CreateTestInterview(jwtoken, logger) conversationsURL := testutil.TestServerURL + fmt.Sprintf("/api/conversations/create/%d", interviewID) @@ -860,7 +859,7 @@ func Test_CreateConversationsHandler_Integration(t *testing.T) { // Assert Database if tc.DBCheck { - conversation, err := conversation.GetConversation(Handler.ConversationRepo, got.Conversation.ID) + conversation, err := Handler.ConversationService.GetConversation(got.Conversation.ID) if err != nil { t.Fatalf("Assert Database: GetConversation failed: %v", err) } @@ -878,7 +877,7 @@ func Test_CreateConversationsHandler_Integration(t *testing.T) { func Test_AppendConversationsHandler_Integration(t *testing.T) { cleanDBOrFail(t) - jwtoken, _ := testutil.CreateTestUserAndJWT(logger) + jwtoken, _ := testutil.CreateTestUserAndJWT(Handler.UserService, Handler.TokenService, logger) mockAI.Scenario = mocks.ScenarioInterview interviewID := testutil.CreateTestInterview(jwtoken, logger) @@ -1000,7 +999,7 @@ func Test_AppendConversationsHandler_Integration(t *testing.T) { // DB validation if tc.DBCheck { - gotDB, err := conversation.GetConversation(Handler.ConversationRepo, respUnmarshalled.Conversation.ID) + gotDB, err := Handler.ConversationService.GetConversation(respUnmarshalled.Conversation.ID) if err != nil { t.Fatalf("DB check failed: %v", err) } diff --git a/handlers/model.go b/handlers/model.go index f28abac..3b77ac9 100644 --- a/handlers/model.go +++ b/handlers/model.go @@ -2,10 +2,12 @@ package handlers import ( "database/sql" + "log/slog" "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/chatgpt" "github.com/michaelboegner/interviewer/conversation" + "github.com/michaelboegner/interviewer/dashboard" "github.com/michaelboegner/interviewer/interview" "github.com/michaelboegner/interviewer/mailer" "github.com/michaelboegner/interviewer/token" @@ -52,36 +54,39 @@ type ReturnVals struct { } type Handler struct { - UserRepo user.UserRepo - InterviewRepo interview.InterviewRepo - ConversationRepo conversation.ConversationRepo - TokenRepo token.TokenRepo - BillingRepo billing.BillingRepo - Billing *billing.Billing - Mailer mailer.MailerClient - OpenAI chatgpt.AIClient - DB *sql.DB + UserService *user.UserService + InterviewService *interview.InterviewService + ConversationService *conversation.ConversationService + TokenService *token.TokenService + BillingService *billing.BillingService + Mailer mailer.MailerClient + AIService chatgpt.AIClient + DashboardService *dashboard.DashboardService + DB *sql.DB + Logger *slog.Logger } func NewHandler( - interviewRepo interview.InterviewRepo, - userRepo user.UserRepo, - tokenRepo token.TokenRepo, - conversationRepo conversation.ConversationRepo, - billingRepo billing.BillingRepo, - billing *billing.Billing, + interviewService *interview.InterviewService, + userService *user.UserService, + tokenService *token.TokenService, + conversationService *conversation.ConversationService, + billingService *billing.BillingService, mailer mailer.MailerClient, - openAI chatgpt.AIClient, - db *sql.DB) *Handler { + aiService chatgpt.AIClient, + dashboardService *dashboard.DashboardService, + db *sql.DB, + logger *slog.Logger) *Handler { return &Handler{ - InterviewRepo: interviewRepo, - UserRepo: userRepo, - TokenRepo: tokenRepo, - ConversationRepo: conversationRepo, - BillingRepo: billingRepo, - Billing: billing, - Mailer: mailer, - OpenAI: openAI, - DB: db, + InterviewService: interviewService, + UserService: userService, + TokenService: tokenService, + ConversationService: conversationService, + BillingService: billingService, + Mailer: mailer, + AIService: aiService, + DashboardService: dashboardService, + DB: db, + Logger: logger, } } diff --git a/internal/mocks/mailer_mock.go b/internal/mocks/mailer_mock.go index 05e1c4d..a98d0ef 100644 --- a/internal/mocks/mailer_mock.go +++ b/internal/mocks/mailer_mock.go @@ -1,25 +1,23 @@ package mocks -type MockMailer struct{} +type MockMailerService struct{} -func NewMockMailer() *MockMailer { - mockMailer := &MockMailer{} - - return mockMailer +func NewMockMailerService() *MockMailerService { + return &MockMailerService{} } -func (m *MockMailer) SendPasswordReset(email, resetURL string) error { +func (m *MockMailerService) SendPasswordReset(email, resetURL string) error { return nil } -func (m *MockMailer) SendVerificationEmail(email, verifyURL string) error { +func (m *MockMailerService) SendVerificationEmail(email, verifyURL string) error { return nil } -func (m *MockMailer) SendWelcome(email string) error { +func (m *MockMailerService) SendWelcome(email string) error { return nil } -func (m *MockMailer) SendDeletionConfirmation(email string) error { +func (m *MockMailerService) SendDeletionConfirmation(email string) error { return nil } diff --git a/internal/mocks/openai_mock.go b/internal/mocks/openai_mock.go index 93bd72f..7814016 100644 --- a/internal/mocks/openai_mock.go +++ b/internal/mocks/openai_mock.go @@ -73,21 +73,21 @@ var responseFixtures = map[string]*chatgpt.ChatGPTResponse{ }, } -type MockOpenAIClient struct { +type MockAIService struct { Scenario string } -func NewMockOpenAIClient() *MockOpenAIClient { - mockOpenAIClient := &MockOpenAIClient{} +func NewMockAIService() *MockAIService { + mockAIService := &MockAIService{} - return mockOpenAIClient + return mockAIService } -func (m *MockOpenAIClient) GetChatGPTResponse(prompt string) (*chatgpt.ChatGPTResponse, error) { +func (m *MockAIService) GetChatGPTResponse(prompt string) (*chatgpt.ChatGPTResponse, error) { return responseFixtures[ScenarioInterview], nil } -func (m *MockOpenAIClient) GetChatGPTResponseConversation(_ []map[string]string) (*chatgpt.ChatGPTResponse, error) { +func (m *MockAIService) GetChatGPTResponseConversation(_ []map[string]string) (*chatgpt.ChatGPTResponse, error) { resp, ok := responseFixtures[m.Scenario] if !ok { return nil, fmt.Errorf("invalid scenario: %s", m.Scenario) @@ -95,15 +95,15 @@ func (m *MockOpenAIClient) GetChatGPTResponseConversation(_ []map[string]string) return resp, nil } -func (m *MockOpenAIClient) GetChatGPT35Response(prompt string) (*chatgpt.ChatGPTResponse, error) { +func (m *MockAIService) GetChatGPT35Response(prompt string) (*chatgpt.ChatGPTResponse, error) { return &chatgpt.ChatGPTResponse{}, nil } -func (m *MockOpenAIClient) ExtractJDInput(jd string) (*chatgpt.JDParsedOutput, error) { +func (m *MockAIService) ExtractJDInput(jd string) (*chatgpt.JDParsedOutput, error) { return &chatgpt.JDParsedOutput{}, nil } -func (m *MockOpenAIClient) ExtractJDSummary(jdInput *chatgpt.JDParsedOutput) (string, error) { +func (m *MockAIService) ExtractJDSummary(jdInput *chatgpt.JDParsedOutput) (string, error) { return "", nil } diff --git a/internal/server/server.go b/internal/server/server.go index 6fcfad9..058e594 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/chatgpt" "github.com/michaelboegner/interviewer/conversation" + "github.com/michaelboegner/interviewer/dashboard" "github.com/michaelboegner/interviewer/database" "github.com/michaelboegner/interviewer/handlers" "github.com/michaelboegner/interviewer/interview" @@ -35,15 +36,21 @@ func NewServer(logger *slog.Logger) (*Server, error) { tokenRepo := token.NewRepository(db) conversationRepo := conversation.NewRepository(db) billingRepo := billing.NewRepository(db) - openAI := chatgpt.NewOpenAI(logger) - mailer := mailer.NewMailer(logger) - billing, err := billing.NewBilling(logger) + + aiService := chatgpt.NewAIService(logger) + interviewService := interview.NewInterviewService(interviewRepo, userRepo, billingRepo, aiService, logger) + userService := user.NewUserService(userRepo, logger) + tokenService := token.NewTokenService(tokenRepo, logger) + conversationService := conversation.NewConversationService(conversationRepo, interviewRepo, aiService, logger) + mailerService := mailer.NewMailerService(logger) + dashboardService := dashboard.NewDashboardService(userRepo, interviewRepo, logger) + billingService, err := billing.NewBillingService(billingRepo, userRepo, logger) if err != nil { logger.Error("billing.NewBilling failed", "error", err) return nil, err } - handler := handlers.NewHandler(interviewRepo, userRepo, tokenRepo, conversationRepo, billingRepo, billing, mailer, openAI, db) + handler := handlers.NewHandler(interviewService, userService, tokenService, conversationService, billingService, mailerService, aiService, dashboardService, db, logger) mux.Handle("/api/users", http.HandlerFunc(handler.CreateUsersHandler)) mux.Handle("/api/auth/login", http.HandlerFunc(handler.LoginHandler)) diff --git a/internal/testutil/helpers.go b/internal/testutil/helpers.go index 50587d6..ecde95f 100644 --- a/internal/testutil/helpers.go +++ b/internal/testutil/helpers.go @@ -18,7 +18,7 @@ import ( "github.com/michaelboegner/interviewer/user" ) -func CreateTestUserAndJWT(logger *slog.Logger) (string, int) { +func CreateTestUserAndJWT(userService *user.UserService, tokenService *token.TokenService, logger *slog.Logger) (string, int) { var ( jwt string userID int @@ -28,7 +28,7 @@ func CreateTestUserAndJWT(logger *slog.Logger) (string, int) { email := "test@test.com" password := "test" - verificationJWT, err := user.VerificationToken(email, username, password) + verificationJWT, err := userService.VerificationToken(email, username, password) if err != nil { logger.Error("GenerateEmailVerificationToken failed", "error", err) } @@ -61,7 +61,7 @@ func CreateTestUserAndJWT(logger *slog.Logger) (string, int) { jwt = returnVals.JWToken //test userID extract - userID, err = token.ExtractUserIDFromToken(jwt) + userID, err = tokenService.ExtractUserIDFromToken(jwt) if err != nil { logger.Error("CreateTestUserandJWT userID extraction failed", "error", err) } diff --git a/internal/testutil/server.go b/internal/testutil/server.go index 368c1f0..011f5ac 100644 --- a/internal/testutil/server.go +++ b/internal/testutil/server.go @@ -7,6 +7,7 @@ import ( "github.com/michaelboegner/interviewer/billing" "github.com/michaelboegner/interviewer/conversation" + "github.com/michaelboegner/interviewer/dashboard" "github.com/michaelboegner/interviewer/database" "github.com/michaelboegner/interviewer/handlers" "github.com/michaelboegner/interviewer/internal/mocks" @@ -37,15 +38,21 @@ func InitTestServer(logger *slog.Logger) (*handlers.Handler, error) { tokenRepo := token.NewRepository(db) conversationRepo := conversation.NewRepository(db) billingRepo := billing.NewRepository(db) - openAI := mocks.NewMockOpenAIClient() - mailer := mocks.NewMockMailer() - billing, err := billing.NewBilling(logger) + + mockAIService := mocks.NewMockAIService() + mockMailerService := mocks.NewMockMailerService() + interviewService := interview.NewInterviewService(interviewRepo, userRepo, billingRepo, mockAIService, logger) + userService := user.NewUserService(userRepo, logger) + tokenService := token.NewTokenService(tokenRepo, logger) + conversationService := conversation.NewConversationService(conversationRepo, interviewRepo, mockAIService, logger) + dashboardService := dashboard.NewDashboardService(userRepo, interviewRepo, logger) + billingService, err := billing.NewBillingService(billingRepo, userRepo, logger) if err != nil { logger.Error("billing.NewBilling failed", "error", err) return nil, err } - handler := handlers.NewHandler(interviewRepo, userRepo, tokenRepo, conversationRepo, billingRepo, billing, mailer, openAI, db) + handler := handlers.NewHandler(interviewService, userService, tokenService, conversationService, billingService, mockMailerService, mockAIService, dashboardService, db, logger) TestMux = http.NewServeMux() TestMux.Handle("/api/users", http.HandlerFunc(handler.CreateUsersHandler)) diff --git a/interview/model.go b/interview/model.go index aa76139..a82e080 100644 --- a/interview/model.go +++ b/interview/model.go @@ -2,7 +2,12 @@ package interview import ( "errors" + "log/slog" "time" + + "github.com/michaelboegner/interviewer/billing" + "github.com/michaelboegner/interviewer/chatgpt" + "github.com/michaelboegner/interviewer/user" ) type Interview struct { @@ -31,8 +36,26 @@ type Summary struct { Score *int `json:"score,omitempty"` } +type InterviewService struct { + InterviewRepo InterviewRepo + UserRepo user.UserRepo + BillingRepo billing.BillingRepo + AI chatgpt.AIClient + Logger *slog.Logger +} + var ErrNoValidCredits = errors.New("no valid credits") +func NewInterviewService(interviewRepo InterviewRepo, userRepo user.UserRepo, billingRepo billing.BillingRepo, ai chatgpt.AIClient, logger *slog.Logger) *InterviewService { + return &InterviewService{ + InterviewRepo: interviewRepo, + UserRepo: userRepo, + BillingRepo: billingRepo, + AI: ai, + Logger: logger, + } +} + type InterviewRepo interface { LinkConversation(interviewID, conversationID int) error CreateInterview(interview *Interview) (int, error) diff --git a/interview/service.go b/interview/service.go index 6278c4b..f0ffda6 100644 --- a/interview/service.go +++ b/interview/service.go @@ -2,7 +2,6 @@ package interview import ( "fmt" - "log" "time" "github.com/michaelboegner/interviewer/billing" @@ -10,20 +9,16 @@ import ( "github.com/michaelboegner/interviewer/user" ) -func StartInterview( - interviewRepo InterviewRepo, - userRepo user.UserRepo, - billingRepo billing.BillingRepo, - ai chatgpt.AIClient, +func (i *InterviewService) StartInterview( user *user.User, length, numberQuestions int, difficulty string, jd string) (*Interview, error) { - err := deductAndLogCredit(user, userRepo, billingRepo) + err := i.deductAndLogCredit(user) if err != nil { - log.Printf("checkCreditsLogTransaction failed: %v", err) + i.Logger.Error("checkCreditsLogTransaction failed", "error", err) return nil, err } @@ -31,23 +26,23 @@ func StartInterview( jdSummary := "" if jd != "" { - jdInput, err := ai.ExtractJDInput(jd) + jdInput, err := i.AI.ExtractJDInput(jd) if err != nil { - fmt.Printf("ai.ExtractJDInput() failed: %v", err) + i.Logger.Error("ai.ExtractJDInput() failed", "error", err) return nil, err } - jdSummary, err = ai.ExtractJDSummary(jdInput) + jdSummary, err = i.AI.ExtractJDSummary(jdInput) if err != nil { - fmt.Printf("ai.ExtractJDSummary() failed: %v", err) + i.Logger.Error("ai.ExtractJDSummary() failed", "error", err) return nil, err } } prompt := chatgpt.BuildPrompt([]string{}, "Introduction", 1, jdSummary) - chatGPTResponse, err := ai.GetChatGPTResponse(prompt) + chatGPTResponse, err := i.AI.GetChatGPTResponse(prompt) if err != nil { - log.Printf("getChatGPTResponse err: %v\n", err) + i.Logger.Error("getChatGPTResponse err", "error", err) return nil, err } @@ -67,9 +62,9 @@ func StartInterview( UpdatedAt: now, } - id, err := interviewRepo.CreateInterview(interview) + id, err := i.InterviewRepo.CreateInterview(interview) if err != nil { - log.Printf("CreateInterview err: %v", err) + i.Logger.Error("CreateInterview err", "error", err) return nil, err } interview.Id = id @@ -77,54 +72,40 @@ func StartInterview( return interview, nil } -func LinkConversation(interviewRepo InterviewRepo, interviewID, conversationID int) error { - err := interviewRepo.LinkConversation(interviewID, conversationID) +func (i *InterviewService) LinkConversation(interviewID, conversationID int) error { + err := i.InterviewRepo.LinkConversation(interviewID, conversationID) if err != nil { - log.Printf("interviewRepo.LinkConversation failed: %v", err) + i.Logger.Error("interviewRepo.LinkConversation failed", "error", err) return err } return nil } -func GetInterview(interviewRepo InterviewRepo, interviewID int) (*Interview, error) { - interview, err := interviewRepo.GetInterview(interviewID) +func (i *InterviewService) GetInterview(interviewID int) (*Interview, error) { + interview, err := i.InterviewRepo.GetInterview(interviewID) if err != nil { + i.Logger.Error("interviewRepo.GetInterview failed", "error", err) return nil, err } return interview, nil } -func canUseCredit(user *user.User) (string, error) { - now := time.Now() - - switch { - case user.SubscriptionEndDate != nil && - user.SubscriptionEndDate.After(now) && - user.SubscriptionStatus != "expired" && - user.SubscriptionCredits > 0: - return "subscription", nil - case user.IndividualCredits > 0: - return "individual", nil - default: - return "", ErrNoValidCredits - } -} - -func deductAndLogCredit(user *user.User, userRepo user.UserRepo, billingRepo billing.BillingRepo) error { - creditType, err := canUseCredit(user) +func (i *InterviewService) deductAndLogCredit(user *user.User) error { + creditType, err := i.canUseCredit(user) if err != nil { - log.Print("canUseCredit failed", err) + i.Logger.Error("canUseCredit failed", "error", err) return err } - if creditType != "" { - + if creditType == "" { + i.Logger.Info("user doesn't have a valid plan or credits") + return fmt.Errorf("user doesn't have a valid plan or credits") } - err = userRepo.AddCredits(user.ID, -1, creditType) + err = i.UserRepo.AddCredits(user.ID, -1, creditType) if err != nil { - log.Printf("AddCredits failed: %v", err) + i.Logger.Error("AddCredits failed", "error", err) return err } @@ -135,10 +116,29 @@ func deductAndLogCredit(user *user.User, userRepo user.UserRepo, billingRepo bil CreditType: creditType, Reason: reason, } - if err := billingRepo.LogCreditTransaction(tx); err != nil { - log.Printf("billingRepo.LogCreditTransaction failed: %v", err) + if err := i.BillingRepo.LogCreditTransaction(tx); err != nil { + i.Logger.Error("billingRepo.LogCreditTransaction failed", "error", err) return err } return nil } + +func (i *InterviewService) canUseCredit(user *user.User) (string, error) { + now := time.Now() + + switch { + case user.SubscriptionEndDate != nil && + user.SubscriptionEndDate.After(now) && + user.SubscriptionStatus != "expired" && + user.SubscriptionCredits > 0: + i.Logger.Info("subscription plan in canUseCredit check") + return "subscription", nil + case user.IndividualCredits > 0: + i.Logger.Info("individual plan in canUseCredit check") + return "individual", nil + default: + i.Logger.Info("no valid credits in canUseCredit check") + return "", ErrNoValidCredits + } +} diff --git a/interview/service_test.go b/interview/service_test.go index d84e7a1..1b51534 100644 --- a/interview/service_test.go +++ b/interview/service_test.go @@ -1,9 +1,7 @@ package interview_test import ( - "fmt" - "log" - "os" + "log/slog" "strings" "testing" "time" @@ -24,7 +22,7 @@ func TestStartInterview(t *testing.T) { length int numQuestions int difficulty string - aiClient *mocks.MockOpenAIClient + aiClient *mocks.MockAIService failRepo bool expected *interview.Interview expectError bool @@ -41,7 +39,7 @@ func TestStartInterview(t *testing.T) { length: 30, numQuestions: 3, difficulty: "easy", - aiClient: &mocks.MockOpenAIClient{}, + aiClient: &mocks.MockAIService{}, expected: &interview.Interview{ UserId: 1, Length: 30, @@ -67,7 +65,7 @@ func TestStartInterview(t *testing.T) { length: 30, numQuestions: 3, difficulty: "easy", - aiClient: &mocks.MockOpenAIClient{}, + aiClient: &mocks.MockAIService{}, failRepo: true, expectError: true, jdSummary: "", @@ -77,19 +75,14 @@ func TestStartInterview(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) - + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) repo := interview.NewMockRepo() userRepo := user.NewMockRepo() billingRepo := billing.NewMockRepo() + interviewService := interview.NewInterviewService(repo, userRepo, billingRepo, tc.aiClient, logger) repo.FailRepo = tc.failRepo - interviewStarted, err := interview.StartInterview( - repo, - userRepo, - billingRepo, - tc.aiClient, + interviewStarted, err := interviewService.StartInterview( tc.user, tc.length, tc.numQuestions, @@ -169,10 +162,11 @@ func TestGetInterview(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) - + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) repo := interview.NewMockRepo() + userRepo := user.NewMockRepo() + billingRepo := billing.NewMockRepo() + interviewService := interview.NewInterviewService(repo, userRepo, billingRepo, &mocks.MockAIService{}, logger) repo.FailRepo = tc.failRepo if tc.setup != nil { @@ -182,7 +176,7 @@ func TestGetInterview(t *testing.T) { } } - got, err := interview.GetInterview(repo, tc.interviewID) + got, err := interviewService.GetInterview(tc.interviewID) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -203,10 +197,3 @@ func TestGetInterview(t *testing.T) { }) } } - -func showLogsIfFail(t *testing.T, name string, buf strings.Builder) { - log.SetOutput(os.Stderr) - if t.Failed() { - fmt.Printf("---- logs for test: %s ----\n%s\n", name, buf.String()) - } -} diff --git a/learninglog/03_29_2025.md b/learninglog/03_29_2025.md index 8ca7220..dad0af9 100644 --- a/learninglog/03_29_2025.md +++ b/learninglog/03_29_2025.md @@ -10,7 +10,7 @@ just being hardcoded in. As a result, I removed the param and wrote a const vari 2. In needing to write another mock chatGPT response, I'm realizing that I also need to abstract the current interview GetChatGPTResponse method, likely to its own package so that CreateConversations() and AppendConversations() in the conversations package can also access it to both to be able to continue to mock chatGPT responses for my integration tests AND also because the redundancy of the function in both the interview package and the conversation package is just gross and inefficient. ^_^ -However, I currently already have a models package, where the ChatgptResponse struct lives that models the response we get back from OpenAI. I'm thinking I need to get rid of that models package and replace it with a ChatGPT package that could then have its own model and service and interface files. The model file would house the current ChatGPTResponse struct as well as the OpenAIClient struct and AIClient interface currently housed in interview/models. Then I would just import that package and use the resulting method inside the interview package and the two times in the conversation package. +However, I currently already have a models package, where the ChatgptResponse struct lives that models the response we get back from OpenAI. I'm thinking I need to get rid of that models package and replace it with a ChatGPT package that could then have its own model and service and interface files. The model file would house the current ChatGPTResponse struct as well as the AIService struct and AIClient interface currently housed in interview/models. Then I would just import that package and use the resulting method inside the interview package and the two times in the conversation package. ### 🔁 TODO diff --git a/mailer/model.go b/mailer/model.go index 36170c3..b892f96 100644 --- a/mailer/model.go +++ b/mailer/model.go @@ -5,7 +5,7 @@ import ( "os" ) -type Mailer struct { +type MailerService struct { APIKey string BaseURL string Logger *slog.Logger @@ -23,8 +23,8 @@ const signature = `

` -func NewMailer(logger *slog.Logger) *Mailer { - return &Mailer{ +func NewMailerService(logger *slog.Logger) *MailerService { + return &MailerService{ APIKey: os.Getenv("RESEND_API_KEY"), BaseURL: "https://api.resend.com", Logger: logger, diff --git a/mailer/service.go b/mailer/service.go index 70c3b6e..aa11a10 100644 --- a/mailer/service.go +++ b/mailer/service.go @@ -7,7 +7,7 @@ import ( "net/http" ) -func (m *Mailer) SendPasswordReset(email, resetURL string) error { +func (m *MailerService) SendPasswordReset(email, resetURL string) error { payload := map[string]any{ "from": "Interviewer Support ", "to": email, @@ -48,7 +48,7 @@ func (m *Mailer) SendPasswordReset(email, resetURL string) error { return nil } -func (m *Mailer) SendVerificationEmail(email, verifyURL string) error { +func (m *MailerService) SendVerificationEmail(email, verifyURL string) error { payload := map[string]any{ "from": "Interviewer Support ", @@ -79,7 +79,7 @@ func (m *Mailer) SendVerificationEmail(email, verifyURL string) error { return nil } -func (m *Mailer) SendWelcome(email string) error { +func (m *MailerService) SendWelcome(email string) error { payload := map[string]any{ "from": "Interviewer Support ", "to": email, @@ -128,7 +128,7 @@ func (m *Mailer) SendWelcome(email string) error { return nil } -func (m *Mailer) SendDeletionConfirmation(email string) error { +func (m *MailerService) SendDeletionConfirmation(email string) error { payload := map[string]any{ "from": "Interviewer Support ", "to": email, diff --git a/token/model.go b/token/model.go index 4cd8ad8..90d7b71 100644 --- a/token/model.go +++ b/token/model.go @@ -1,11 +1,17 @@ package token import ( + "log/slog" "time" "github.com/golang-jwt/jwt/v5" ) +type TokenService struct { + TokenRepo TokenRepo + Logger *slog.Logger +} + type RefreshToken struct { UserID int RefreshToken string @@ -14,6 +20,13 @@ type RefreshToken struct { UpdatedAt time.Time } +func NewTokenService(tokenRepo TokenRepo, logger *slog.Logger) *TokenService { + return &TokenService{ + TokenRepo: tokenRepo, + Logger: logger, + } +} + type CustomClaims struct { UserID string `json:"sub"` jwt.RegisteredClaims diff --git a/token/service.go b/token/service.go index f1c07f5..f8f3b44 100644 --- a/token/service.go +++ b/token/service.go @@ -5,7 +5,6 @@ import ( "crypto/subtle" "encoding/hex" "fmt" - "log" "os" "strconv" "time" @@ -13,7 +12,7 @@ import ( "github.com/golang-jwt/jwt/v5" ) -func CreateJWT(subject string, expires int) (string, error) { +func (t *TokenService) CreateJWT(subject string, expires int) (string, error) { var ( key []byte jwtoken *jwt.Token @@ -35,20 +34,21 @@ func CreateJWT(subject string, expires int) (string, error) { jwtoken = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := jwtoken.SignedString(key) if err != nil { - log.Fatalf("Bad SignedString: %s", err) + t.Logger.Error("Bad SignedString", "error", err) return "", err } return tokenString, nil } -func CreateRefreshToken(repo TokenRepo, userID int) (string, error) { +func (t *TokenService) CreateRefreshToken(userID int) (string, error) { now := time.Now().UTC() refreshLength := 32 refreshBytes := make([]byte, refreshLength) _, err := rand.Read([]byte(refreshBytes)) if err != nil { + t.Logger.Error("rand.Read failed", "error", err) return "", err } token := hex.EncodeToString(refreshBytes) @@ -62,37 +62,39 @@ func CreateRefreshToken(repo TokenRepo, userID int) (string, error) { UpdatedAt: now, } - err = repo.AddRefreshToken(refreshToken) + err = t.TokenRepo.AddRefreshToken(refreshToken) if err != nil { + t.Logger.Error("t.TokenRepo.AddRefreshToken failed", "error", err) return "", err } return refreshToken.RefreshToken, nil } -func DeleteRefreshToken(repo TokenRepo, userID int) error { - err := repo.DeleteRefreshToken(userID) +func (t *TokenService) DeleteRefreshToken(userID int) error { + err := t.TokenRepo.DeleteRefreshToken(userID) if err != nil { - log.Printf("repo.DeleteRefreshToken failed: %v", err) + t.Logger.Error("t.TokenRepo.DeleteRefreshToken failed", "error", err) return err } return nil } -func GetStoredRefreshToken(repo TokenRepo, userID int) (string, error) { - storedToken, err := repo.GetStoredRefreshToken(userID) +func (t *TokenService) GetStoredRefreshToken(userID int) (string, error) { + storedToken, err := t.TokenRepo.GetStoredRefreshToken(userID) if err != nil { + t.Logger.Error("t.TokenRepo.GetStoredRefreshToken failed", "error", err) return "", err } return storedToken, nil } -func VerifyRefreshToken(storedToken, providedToken string) bool { +func (t *TokenService) VerifyRefreshToken(storedToken, providedToken string) bool { return subtle.ConstantTimeCompare([]byte(storedToken), []byte(providedToken)) == 1 } -func ExtractUserIDFromToken(tokenString string) (int, error) { +func (t *TokenService) ExtractUserIDFromToken(tokenString string) (int, error) { jwtSecret := os.Getenv("JWT_SECRET") token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(tokenString *jwt.Token) (interface{}, error) { @@ -102,7 +104,7 @@ func ExtractUserIDFromToken(tokenString string) (int, error) { return []byte(jwtSecret), nil }) if err != nil { - log.Printf("ParseWithClaims failed: %v", err) + t.Logger.Error("jwt.ParseWithClaims failed", "error", err) return 0, err } @@ -110,6 +112,7 @@ func ExtractUserIDFromToken(tokenString string) (int, error) { if ok && token.Valid { userID, err := strconv.Atoi(claims.UserID) if err != nil { + t.Logger.Error("strconv.Atoi failed", "error", err) return 0, err } diff --git a/token/service_test.go b/token/service_test.go index e202a6c..fcc8db2 100644 --- a/token/service_test.go +++ b/token/service_test.go @@ -1,8 +1,7 @@ package token import ( - "fmt" - "log" + "log/slog" "os" "strconv" "strings" @@ -34,15 +33,12 @@ func TestCreateRefreshToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + tokenRepo := NewMockRepo() + tokenService := NewTokenService(tokenRepo, logger) + tokenRepo.failRepo = tc.failRepo - repo := NewMockRepo() - if tc.failRepo { - repo.failRepo = true - } - - token, err := CreateRefreshToken(repo, tc.userID) + token, err := tokenService.CreateRefreshToken(tc.userID) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -88,15 +84,12 @@ func TestGetStoredRefreshToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) - - repo := NewMockRepo() - if tc.failRepo { - repo.failRepo = true - } + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + tokenRepo := NewMockRepo() + tokenService := NewTokenService(tokenRepo, logger) + tokenRepo.failRepo = tc.failRepo - token, err := GetStoredRefreshToken(repo, tc.userID) + token, err := tokenService.GetStoredRefreshToken(tc.userID) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -141,10 +134,11 @@ func TestVerifyRefreshToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + tokenRepo := NewMockRepo() + tokenService := NewTokenService(tokenRepo, logger) - result := VerifyRefreshToken(tc.storedToken, tc.inputToken) + result := tokenService.VerifyRefreshToken(tc.storedToken, tc.inputToken) if result != tc.expected { t.Errorf("expected %v but got %v", tc.expected, result) @@ -178,20 +172,20 @@ func TestExtractUserIDFromToken(t *testing.T) { var buf strings.Builder var token string var err error - - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + tokenRepo := NewMockRepo() + tokenService := NewTokenService(tokenRepo, logger) if tc.invalid { token = "invalid.token.value" } else { - token, err = CreateJWT(strconv.Itoa(tc.userID), 3600) + token, err = tokenService.CreateJWT(strconv.Itoa(tc.userID), 3600) if err != nil { t.Fatalf("failed to create JWT: %v", err) } } - uid, err := ExtractUserIDFromToken(token) + uid, err := tokenService.ExtractUserIDFromToken(token) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -211,10 +205,3 @@ func TestExtractUserIDFromToken(t *testing.T) { }) } } - -func showLogsIfFail(t *testing.T, name string, buf strings.Builder) { - log.SetOutput(os.Stderr) - if t.Failed() { - fmt.Printf("---- logs for test: %s ----\n%s\n", name, buf.String()) - } -} diff --git a/user/model.go b/user/model.go index bdaceed..6931341 100644 --- a/user/model.go +++ b/user/model.go @@ -2,6 +2,7 @@ package user import ( "errors" + "log/slog" "time" "github.com/golang-jwt/jwt/v5" @@ -36,6 +37,18 @@ type EmailClaims struct { jwt.RegisteredClaims } +type UserService struct { + UserRepo UserRepo + Logger *slog.Logger +} + +func NewUserService(userRepo UserRepo, logger *slog.Logger) *UserService { + return &UserService{ + UserRepo: userRepo, + Logger: logger, + } +} + type UserRepo interface { CreateUser(user *User) (int, error) MarkUserDeleted(userID int) error diff --git a/user/repository_mock.go b/user/repository_mock.go index b0dbbf0..fad130d 100644 --- a/user/repository_mock.go +++ b/user/repository_mock.go @@ -10,7 +10,7 @@ import ( type MockRepo struct { Users map[int]User - failRepo bool + FailRepo bool FailGetUserByEmail bool FailAddCredits bool } @@ -27,12 +27,13 @@ func NewMockRepo() *MockRepo { } return &MockRepo{ - Users: map[int]User{}, + Users: map[int]User{}, + FailGetUserByEmail: false, } } func (m *MockRepo) CreateUser(user *User) (int, error) { - if m.failRepo { + if m.FailRepo { return 0, errors.New("mocked DB failure") } @@ -40,7 +41,7 @@ func (m *MockRepo) CreateUser(user *User) (int, error) { } func (m *MockRepo) MarkUserDeleted(userID int) error { - if m.failRepo { + if m.FailRepo { return errors.New("mocked DB failure") } @@ -48,7 +49,7 @@ func (m *MockRepo) MarkUserDeleted(userID int) error { } func (m *MockRepo) GetUser(userID int) (*User, error) { - if m.failRepo { + if m.FailRepo { return nil, errors.New("mocked DB failure") } @@ -64,7 +65,7 @@ func (m *MockRepo) GetUser(userID int) (*User, error) { } func (m *MockRepo) GetPasswordandID(username string) (int, string, error) { - if m.failRepo { + if m.FailRepo { return 0, "", errors.New("mocked DB failure") } @@ -75,7 +76,7 @@ func (m *MockRepo) GetUserByEmail(email string) (*User, error) { if m.FailGetUserByEmail { return nil, errors.New("mocked GetUserByEmail failure") } - if m.failRepo { + if m.FailRepo { return nil, errors.New("mocked DB failure") } @@ -90,7 +91,7 @@ func (m *MockRepo) GetUserByEmail(email string) (*User, error) { } func (m *MockRepo) GetUserByCustomerID(customerID string) (*User, error) { - if m.failRepo { + if m.FailRepo { return nil, errors.New("mocked DB failure") } @@ -105,7 +106,7 @@ func (m *MockRepo) GetUserByCustomerID(customerID string) (*User, error) { } func (m *MockRepo) UpdatePasswordByEmail(email string, password []byte) error { - if m.failRepo { + if m.FailRepo { return errors.New("mocked DB failure") } @@ -116,7 +117,7 @@ func (m *MockRepo) AddCredits(userID, credits int, creditType string) error { if m.FailAddCredits { return errors.New("mocked AddCredits failure") } - if m.failRepo { + if m.FailRepo { return errors.New("mocked DB failure") } @@ -124,7 +125,7 @@ func (m *MockRepo) AddCredits(userID, credits int, creditType string) error { } func (m *MockRepo) UpdateSubscriptionData(userID int, status, tier, subscriptionID string, startsAt, endsAt time.Time) error { - if m.failRepo { + if m.FailRepo { return errors.New("mocked DB failure") } @@ -132,7 +133,7 @@ func (m *MockRepo) UpdateSubscriptionData(userID int, status, tier, subscription } func (m *MockRepo) UpdateSubscriptionStatusData(userID int, status string) error { - if m.failRepo { + if m.FailRepo { return errors.New("mocked DB failure") } @@ -140,7 +141,7 @@ func (m *MockRepo) UpdateSubscriptionStatusData(userID int, status string) error } func (m *MockRepo) HasActiveOrCancelledSubscription(email string) (bool, error) { - if m.failRepo { + if m.FailRepo { return false, errors.New("mocked DB failure") } diff --git a/user/service.go b/user/service.go index 85a9583..44fd4dc 100644 --- a/user/service.go +++ b/user/service.go @@ -4,15 +4,13 @@ import ( "errors" "log" "os" - "strconv" "time" "github.com/golang-jwt/jwt/v5" - "github.com/michaelboegner/interviewer/token" "golang.org/x/crypto/bcrypt" ) -func VerificationToken(email, username, password string) (string, error) { +func (u *UserService) VerificationToken(email, username, password string) (string, error) { passwordHashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) if err != nil { return "", err @@ -30,7 +28,7 @@ func VerificationToken(email, username, password string) (string, error) { SignedString([]byte(os.Getenv("JWT_SECRET"))) } -func CreateUser(repo UserRepo, tokenStr string) (*User, error) { +func (u *UserService) CreateUser(tokenStr string) (*User, error) { claims := &EmailClaims{} tkn, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { return []byte(os.Getenv("JWT_SECRET")), nil @@ -52,7 +50,7 @@ func CreateUser(repo UserRepo, tokenStr string) (*User, error) { UpdatedAt: time.Now().UTC(), } - id, err := repo.CreateUser(user) + id, err := u.UserRepo.CreateUser(user) if err != nil { return nil, err } @@ -60,38 +58,32 @@ func CreateUser(repo UserRepo, tokenStr string) (*User, error) { return user, nil } -func LoginUser(repo UserRepo, email, password string) (string, string, int, error) { - userID, hashedPassword, err := repo.GetPasswordandID(email) +func (u *UserService) LoginUser(email, password string) (string, int, error) { + userID, hashedPassword, err := u.UserRepo.GetPasswordandID(email) if err != nil { - return "", "", 0, err + return "", 0, err } - user, err := repo.GetUser(userID) + user, err := u.UserRepo.GetUser(userID) if err != nil { - log.Printf("repo.GetUser failed: %v", err) - return "", "", 0, err + log.Printf("u.UserRepo.GetUser failed: %v", err) + return "", 0, err } if user.AccountStatus != "active" { - return "", "", 0, ErrAccountDeleted + return "", 0, ErrAccountDeleted } err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) if err != nil { - return "", "", 0, err + return "", 0, err } - jwToken, err := token.CreateJWT(strconv.Itoa(userID), 0) - if err != nil { - log.Printf("JWT creation failed: %v", err) - return "", "", 0, err - } - - return jwToken, user.Username, userID, nil + return user.Username, userID, nil } -func GetUser(repo UserRepo, userID int) (*User, error) { - userReturned, err := repo.GetUser(userID) +func (u *UserService) GetUser(userID int) (*User, error) { + userReturned, err := u.UserRepo.GetUser(userID) if err != nil { log.Printf("GetUser failed: %v", err) return nil, err @@ -99,43 +91,27 @@ func GetUser(repo UserRepo, userID int) (*User, error) { return userReturned, nil } -func MarkUserDeleted(repo UserRepo, userId int) error { - err := repo.MarkUserDeleted(userId) +func (u *UserService) MarkUserDeleted(userId int) error { + err := u.UserRepo.MarkUserDeleted(userId) if err != nil { - log.Printf("repo.DeleteUser failed: %v", err) + log.Printf("u.UserRepo.DeleteUser failed: %v", err) return err } return nil } -func GetUserByEmail(repo UserRepo, email string) error { - _, err := repo.GetUserByEmail(email) +func (u *UserService) GetUserByEmail(email string) error { + _, err := u.UserRepo.GetUserByEmail(email) if err != nil { - log.Printf("repo.GetUserByEmail failed: %v", err) + log.Printf("u.UserRepo.GetUserByEmail failed: %v", err) return err } return nil } -func RequestPasswordReset(repo UserRepo, email string) (string, error) { - user, err := repo.GetUserByEmail(email) - if err != nil { - log.Printf("GetUserByEmail failed: %v", err) - return "", err - } - - resetJWT, err := token.CreateJWT(user.Email, 900) - if err != nil { - log.Printf("CreateJWT failed: %v", err) - return "", err - } - - return resetJWT, nil -} - -func ResetPassword(repo UserRepo, newPassword string, resetJWT string) error { +func (u *UserService) ResetPassword(newPassword string, resetJWT string) error { email, err := verifyResetToken(resetJWT) if err != nil { return err @@ -147,7 +123,7 @@ func ResetPassword(repo UserRepo, newPassword string, resetJWT string) error { return err } - err = repo.UpdatePasswordByEmail(email, passwordHashed) + err = u.UserRepo.UpdatePasswordByEmail(email, passwordHashed) if err != nil { log.Printf("UpdatePasswordByEmail failed: %v", err) return err @@ -156,29 +132,8 @@ func ResetPassword(repo UserRepo, newPassword string, resetJWT string) error { return nil } -func verifyResetToken(tokenString string) (string, error) { - jwtSecret := os.Getenv("JWT_SECRET") - if jwtSecret == "" { - log.Printf("JWT secret is not set") - err := errors.New("jwt secret is not set") - return "", err - } - - token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte(jwtSecret), nil - }) - if err != nil { - return "", err - } - if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid { - return claims.Subject, nil - } else { - return "", errors.New("invalid token") - } -} - -func GetOrCreateByEmail(repo UserRepo, email, username string) (*User, error) { - user, err := repo.GetUserByEmail(email) +func (u *UserService) GetOrCreateByEmail(email, username string) (*User, error) { + user, err := u.UserRepo.GetUserByEmail(email) if err == nil { return user, nil } @@ -193,7 +148,7 @@ func GetOrCreateByEmail(repo UserRepo, email, username string) (*User, error) { UpdatedAt: time.Now().UTC(), } - id, err := repo.CreateUser(newUser) + id, err := u.UserRepo.CreateUser(newUser) if err != nil { log.Printf("CreateUser failed: %v", err) return nil, err @@ -202,3 +157,24 @@ func GetOrCreateByEmail(repo UserRepo, email, username string) (*User, error) { newUser.ID = id return newUser, nil } + +func verifyResetToken(tokenString string) (string, error) { + jwtSecret := os.Getenv("JWT_SECRET") + if jwtSecret == "" { + log.Printf("JWT secret is not set") + err := errors.New("jwt secret is not set") + return "", err + } + + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) + if err != nil { + return "", err + } + if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid { + return claims.Subject, nil + } else { + return "", errors.New("invalid token") + } +} diff --git a/user/service_test.go b/user/service_test.go index 3888198..169f12f 100644 --- a/user/service_test.go +++ b/user/service_test.go @@ -3,12 +3,15 @@ package user import ( "fmt" "log" + "log/slog" "os" + "strconv" "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/michaelboegner/interviewer/token" "golang.org/x/crypto/bcrypt" ) @@ -49,17 +52,16 @@ func TestCreateUser(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + userRepo := NewMockRepo() + userService := NewUserService(userRepo, logger) + userRepo.FailRepo = tc.failRepo - repo := NewMockRepo() - repo.failRepo = tc.failRepo - - jwt, err := VerificationToken(tc.email, tc.username, tc.password) + jwt, err := userService.VerificationToken(tc.email, tc.username, tc.password) if err != nil { t.Fatalf("VerificationToken failed: %v", err) } - user, err := CreateUser(repo, jwt) + user, err := userService.CreateUser(jwt) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -114,22 +116,34 @@ func TestLoginUser(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) - - repo := NewMockRepo() - repo.failRepo = tc.failRepo - - jwtoken, username, userID, err := LoginUser(repo, tc.email, tc.password) - - if tc.expectError && err == nil { - t.Fatalf("expected error but got nil") + var ( + buf strings.Builder + jwToken string + ) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + userRepo := NewMockRepo() + tokenRepo := token.NewMockRepo() + userService := NewUserService(userRepo, logger) + tokenService := token.NewTokenService(tokenRepo, logger) + userRepo.FailRepo = tc.failRepo + + username, userID, err := userService.LoginUser(tc.email, tc.password) + if tc.expectError { + if err == nil { + t.Fatalf("expected error but got nil") + } + return } - if !tc.expectError && err != nil { + + if err != nil { t.Fatalf("did not expect error but got: %v", err) } + jwToken, err = tokenService.CreateJWT(strconv.Itoa(userID), 0) + if err != nil { + t.Fatalf("JWT creation failed: %v", err) + } + if !tc.expectError { expected := tc.userID got := userID @@ -137,7 +151,7 @@ func TestLoginUser(t *testing.T) { if diff := cmp.Diff(expected, got); diff != "" { t.Errorf("User mismatch (-want +got):\n%s", diff) } - if jwtoken == "" { + if jwToken == "" { t.Errorf("Expected jwtoken but got empty string") } if username == "" { @@ -178,13 +192,12 @@ func TestGetUser(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + userRepo := NewMockRepo() + userService := NewUserService(userRepo, logger) + userRepo.FailRepo = tc.failRepo - repo := NewMockRepo() - repo.failRepo = tc.failRepo - - user, err := GetUser(repo, tc.userID) + user, err := userService.GetUser(tc.userID) if tc.expectError && err == nil { t.Fatalf("expected error but got nil") @@ -231,13 +244,12 @@ func TestUpdateSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var buf strings.Builder - log.SetOutput(&buf) - defer showLogsIfFail(t, tc.name, buf) - - repo := NewMockRepo() - repo.failRepo = tc.failRepo + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: true})) + userRepo := NewMockRepo() + userService := NewUserService(userRepo, logger) + userRepo.FailRepo = tc.failRepo - user, err := GetUser(repo, tc.userID) + user, err := userService.GetUser(tc.userID) if tc.expectError && err == nil { t.Fatalf("expected error but got nil")