diff --git a/plugins/contrib/internal/oauthutil/token.go b/plugins/contrib/internal/oauthutil/token.go index 01c1fa2..2dc10b3 100644 --- a/plugins/contrib/internal/oauthutil/token.go +++ b/plugins/contrib/internal/oauthutil/token.go @@ -91,9 +91,6 @@ const defaultExpiryMargin = 1 * time.Minute // NewTokenFetcher creates a TokenFetcher with defaults applied for nil fields. func NewTokenFetcher(cfg TokenFetcher) *TokenFetcher { - if cfg.Logger == nil { - cfg.Logger = slog.Default() - } if cfg.Client == nil { cfg.Client = DefaultHTTPClient() } @@ -103,6 +100,16 @@ func NewTokenFetcher(cfg TokenFetcher) *TokenFetcher { return &cfg } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (tf *TokenFetcher) log() *slog.Logger { + if tf.Logger != nil { + return tf.Logger + } + return slog.Default() +} + // Exchange sends a token request with the given form parameters and returns // the parsed result. The caller is responsible for setting grant-type-specific // form parameters (e.g., refresh_token) before calling Exchange. @@ -169,7 +176,7 @@ func (tf *TokenFetcher) doRequest(ctx context.Context, req *http.Request) (statu if ctx.Err() != nil { return 0, nil, nil, fmt.Errorf("token request for %s: %w", tf.TokenURL, ctx.Err()) } - tf.Logger.LogAttrs(ctx, slog.LevelWarn, "token endpoint request failed", + tf.log().LogAttrs(ctx, slog.LevelWarn, "token endpoint request failed", slog.String("token_url", tf.TokenURL), slog.String("error", err.Error())) return 0, nil, nil, fmt.Errorf("token endpoint request for %s: %w", @@ -193,7 +200,7 @@ func (tf *TokenFetcher) doRequest(ctx context.Context, req *http.Request) (statu func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int, header http.Header, body []byte) error { contentType := header.Get("Content-Type") - if tf.Logger.Enabled(ctx, slog.LevelDebug) { + if tf.log().Enabled(ctx, slog.LevelDebug) { attrs := []slog.Attr{ slog.Int("status", statusCode), slog.String("content_type", contentType), @@ -209,7 +216,7 @@ func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int, } } - tf.Logger.LogAttrs(ctx, slog.LevelDebug, "token endpoint error response", attrs...) + tf.log().LogAttrs(ctx, slog.LevelDebug, "token endpoint error response", attrs...) } if statusCode == http.StatusUnauthorized { @@ -218,7 +225,7 @@ func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int, } if statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError { - tf.Logger.LogAttrs(ctx, slog.LevelWarn, "token endpoint unavailable", + tf.log().LogAttrs(ctx, slog.LevelWarn, "token endpoint unavailable", slog.String("token_url", tf.TokenURL), slog.Int("status", statusCode)) return fmt.Errorf("token endpoint returned %d (content-type: %s): %w", diff --git a/plugins/contrib/internal/oauthutil/token_test.go b/plugins/contrib/internal/oauthutil/token_test.go index f942095..3363ac2 100644 --- a/plugins/contrib/internal/oauthutil/token_test.go +++ b/plugins/contrib/internal/oauthutil/token_test.go @@ -5,6 +5,8 @@ package oauthutil import ( "encoding/json" + "io" + "log/slog" "testing" "time" ) @@ -30,8 +32,10 @@ func TestNewTokenFetcher_Defaults(t *testing.T) { TokenURL: "https://example.com/token", }) - if f.Logger == nil { - t.Error("Logger should default to non-nil") + // Logger field must be nil — no eager slog.Default() capture at construction. + // Lazy resolution happens via log() at call time. + if f.Logger != nil { + t.Error("Logger field should be nil when not provided; lazy resolution via log()") } if f.Client == nil { t.Error("Client should default to non-nil") @@ -41,6 +45,26 @@ func TestNewTokenFetcher_Defaults(t *testing.T) { } } +func TestNewTokenFetcher_NilLogger_LazyResolution(t *testing.T) { + f := NewTokenFetcher(TokenFetcher{TokenURL: "https://example.com/token"}) + + if f.log() != slog.Default() { + t.Error("log() should return slog.Default() when Logger is nil") + } +} + +func TestNewTokenFetcher_ExplicitLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + f := NewTokenFetcher(TokenFetcher{ + TokenURL: "https://example.com/token", + Logger: custom, + }) + + if f.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +} + func TestNewTokenFetcher_PreservesExplicitValues(t *testing.T) { margin := 5 * time.Minute f := NewTokenFetcher(TokenFetcher{ diff --git a/plugins/contrib/microsoft/token.go b/plugins/contrib/microsoft/token.go index 46f90a1..a57dd83 100644 --- a/plugins/contrib/microsoft/token.go +++ b/plugins/contrib/microsoft/token.go @@ -163,11 +163,6 @@ func NewRefreshTokenSource(cfg Config) *RefreshTokenSource { margin = defaultExpiryMargin } - logger := cfg.Logger - if logger == nil { - logger = slog.Default() - } - client := cfg.HTTPClient if client == nil { client = oauthutil.DefaultHTTPClient() @@ -182,13 +177,23 @@ func NewRefreshTokenSource(cfg Config) *RefreshTokenSource { maxPoolSize: maxPool, expiryMargin: margin, httpClient: client, - logger: logger, + logger: cfg.Logger, onSaveError: cfg.OnSaveError, pool: make(map[string]*list.Element), lru: list.New(), } } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (s *RefreshTokenSource) log() *slog.Logger { + if s.logger != nil { + return s.logger + } + return slog.Default() +} + // GetCredentials extracts TenantID and Resource from the transaction context, // resolves a per-tenant entry from the pool, checks the per-resource access // token cache, and if needed exchanges the tenant's MRRT for a new access @@ -240,14 +245,14 @@ func (s *RefreshTokenSource) getAccessToken( entry.mu.RLock() if ct, ok := entry.tokens[resource]; ok && time.Now().Before(ct.expiresAt) { entry.mu.RUnlock() - s.logger.LogAttrs(ctx, slog.LevelDebug, "access token cache hit", + s.log().LogAttrs(ctx, slog.LevelDebug, "access token cache hit", slog.String("tenant_id", entry.tenantID), slog.String("resource", resource)) return ct.accessToken, ct.expiresAt, nil } entry.mu.RUnlock() - s.logger.LogAttrs(ctx, slog.LevelDebug, "access token cache miss", + s.log().LogAttrs(ctx, slog.LevelDebug, "access token cache miss", slog.String("tenant_id", entry.tenantID), slog.String("resource", resource)) @@ -317,7 +322,7 @@ func (s *RefreshTokenSource) exchange( if result.RefreshToken != "" { if saveErr := s.store.Save(ctx, tenantID, result.RefreshToken); saveErr != nil { - s.logger.LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token", + s.log().LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token", slog.String("tenant_id", tenantID), slog.String("resource", resource), slog.String("error", saveErr.Error())) @@ -371,7 +376,7 @@ func (s *RefreshTokenSource) evictLRU(ctx context.Context) { } entry, _ := back.Value.(*tenantEntry) - s.logger.LogAttrs(ctx, slog.LevelDebug, "evicting LRU pool entry", + s.log().LogAttrs(ctx, slog.LevelDebug, "evicting LRU pool entry", slog.String("tenant_id", entry.tenantID)) delete(s.pool, entry.tenantID) diff --git a/plugins/contrib/microsoft/token_test.go b/plugins/contrib/microsoft/token_test.go index f276911..99e4f55 100644 --- a/plugins/contrib/microsoft/token_test.go +++ b/plugins/contrib/microsoft/token_test.go @@ -1446,3 +1446,33 @@ func TestGetCredentials_KeyResolver_ValidTenantIDCheck_RejectsBadResolverValue(t t.Errorf("error = %v, want errors.Is(ErrInvalidContextData)", err) } } + +func TestNewRefreshTokenSource_NilLogger_LazyResolution(t *testing.T) { + src := NewRefreshTokenSource(Config{ + ClientID: "id", + ClientSecret: "secret", + Store: newMemoryTokenStore(), + }) + + // logger field must be nil — no eager slog.Default() at construction. + if src.logger != nil { + t.Error("logger field should be nil when not provided; lazy resolution via log()") + } + if src.log() != slog.Default() { + t.Error("log() should return slog.Default() when logger is nil") + } +} + +func TestNewRefreshTokenSource_ExplicitLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + src := NewRefreshTokenSource(Config{ + ClientID: "id", + ClientSecret: "secret", + Store: newMemoryTokenStore(), + Logger: custom, + }) + + if src.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +} diff --git a/plugins/contrib/mux.go b/plugins/contrib/mux.go index b7c06ec..42f444a 100644 --- a/plugins/contrib/mux.go +++ b/plugins/contrib/mux.go @@ -57,15 +57,23 @@ func WithLogger(l *slog.Logger) MuxOption { // NewMux creates a new request multiplexer. Use [Mux.Handle] and // [Mux.Default] to register routes before serving traffic. func NewMux(opts ...MuxOption) *Mux { - m := &Mux{ - logger: slog.Default(), - } + m := &Mux{} for _, opt := range opts { opt(m) } return m } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (m *Mux) log() *slog.Logger { + if m.logger != nil { + return m.logger + } + return slog.Default() +} + // Handle registers a route that dispatches matching requests to the // given provider. Routes are evaluated by specificity at dispatch time, // not registration order — but registration order breaks ties. @@ -81,7 +89,7 @@ func (m *Mux) Handle(route Route, provider sdk.CredentialProvider) { newSpec := route.Specificity() for _, e := range m.entries { if e.route.Specificity() == newSpec && routesMayOverlap(e.route, route) { - m.logger.Warn("routes registered with equal specificity may overlap, first registered wins on tie", + m.log().Warn("routes registered with equal specificity may overlap, first registered wins on tie", "existing_route", routeString(e.route), "new_route", routeString(route), ) diff --git a/plugins/contrib/mux_test.go b/plugins/contrib/mux_test.go index f965b60..28301fa 100644 --- a/plugins/contrib/mux_test.go +++ b/plugins/contrib/mux_test.go @@ -6,6 +6,7 @@ package contrib import ( "context" "errors" + "io" "log/slog" "net/http" "sync" @@ -602,3 +603,24 @@ func TestMux_Compliance(t *testing.T) { mux.Default(&namedProvider{name: "compliance"}) compliance.VerifyContract(t, mux) } + +func TestNewMux_NilLogger_LazyResolution(t *testing.T) { + m := NewMux() + + // logger field must be nil — no eager slog.Default() at construction. + if m.logger != nil { + t.Error("logger field should be nil when not provided; lazy resolution via log()") + } + if m.log() != slog.Default() { + t.Error("log() should return slog.Default() when logger is nil") + } +} + +func TestNewMux_WithLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + m := NewMux(WithLogger(custom)) + + if m.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +} diff --git a/plugins/contrib/oauth/refreshtoken.go b/plugins/contrib/oauth/refreshtoken.go index 09b7054..28ed868 100644 --- a/plugins/contrib/oauth/refreshtoken.go +++ b/plugins/contrib/oauth/refreshtoken.go @@ -116,6 +116,16 @@ func NewRefreshToken(cfg RefreshTokenConfig) *RefreshToken { return rt } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (rt *RefreshToken) log() *slog.Logger { + if rt.logger != nil { + return rt.logger + } + return slog.Default() +} + // GetCredentials fetches an OAuth2 bearer token using the refresh token grant // and returns it as a cacheable credential (Fast Path). // @@ -151,7 +161,7 @@ func (rt *RefreshToken) fetch(ctx context.Context) (*cachedToken, error) { if result.RefreshToken != "" { if saveErr := rt.store.Save(ctx, result.RefreshToken); saveErr != nil { - rt.logger.LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token", + rt.log().LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token", slog.String("token_url", rt.fetcher.TokenURL), slog.String("error", saveErr.Error())) diff --git a/plugins/contrib/oauth/refreshtoken_test.go b/plugins/contrib/oauth/refreshtoken_test.go index cb678d3..3d8ee6d 100644 --- a/plugins/contrib/oauth/refreshtoken_test.go +++ b/plugins/contrib/oauth/refreshtoken_test.go @@ -975,3 +975,37 @@ func TestRefreshToken_Compliance(t *testing.T) { plugin := contrib.AsPlugin(rt) compliance.VerifyContract(t, plugin) } + +func TestNewRefreshToken_NilLogger_LazyResolution(t *testing.T) { + store := newMemoryStore("tok") + rt := NewRefreshToken(RefreshTokenConfig{ + TokenURL: "https://example.com/token", + ClientID: "id", + ClientSecret: "secret", + Store: store, + }) + + // logger field must be nil — no eager slog.Default() at construction. + if rt.logger != nil { + t.Error("logger field should be nil when not provided; lazy resolution via log()") + } + if rt.log() != slog.Default() { + t.Error("log() should return slog.Default() when logger is nil") + } +} + +func TestNewRefreshToken_ExplicitLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + store := newMemoryStore("tok") + rt := NewRefreshToken(RefreshTokenConfig{ + TokenURL: "https://example.com/token", + ClientID: "id", + ClientSecret: "secret", + Store: store, + Logger: custom, + }) + + if rt.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +} diff --git a/plugins/contrib/oauth/token.go b/plugins/contrib/oauth/token.go index 996f579..750b65e 100644 --- a/plugins/contrib/oauth/token.go +++ b/plugins/contrib/oauth/token.go @@ -44,6 +44,16 @@ func newTokenManager( } } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (tm *tokenManager) log() *slog.Logger { + if tm.logger != nil { + return tm.logger + } + return slog.Default() +} + // getToken returns a valid access token, fetching a new one if needed. // Concurrent callers are deduplicated via singleflight. func (tm *tokenManager) getToken(ctx context.Context) (string, time.Time, error) { @@ -51,13 +61,13 @@ func (tm *tokenManager) getToken(ctx context.Context) (string, time.Time, error) if tm.token != nil && time.Now().Before(tm.token.expiresAt) { t := tm.token tm.mu.RUnlock() - tm.logger.LogAttrs(ctx, slog.LevelDebug, "token cache hit", + tm.log().LogAttrs(ctx, slog.LevelDebug, "token cache hit", slog.String("token_url", tm.tokenURL)) return t.accessToken, t.expiresAt, nil } tm.mu.RUnlock() - tm.logger.LogAttrs(ctx, slog.LevelDebug, "token cache miss", + tm.log().LogAttrs(ctx, slog.LevelDebug, "token cache miss", slog.String("token_url", tm.tokenURL)) // Use context.WithoutCancel so that a single caller's cancellation @@ -72,7 +82,7 @@ func (tm *tokenManager) getToken(ctx context.Context) (string, time.Time, error) tm.token = cached tm.mu.Unlock() - tm.logger.LogAttrs(ctx, slog.LevelDebug, "token fetched", + tm.log().LogAttrs(ctx, slog.LevelDebug, "token fetched", slog.String("token_url", tm.tokenURL), slog.Time("expires_at", cached.expiresAt)) diff --git a/plugins/contrib/oauth/token_test.go b/plugins/contrib/oauth/token_test.go new file mode 100644 index 0000000..2abcf72 --- /dev/null +++ b/plugins/contrib/oauth/token_test.go @@ -0,0 +1,36 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package oauth + +import ( + "context" + "io" + "log/slog" + "testing" +) + +func TestNewTokenManager_NilLogger_LazyResolution(t *testing.T) { + tm := newTokenManager("https://example.com/token", nil, func(_ context.Context) (*cachedToken, error) { + return nil, nil + }) + + // logger field must be nil — no eager slog.Default() at construction. + if tm.logger != nil { + t.Error("logger field should be nil when not provided; lazy resolution via log()") + } + if tm.log() != slog.Default() { + t.Error("log() should return slog.Default() when logger is nil") + } +} + +func TestNewTokenManager_ExplicitLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + tm := newTokenManager("https://example.com/token", custom, func(_ context.Context) (*cachedToken, error) { + return nil, nil + }) + + if tm.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +} diff --git a/plugins/contrib/static_mapping.go b/plugins/contrib/static_mapping.go index e81e561..6270c51 100644 --- a/plugins/contrib/static_mapping.go +++ b/plugins/contrib/static_mapping.go @@ -104,8 +104,7 @@ func NewStaticMapping(rules []MappingRule, opts ...StaticMappingOption) *StaticM } sm := &StaticMapping{ - rules: make([]MappingRule, len(rules)), - logger: slog.Default(), + rules: make([]MappingRule, len(rules)), } copy(sm.rules, rules) @@ -118,7 +117,7 @@ func NewStaticMapping(rules []MappingRule, opts ...StaticMappingOption) *StaticM for j := i + 1; j < len(sm.rules); j++ { if sm.rules[i].Specificity() == sm.rules[j].Specificity() && rulesMayOverlap(sm.rules[i], sm.rules[j]) { - sm.logger.Warn( + sm.log().Warn( "mapping rules with equal specificity may overlap, first registered wins on tie", "rule_a_index", i, "rule_a_key", sm.rules[i].Key, @@ -133,6 +132,16 @@ func NewStaticMapping(rules []MappingRule, opts ...StaticMappingOption) *StaticM return sm } +// log returns the configured logger, or slog.Default() if none was set. +// Called at log-emit time so the current global default is always used +// when no explicit logger is provided. +func (sm *StaticMapping) log() *slog.Logger { + if sm.logger != nil { + return sm.logger + } + return slog.Default() +} + // ResolveKey finds the best matching rule for the transaction context and // returns its Key. Returns ErrNoMappingMatch if no rule matches. func (sm *StaticMapping) ResolveKey(_ context.Context, tx sdk.TransactionContext) (string, error) { diff --git a/plugins/contrib/static_mapping_test.go b/plugins/contrib/static_mapping_test.go index ad9df0b..e1ad4a1 100644 --- a/plugins/contrib/static_mapping_test.go +++ b/plugins/contrib/static_mapping_test.go @@ -6,6 +6,7 @@ package contrib import ( "context" "errors" + "io" "log/slog" "testing" @@ -373,3 +374,24 @@ func (lc *mappingLogCapture) Handle(_ context.Context, r slog.Record) error { func (lc *mappingLogCapture) WithAttrs(_ []slog.Attr) slog.Handler { return lc } func (lc *mappingLogCapture) WithGroup(_ string) slog.Handler { return lc } func (lc *mappingLogCapture) hasWarning() bool { return lc.warned } + +func TestNewStaticMapping_NilLogger_LazyResolution(t *testing.T) { + sm := NewStaticMapping([]MappingRule{{Key: "k"}}) + + // logger field must be nil — no eager slog.Default() at construction. + if sm.logger != nil { + t.Error("logger field should be nil when not provided; lazy resolution via log()") + } + if sm.log() != slog.Default() { + t.Error("log() should return slog.Default() when logger is nil") + } +} + +func TestNewStaticMapping_WithMappingLogger_UsesExplicitLogger(t *testing.T) { + custom := slog.New(slog.NewTextHandler(io.Discard, nil)) + sm := NewStaticMapping([]MappingRule{{Key: "k"}}, WithMappingLogger(custom)) + + if sm.log() != custom { + t.Error("log() should return the explicitly provided logger") + } +}