Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions plugins/contrib/internal/oauthutil/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ const defaultExpiryMargin = 1 * time.Minute

// NewTokenFetcher creates a TokenFetcher with defaults applied for nil fields.
func NewTokenFetcher(cfg TokenFetcher) *TokenFetcher {
if cfg.Logger == nil {
cfg.Logger = slog.Default()
}
if cfg.Client == nil {
cfg.Client = DefaultHTTPClient()
}
Expand All @@ -103,6 +100,16 @@ func NewTokenFetcher(cfg TokenFetcher) *TokenFetcher {
return &cfg
}

// log returns the configured logger, or slog.Default() if none was set.
// Called at log-emit time so the current global default is always used
// when no explicit logger is provided.
func (tf *TokenFetcher) log() *slog.Logger {
if tf.Logger != nil {
return tf.Logger
}
return slog.Default()
}

// Exchange sends a token request with the given form parameters and returns
// the parsed result. The caller is responsible for setting grant-type-specific
// form parameters (e.g., refresh_token) before calling Exchange.
Expand Down Expand Up @@ -169,7 +176,7 @@ func (tf *TokenFetcher) doRequest(ctx context.Context, req *http.Request) (statu
if ctx.Err() != nil {
return 0, nil, nil, fmt.Errorf("token request for %s: %w", tf.TokenURL, ctx.Err())
}
tf.Logger.LogAttrs(ctx, slog.LevelWarn, "token endpoint request failed",
tf.log().LogAttrs(ctx, slog.LevelWarn, "token endpoint request failed",
slog.String("token_url", tf.TokenURL),
slog.String("error", err.Error()))
return 0, nil, nil, fmt.Errorf("token endpoint request for %s: %w",
Expand All @@ -193,7 +200,7 @@ func (tf *TokenFetcher) doRequest(ctx context.Context, req *http.Request) (statu
func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int, header http.Header, body []byte) error {
contentType := header.Get("Content-Type")

if tf.Logger.Enabled(ctx, slog.LevelDebug) {
if tf.log().Enabled(ctx, slog.LevelDebug) {
attrs := []slog.Attr{
slog.Int("status", statusCode),
slog.String("content_type", contentType),
Expand All @@ -209,7 +216,7 @@ func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int,
}
}

tf.Logger.LogAttrs(ctx, slog.LevelDebug, "token endpoint error response", attrs...)
tf.log().LogAttrs(ctx, slog.LevelDebug, "token endpoint error response", attrs...)
}

if statusCode == http.StatusUnauthorized {
Expand All @@ -218,7 +225,7 @@ func (tf *TokenFetcher) handleErrorResponse(ctx context.Context, statusCode int,
}

if statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError {
tf.Logger.LogAttrs(ctx, slog.LevelWarn, "token endpoint unavailable",
tf.log().LogAttrs(ctx, slog.LevelWarn, "token endpoint unavailable",
slog.String("token_url", tf.TokenURL),
slog.Int("status", statusCode))
return fmt.Errorf("token endpoint returned %d (content-type: %s): %w",
Expand Down
28 changes: 26 additions & 2 deletions plugins/contrib/internal/oauthutil/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package oauthutil

import (
"encoding/json"
"io"
"log/slog"
"testing"
"time"
)
Expand All @@ -30,8 +32,10 @@ func TestNewTokenFetcher_Defaults(t *testing.T) {
TokenURL: "https://example.com/token",
})

if f.Logger == nil {
t.Error("Logger should default to non-nil")
// Logger field must be nil — no eager slog.Default() capture at construction.
// Lazy resolution happens via log() at call time.
if f.Logger != nil {
t.Error("Logger field should be nil when not provided; lazy resolution via log()")
}
if f.Client == nil {
t.Error("Client should default to non-nil")
Expand All @@ -41,6 +45,26 @@ func TestNewTokenFetcher_Defaults(t *testing.T) {
}
}

func TestNewTokenFetcher_NilLogger_LazyResolution(t *testing.T) {
f := NewTokenFetcher(TokenFetcher{TokenURL: "https://example.com/token"})

if f.log() != slog.Default() {
t.Error("log() should return slog.Default() when Logger is nil")
}
}

func TestNewTokenFetcher_ExplicitLogger_UsesExplicitLogger(t *testing.T) {
custom := slog.New(slog.NewTextHandler(io.Discard, nil))
f := NewTokenFetcher(TokenFetcher{
TokenURL: "https://example.com/token",
Logger: custom,
})

if f.log() != custom {
t.Error("log() should return the explicitly provided logger")
}
}

func TestNewTokenFetcher_PreservesExplicitValues(t *testing.T) {
margin := 5 * time.Minute
f := NewTokenFetcher(TokenFetcher{
Expand Down
25 changes: 15 additions & 10 deletions plugins/contrib/microsoft/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ func NewRefreshTokenSource(cfg Config) *RefreshTokenSource {
margin = defaultExpiryMargin
}

logger := cfg.Logger
if logger == nil {
logger = slog.Default()
}

client := cfg.HTTPClient
if client == nil {
client = oauthutil.DefaultHTTPClient()
Expand All @@ -182,13 +177,23 @@ func NewRefreshTokenSource(cfg Config) *RefreshTokenSource {
maxPoolSize: maxPool,
expiryMargin: margin,
httpClient: client,
logger: logger,
logger: cfg.Logger,
onSaveError: cfg.OnSaveError,
pool: make(map[string]*list.Element),
lru: list.New(),
}
}

// log returns the configured logger, or slog.Default() if none was set.
// Called at log-emit time so the current global default is always used
// when no explicit logger is provided.
func (s *RefreshTokenSource) log() *slog.Logger {
if s.logger != nil {
return s.logger
}
return slog.Default()
}

// GetCredentials extracts TenantID and Resource from the transaction context,
// resolves a per-tenant entry from the pool, checks the per-resource access
// token cache, and if needed exchanges the tenant's MRRT for a new access
Expand Down Expand Up @@ -240,14 +245,14 @@ func (s *RefreshTokenSource) getAccessToken(
entry.mu.RLock()
if ct, ok := entry.tokens[resource]; ok && time.Now().Before(ct.expiresAt) {
entry.mu.RUnlock()
s.logger.LogAttrs(ctx, slog.LevelDebug, "access token cache hit",
s.log().LogAttrs(ctx, slog.LevelDebug, "access token cache hit",
slog.String("tenant_id", entry.tenantID),
slog.String("resource", resource))
return ct.accessToken, ct.expiresAt, nil
}
entry.mu.RUnlock()

s.logger.LogAttrs(ctx, slog.LevelDebug, "access token cache miss",
s.log().LogAttrs(ctx, slog.LevelDebug, "access token cache miss",
slog.String("tenant_id", entry.tenantID),
slog.String("resource", resource))

Expand Down Expand Up @@ -317,7 +322,7 @@ func (s *RefreshTokenSource) exchange(

if result.RefreshToken != "" {
if saveErr := s.store.Save(ctx, tenantID, result.RefreshToken); saveErr != nil {
s.logger.LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token",
s.log().LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token",
slog.String("tenant_id", tenantID),
slog.String("resource", resource),
slog.String("error", saveErr.Error()))
Expand Down Expand Up @@ -371,7 +376,7 @@ func (s *RefreshTokenSource) evictLRU(ctx context.Context) {
}

entry, _ := back.Value.(*tenantEntry)
s.logger.LogAttrs(ctx, slog.LevelDebug, "evicting LRU pool entry",
s.log().LogAttrs(ctx, slog.LevelDebug, "evicting LRU pool entry",
slog.String("tenant_id", entry.tenantID))

delete(s.pool, entry.tenantID)
Expand Down
30 changes: 30 additions & 0 deletions plugins/contrib/microsoft/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1446,3 +1446,33 @@ func TestGetCredentials_KeyResolver_ValidTenantIDCheck_RejectsBadResolverValue(t
t.Errorf("error = %v, want errors.Is(ErrInvalidContextData)", err)
}
}

func TestNewRefreshTokenSource_NilLogger_LazyResolution(t *testing.T) {
src := NewRefreshTokenSource(Config{
ClientID: "id",
ClientSecret: "secret",
Store: newMemoryTokenStore(),
})

// logger field must be nil — no eager slog.Default() at construction.
if src.logger != nil {
t.Error("logger field should be nil when not provided; lazy resolution via log()")
}
if src.log() != slog.Default() {
t.Error("log() should return slog.Default() when logger is nil")
}
}

func TestNewRefreshTokenSource_ExplicitLogger_UsesExplicitLogger(t *testing.T) {
custom := slog.New(slog.NewTextHandler(io.Discard, nil))
src := NewRefreshTokenSource(Config{
ClientID: "id",
ClientSecret: "secret",
Store: newMemoryTokenStore(),
Logger: custom,
})

if src.log() != custom {
t.Error("log() should return the explicitly provided logger")
}
}
16 changes: 12 additions & 4 deletions plugins/contrib/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,23 @@ func WithLogger(l *slog.Logger) MuxOption {
// NewMux creates a new request multiplexer. Use [Mux.Handle] and
// [Mux.Default] to register routes before serving traffic.
func NewMux(opts ...MuxOption) *Mux {
m := &Mux{
logger: slog.Default(),
}
m := &Mux{}
for _, opt := range opts {
opt(m)
}
return m
}

// log returns the configured logger, or slog.Default() if none was set.
// Called at log-emit time so the current global default is always used
// when no explicit logger is provided.
func (m *Mux) log() *slog.Logger {
if m.logger != nil {
return m.logger
}
return slog.Default()
}

// Handle registers a route that dispatches matching requests to the
// given provider. Routes are evaluated by specificity at dispatch time,
// not registration order — but registration order breaks ties.
Expand All @@ -81,7 +89,7 @@ func (m *Mux) Handle(route Route, provider sdk.CredentialProvider) {
newSpec := route.Specificity()
for _, e := range m.entries {
if e.route.Specificity() == newSpec && routesMayOverlap(e.route, route) {
m.logger.Warn("routes registered with equal specificity may overlap, first registered wins on tie",
m.log().Warn("routes registered with equal specificity may overlap, first registered wins on tie",
"existing_route", routeString(e.route),
"new_route", routeString(route),
)
Expand Down
22 changes: 22 additions & 0 deletions plugins/contrib/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package contrib
import (
"context"
"errors"
"io"
"log/slog"
"net/http"
"sync"
Expand Down Expand Up @@ -602,3 +603,24 @@ func TestMux_Compliance(t *testing.T) {
mux.Default(&namedProvider{name: "compliance"})
compliance.VerifyContract(t, mux)
}

func TestNewMux_NilLogger_LazyResolution(t *testing.T) {
m := NewMux()

// logger field must be nil — no eager slog.Default() at construction.
if m.logger != nil {
t.Error("logger field should be nil when not provided; lazy resolution via log()")
}
if m.log() != slog.Default() {
t.Error("log() should return slog.Default() when logger is nil")
}
}

func TestNewMux_WithLogger_UsesExplicitLogger(t *testing.T) {
custom := slog.New(slog.NewTextHandler(io.Discard, nil))
m := NewMux(WithLogger(custom))

if m.log() != custom {
t.Error("log() should return the explicitly provided logger")
}
}
12 changes: 11 additions & 1 deletion plugins/contrib/oauth/refreshtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ func NewRefreshToken(cfg RefreshTokenConfig) *RefreshToken {
return rt
}

// log returns the configured logger, or slog.Default() if none was set.
// Called at log-emit time so the current global default is always used
// when no explicit logger is provided.
func (rt *RefreshToken) log() *slog.Logger {
if rt.logger != nil {
return rt.logger
}
return slog.Default()
}

// GetCredentials fetches an OAuth2 bearer token using the refresh token grant
// and returns it as a cacheable credential (Fast Path).
//
Expand Down Expand Up @@ -151,7 +161,7 @@ func (rt *RefreshToken) fetch(ctx context.Context) (*cachedToken, error) {

if result.RefreshToken != "" {
if saveErr := rt.store.Save(ctx, result.RefreshToken); saveErr != nil {
rt.logger.LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token",
rt.log().LogAttrs(ctx, slog.LevelError, "failed to save rotated refresh token",
slog.String("token_url", rt.fetcher.TokenURL),
slog.String("error", saveErr.Error()))

Expand Down
34 changes: 34 additions & 0 deletions plugins/contrib/oauth/refreshtoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,37 @@ func TestRefreshToken_Compliance(t *testing.T) {
plugin := contrib.AsPlugin(rt)
compliance.VerifyContract(t, plugin)
}

func TestNewRefreshToken_NilLogger_LazyResolution(t *testing.T) {
store := newMemoryStore("tok")
rt := NewRefreshToken(RefreshTokenConfig{
TokenURL: "https://example.com/token",
ClientID: "id",
ClientSecret: "secret",
Store: store,
})

// logger field must be nil — no eager slog.Default() at construction.
if rt.logger != nil {
t.Error("logger field should be nil when not provided; lazy resolution via log()")
}
if rt.log() != slog.Default() {
t.Error("log() should return slog.Default() when logger is nil")
}
}

func TestNewRefreshToken_ExplicitLogger_UsesExplicitLogger(t *testing.T) {
custom := slog.New(slog.NewTextHandler(io.Discard, nil))
store := newMemoryStore("tok")
rt := NewRefreshToken(RefreshTokenConfig{
TokenURL: "https://example.com/token",
ClientID: "id",
ClientSecret: "secret",
Store: store,
Logger: custom,
})

if rt.log() != custom {
t.Error("log() should return the explicitly provided logger")
}
}
16 changes: 13 additions & 3 deletions plugins/contrib/oauth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,30 @@ func newTokenManager(
}
}

// log returns the configured logger, or slog.Default() if none was set.
// Called at log-emit time so the current global default is always used
// when no explicit logger is provided.
func (tm *tokenManager) log() *slog.Logger {
if tm.logger != nil {
return tm.logger
}
return slog.Default()
}

// getToken returns a valid access token, fetching a new one if needed.
// Concurrent callers are deduplicated via singleflight.
func (tm *tokenManager) getToken(ctx context.Context) (string, time.Time, error) {
tm.mu.RLock()
if tm.token != nil && time.Now().Before(tm.token.expiresAt) {
t := tm.token
tm.mu.RUnlock()
tm.logger.LogAttrs(ctx, slog.LevelDebug, "token cache hit",
tm.log().LogAttrs(ctx, slog.LevelDebug, "token cache hit",
slog.String("token_url", tm.tokenURL))
return t.accessToken, t.expiresAt, nil
}
tm.mu.RUnlock()

tm.logger.LogAttrs(ctx, slog.LevelDebug, "token cache miss",
tm.log().LogAttrs(ctx, slog.LevelDebug, "token cache miss",
slog.String("token_url", tm.tokenURL))

// Use context.WithoutCancel so that a single caller's cancellation
Expand All @@ -72,7 +82,7 @@ func (tm *tokenManager) getToken(ctx context.Context) (string, time.Time, error)
tm.token = cached
tm.mu.Unlock()

tm.logger.LogAttrs(ctx, slog.LevelDebug, "token fetched",
tm.log().LogAttrs(ctx, slog.LevelDebug, "token fetched",
slog.String("token_url", tm.tokenURL),
slog.Time("expires_at", cached.expiresAt))

Expand Down
Loading
Loading