diff --git a/api/handlers/auth.go b/api/handlers/auth.go index af2ddd2db5..ba86142f3a 100644 --- a/api/handlers/auth.go +++ b/api/handlers/auth.go @@ -25,13 +25,45 @@ import ( func (h *Handler) InitSSO(w http.ResponseWriter, r *http.Request) { configuration := h.A.Cfg + billingEnabled := configuration.Billing.Enabled && h.A.BillingClient != nil + slug := strings.TrimSpace(r.URL.Query().Get("slug")) + + licenseKey := configuration.LicenseKey + if billingEnabled && slug != "" { + orgRepo := organisations.New(h.A.Logger, h.A.DB) + orgMemberRepo := organisation_members.New(h.A.Logger, h.A.DB) + result, err := services.ResolveWorkspaceBySlug(r.Context(), slug, services.ResolveWorkspaceBySlugDeps{ + BillingClient: h.A.BillingClient, + OrgRepo: orgRepo, + Logger: h.A.Logger, + Cfg: configuration, + RefreshDeps: services.RefreshLicenseDataDeps{ + OrgMemberRepo: orgMemberRepo, + OrgRepo: orgRepo, + BillingClient: h.A.BillingClient, + Logger: h.A.Logger, + Cfg: configuration, + }, + }) + if err != nil { + h.A.Logger.WithError(err).WithField("slug", slug).Debug("InitSSO: workspace resolve failed") + _ = render.Render(w, r, util.NewErrorResponse("Workspace not found", http.StatusBadRequest)) + return + } + if !result.SSOAvailable { + _ = render.Render(w, r, util.NewErrorResponse("SSO is not available for this workspace", http.StatusBadRequest)) + return + } + licenseKey = result.LicenseKey + } + lu := services.LoginUserSSOService{ UserRepo: users.New(h.A.Logger, h.A.DB), OrgRepo: organisations.New(h.A.Logger, h.A.DB), OrgMemberRepo: organisation_members.New(h.A.Logger, h.A.DB), JWT: jwt.NewJwt(&configuration.Auth.Jwt, h.A.Cache), ConfigRepo: h.A.ConfigRepo, - LicenseKey: configuration.LicenseKey, + LicenseKey: licenseKey, Host: configuration.Host, Licenser: h.A.Licenser, } diff --git a/api/handlers/configuration.go b/api/handlers/configuration.go index 5fbfe2cf4d..22357e3d45 100644 --- a/api/handlers/configuration.go +++ b/api/handlers/configuration.go @@ -3,6 +3,7 @@ package handlers import ( "errors" "net/http" + "strings" "github.com/go-chi/render" @@ -116,7 +117,31 @@ func (h *Handler) GetAuthConfiguration(w http.ResponseWriter, r *http.Request) { _ = render.Render(w, r, util.NewErrorResponse("failed to load configuration", http.StatusBadRequest)) return } + billingEnabled := cfg.Billing.Enabled && h.A.BillingClient != nil + slug := strings.TrimSpace(r.URL.Query().Get("slug")) + + ssoEnabled := h.A.Licenser.EnterpriseSSO() + if billingEnabled && slug != "" { + result, err := services.ResolveWorkspaceBySlug(r.Context(), slug, services.ResolveWorkspaceBySlugDeps{ + BillingClient: h.A.BillingClient, + OrgRepo: h.A.OrgRepo, + Logger: h.A.Logger, + Cfg: cfg, + RefreshDeps: services.RefreshLicenseDataDeps{ + OrgMemberRepo: h.A.OrgMemberRepo, + OrgRepo: h.A.OrgRepo, + BillingClient: h.A.BillingClient, + Logger: h.A.Logger, + Cfg: cfg, + }, + }) + if err == nil { + ssoEnabled = result.SSOAvailable + } + } + authConfig := map[string]interface{}{ + "billing_enabled": billingEnabled, "is_signup_enabled": cfg.Auth.IsSignupEnabled, "google_oauth": map[string]interface{}{ "enabled": cfg.Auth.GoogleOAuth.Enabled && h.A.Licenser.GoogleOAuth(), @@ -124,7 +149,7 @@ func (h *Handler) GetAuthConfiguration(w http.ResponseWriter, r *http.Request) { "redirect_url": cfg.Auth.GoogleOAuth.RedirectURL, }, "sso": map[string]interface{}{ - "enabled": h.A.Licenser.EnterpriseSSO(), + "enabled": ssoEnabled, "redirect_url": cfg.Auth.SSO.RedirectURL, }, } diff --git a/api/handlers/source.go b/api/handlers/source.go index 6dc495989b..430165178a 100644 --- a/api/handlers/source.go +++ b/api/handlers/source.go @@ -317,6 +317,9 @@ func (h *Handler) LoadSourcesPaged(w http.ResponseWriter, r *http.Request) { } func fillSourceURL(s *datastore.Source, baseUrl, customDomain string) { + if s == nil { + return + } url := baseUrl if len(customDomain) > 0 { url = customDomain diff --git a/internal/pkg/billing/client.go b/internal/pkg/billing/client.go index 0368a69ca3..5438bff327 100644 --- a/internal/pkg/billing/client.go +++ b/internal/pkg/billing/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" @@ -24,6 +25,7 @@ type Client interface { CreateOrganisation(ctx context.Context, orgData BillingOrganisation) (*Response[BillingOrganisation], error) GetOrganisationLicense(ctx context.Context, orgID string) (*Response[OrganisationLicense], error) GetOrganisation(ctx context.Context, orgID string) (*Response[BillingOrganisation], error) + GetWorkspaceConfigBySlug(ctx context.Context, slug string) (*Response[WorkspaceConfigData], error) UpdateOrganisation(ctx context.Context, orgID string, orgData BillingOrganisation) (*Response[BillingOrganisation], error) UpdateOrganisationTaxID(ctx context.Context, orgID string, taxData UpdateOrganisationTaxIDRequest) (*Response[BillingOrganisation], error) UpdateOrganisationAddress(ctx context.Context, orgID string, addressData UpdateOrganisationAddressRequest) (*Response[BillingOrganisation], error) @@ -157,6 +159,14 @@ func (c *HTTPClient) GetOrganisation(ctx context.Context, orgID string) (*Respon return makeRequest[BillingOrganisation](ctx, c.httpClient, c.config, "GET", fmt.Sprintf("/organisations/%s", orgID), nil) } +func (c *HTTPClient) GetWorkspaceConfigBySlug(ctx context.Context, slug string) (*Response[WorkspaceConfigData], error) { + if slug == "" { + return nil, fmt.Errorf("slug is required") + } + path := fmt.Sprintf("/api/v1/workspace_config?slug=%s", strings.ReplaceAll(url.QueryEscape(slug), "+", "%20")) + return makeRequest[WorkspaceConfigData](ctx, c.httpClient, c.config, "GET", path, nil) +} + func (c *HTTPClient) UpdateOrganisation(ctx context.Context, orgID string, orgData BillingOrganisation) (*Response[BillingOrganisation], error) { return makeRequest[BillingOrganisation](ctx, c.httpClient, c.config, "PUT", fmt.Sprintf("/organisations/%s", orgID), orgData) } diff --git a/internal/pkg/billing/mock.go b/internal/pkg/billing/mock.go index f5928de284..2ded1677aa 100644 --- a/internal/pkg/billing/mock.go +++ b/internal/pkg/billing/mock.go @@ -155,6 +155,17 @@ func (m *MockBillingClient) GetOrganisationLicense(ctx context.Context, orgID st }, nil } +func (m *MockBillingClient) GetWorkspaceConfigBySlug(ctx context.Context, slug string) (*Response[WorkspaceConfigData], error) { + if slug == "" { + return nil, &Error{Message: "slug is required"} + } + return &Response[WorkspaceConfigData]{ + Status: true, + Message: "OK", + Data: WorkspaceConfigData{ExternalID: slug, SSOAvailable: false}, + }, nil +} + func (m *MockBillingClient) UpdateOrganisation(ctx context.Context, orgID string, orgData BillingOrganisation) (*Response[BillingOrganisation], error) { if orgID == "" || orgData.Name == "" { return nil, &Error{Message: "invalid organisation update"} diff --git a/internal/pkg/billing/models.go b/internal/pkg/billing/models.go index c47dce2123..700bf33e38 100644 --- a/internal/pkg/billing/models.go +++ b/internal/pkg/billing/models.go @@ -1,5 +1,12 @@ package billing +// WorkspaceConfigData is the workspace_config API response. +type WorkspaceConfigData struct { + ExternalID string `json:"external_id"` + LicenseKey string `json:"license_key"` + SSOAvailable bool `json:"sso_available"` +} + type BillingOrganisation struct { ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` diff --git a/services/workspace_slug.go b/services/workspace_slug.go new file mode 100644 index 0000000000..345a0809f2 --- /dev/null +++ b/services/workspace_slug.go @@ -0,0 +1,89 @@ +package services + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/frain-dev/convoy/config" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/internal/pkg/billing" + licensesvc "github.com/frain-dev/convoy/internal/pkg/license/service" + "github.com/frain-dev/convoy/pkg/log" +) + +// ResolveWorkspaceBySlugDeps holds dependencies for ResolveWorkspaceBySlug. +type ResolveWorkspaceBySlugDeps struct { + BillingClient billing.Client + OrgRepo datastore.OrganisationRepository + Logger log.StdLogger + Cfg config.Configuration + RefreshDeps RefreshLicenseDataDeps +} + +// ResolveWorkspaceBySlugResult is the result of ResolveWorkspaceBySlug. +type ResolveWorkspaceBySlugResult struct { + ExternalID string + LicenseKey string + SSOAvailable bool + Org *datastore.Organisation +} + +// ResolveWorkspaceBySlug resolves workspace by slug via billing and syncs license data for the org. +func ResolveWorkspaceBySlug(ctx context.Context, slug string, deps ResolveWorkspaceBySlugDeps) (*ResolveWorkspaceBySlugResult, error) { + if slug == "" { + return nil, errors.New("slug is required") + } + if deps.BillingClient == nil { + return nil, errors.New("billing client is required") + } + if deps.OrgRepo == nil { + return nil, errors.New("org repo is required") + } + + reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + resp, err := deps.BillingClient.GetWorkspaceConfigBySlug(reqCtx, slug) + if err != nil { + if deps.Logger != nil { + deps.Logger.WithError(err).WithField("slug", slug).Debug("workspace_config by slug failed") + } + return nil, fmt.Errorf("workspace not found: %w", err) + } + if !resp.Status { + return nil, errors.New("workspace not found") + } + if resp.Data.ExternalID == "" { + return nil, errors.New("workspace config missing external_id") + } + + org, err := deps.OrgRepo.FetchOrganisationByID(ctx, resp.Data.ExternalID) + if err != nil { + return nil, fmt.Errorf("organisation not found for workspace: %w", err) + } + + defaultKey := deps.Cfg.LicenseKey + billingEnabled := deps.Cfg.Billing.Enabled && deps.RefreshDeps.BillingClient != nil + licClient := licensesvc.NewClient(licensesvc.Config{ + Host: deps.Cfg.LicenseService.Host, + ValidatePath: deps.Cfg.LicenseService.ValidatePath, + Timeout: deps.Cfg.LicenseService.Timeout, + RetryCount: deps.Cfg.LicenseService.RetryCount, + Logger: deps.Logger, + }) + RefreshLicenseDataForOrg(ctx, *org, defaultKey, billingEnabled, deps.RefreshDeps, licClient) + + org, err = deps.OrgRepo.FetchOrganisationByID(ctx, resp.Data.ExternalID) + if err != nil { + org = nil + } + + return &ResolveWorkspaceBySlugResult{ + ExternalID: resp.Data.ExternalID, + LicenseKey: resp.Data.LicenseKey, + SSOAvailable: resp.Data.SSOAvailable, + Org: org, + }, nil +} diff --git a/web/ui/dashboard/src/app/private/pages/project/subscriptions/subscriptions.component.html b/web/ui/dashboard/src/app/private/pages/project/subscriptions/subscriptions.component.html index cf1a2d41d3..3164410c1d 100644 --- a/web/ui/dashboard/src/app/private/pages/project/subscriptions/subscriptions.component.html +++ b/web/ui/dashboard/src/app/private/pages/project/subscriptions/subscriptions.component.html @@ -86,9 +86,9 @@
-
+