From 7fd1d7e30cb9e26419722814c473ace5368de4ac Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 13:27:52 +0200 Subject: [PATCH 01/35] decode only once in middleware --- cmds/core-service/main.go | 22 +++++++++++++------ pkg/auth/auth.go | 45 +++++++++++++++++++++------------------ pkg/auth/auth_test.go | 2 +- pkg/auth/claims.go | 4 ++-- pkg/auth/claims_test.go | 2 +- pkg/logging/http.go | 16 ++++++++------ 6 files changed, 53 insertions(+), 38 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 20b72e643..893242fc5 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -27,6 +27,7 @@ import ( "github.com/interuss/dss/pkg/build" "github.com/interuss/dss/pkg/datastore" "github.com/interuss/dss/pkg/datastore/flags" // Force command line flag registration + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/rid/application" rid_v1 "github.com/interuss/dss/pkg/rid/server/v1" @@ -385,15 +386,24 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context - claims, err := authorizer.ExtractClaims(r) - if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - ctx = context.WithValue(r.Context(), logging.CtxAuthError{}, fmt.Sprintf("%#s", err)) - } else { - ctx = context.WithValue(r.Context(), logging.CtxAuthSubject{}, claims.Subject) + if err == nil { + if !authorizer.AcceptedAudiences[claims.Audience] { + err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) + } } + ctx = context.WithValue(r.Context(), auth.CtxAuthKey{}, auth.CtxAuthValue{ + Claims: claims, + Error: err, + }) + + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + Subject: claims.Subject, + //remove the stacktrace using the formatting specifier "%#s" + ErrMsg: fmt.Sprintf("%#s", err), + }) + handler.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 696ecf7f2..48fb5229f 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -120,10 +120,11 @@ func (r *JWKSResolver) ResolveKeys(ctx context.Context) ([]interface{}, error) { // Authorizer authorizes incoming requests. type Authorizer struct { - logger *zap.Logger - keys []interface{} - keyGuard sync.RWMutex - acceptedAudiences map[string]bool + logger *zap.Logger + keys []interface{} + keyGuard sync.RWMutex + + AcceptedAudiences map[string]bool } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -148,7 +149,7 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author } authorizer := &Authorizer{ - acceptedAudiences: auds, + AcceptedAudiences: auds, logger: logger, keys: keys, } @@ -182,33 +183,35 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } +type CtxAuthKey struct{} +type CtxAuthValue struct { + Claims Claims + Error error +} + // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - keyClaims, err := a.ExtractClaims(r) - if err != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Failed to extract claims from access token")} - } - - if !a.acceptedAudiences[keyClaims.Audience] { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} + v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if v.Error != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Failed to extract claims from access token")} } - if pass, missing := validateScopes(authOptions, keyClaims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(keyClaims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(v.Claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &keyClaims.Subject, - Scopes: keyClaims.Scopes.ToStringSlice(), + ClientID: &v.Claims.Subject, + Scopes: v.Claims.Scopes.ToStringSlice(), } } -func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { +func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { tknStr, ok := getToken(r) if !ok { - return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -216,10 +219,10 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims claims + var keyClaims Claims for _, key := range keys { - keyClaims = claims{} + keyClaims = Claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -234,7 +237,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 065c66ca0..95c3c7306 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -203,7 +203,7 @@ func TestClaimsValidation(t *testing.T) { Now = time.Now }() - claims := &claims{} + claims := &Claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims.go index 31a33f066..c19140688 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims.go @@ -61,12 +61,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type claims struct { +type Claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *claims) Valid() error { +func (c *Claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims_test.go index 9005ee1d5..c9ec8c206 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims_test.go @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &claims{} + claims := &Claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") diff --git a/pkg/logging/http.go b/pkg/logging/http.go index b41a4d6eb..c17587b07 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -39,8 +39,11 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxAuthError struct{} -type CtxAuthSubject struct{} +type CtxAuthKey struct{} +type CtxAuthValue struct { + Subject string + ErrMsg string +} // HTTPMiddleware installs a logging http.Handler that logs requests and // selected aspects of responses to 'logger'. @@ -72,12 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - subject, ok := r.Context().Value(CtxAuthSubject{}).(string) - if !ok { - authErrorMsg := r.Context().Value(CtxAuthError{}).(string) - logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) + v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if v.ErrMsg != "" { + logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { - logger = logger.With(zap.String("req_sub", subject)) + logger = logger.With(zap.String("req_sub", v.Subject)) } handler.ServeHTTP(trw, r) From 03a06f3612e6c82f4b283c6e5894c7693414a2c5 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 14:05:39 +0200 Subject: [PATCH 02/35] fix formatting error --- cmds/core-service/main.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 893242fc5..8a4bc0336 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -398,10 +398,15 @@ func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) ht Error: err, }) + var errMsg string + if err != nil { + errMsg = fmt.Sprintf("%#s", err) + } + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ Subject: claims.Subject, //remove the stacktrace using the formatting specifier "%#s" - ErrMsg: fmt.Sprintf("%#s", err), + ErrMsg: errMsg, }) handler.ServeHTTP(w, r.WithContext(ctx)) From b20dc4caf577595ead83d702235e22c90ae56e23 Mon Sep 17 00:00:00 2001 From: Mariem Baccari <53703829+MariemBaccari@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:47:03 +0200 Subject: [PATCH 03/35] [logging] Add subject field to log output (#1263) Remove duplicate decoding --- pkg/auth/auth.go | 4 ++++ pkg/logging/http.go | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 48fb5229f..ff40fa33b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -237,7 +237,11 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } +<<<<<<< HEAD return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") +======= + return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") +>>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } return keyClaims, nil diff --git a/pkg/logging/http.go b/pkg/logging/http.go index c17587b07..029acad2e 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,11 +75,20 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } +<<<<<<< HEAD v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { logger = logger.With(zap.String("req_sub", v.Subject)) +======= + subject, ok := r.Context().Value(CtxAuthSubject{}).(string) + if !ok { + authErrorMsg := r.Context().Value(CtxAuthError{}).(string) + logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) + } else { + logger = logger.With(zap.String("req_sub", subject)) +>>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } handler.ServeHTTP(trw, r) From fc27a16923354aa77f3189342f13cab6425c4c73 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 15:14:12 +0200 Subject: [PATCH 04/35] remove merge messages --- pkg/auth/auth.go | 4 ---- pkg/logging/http.go | 9 --------- 2 files changed, 13 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index ff40fa33b..48fb5229f 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -237,11 +237,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } -<<<<<<< HEAD return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") -======= - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") ->>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } return keyClaims, nil diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 029acad2e..c17587b07 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,20 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } -<<<<<<< HEAD v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { logger = logger.With(zap.String("req_sub", v.Subject)) -======= - subject, ok := r.Context().Value(CtxAuthSubject{}).(string) - if !ok { - authErrorMsg := r.Context().Value(CtxAuthError{}).(string) - logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) - } else { - logger = logger.With(zap.String("req_sub", subject)) ->>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } handler.ServeHTTP(trw, r) From 625f6845d6d2aa2bef6ffea0e58dca22a654b311 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 15:46:35 +0200 Subject: [PATCH 05/35] fix unit test --- pkg/auth/auth_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 95c3c7306..80fae62b4 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -115,7 +115,14 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - res := a.Authorize(nil, test.req, []api.AuthorizationOption{}) + claims, err := a.ExtractClaims(test.req) + + ctx := context.WithValue(test.req.Context(), CtxAuthKey{}, CtxAuthValue{ + Error: err, + Claims: claims, + }) + + res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error()) From 1ef90baf73aa0d6694d9f2a2118e7ccaaa10da3c Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 16:08:08 +0200 Subject: [PATCH 06/35] edit error message --- pkg/auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 48fb5229f..e413fb1ca 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -193,7 +193,7 @@ type CtxAuthValue struct { func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Failed to extract claims from access token")} + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Invalid access token")} } if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { From 5860a2e487bd312711e05abca544ca4794c22de5 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Thu, 18 Sep 2025 11:22:41 +0200 Subject: [PATCH 07/35] Revert changes and improve middleware declarations --- cmds/core-service/main.go | 18 ++++++++---------- pkg/auth/auth.go | 14 +++++++------- pkg/logging/http.go | 8 ++++---- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 8a4bc0336..e8e7734c5 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -315,12 +315,10 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, - healthyEndpointMiddleware(logger, - &multiRouter, - )) - - handler = authDecoderMiddleware(authorizer, handler) + // the middlewares are wrapped and, therefore, executed in the opposite order + handler := healthyEndpointMiddleware(logger, &multiRouter) + handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) + handler = authMiddleware(authorizer, handler) httpServer := &http.Server{ Addr: address, @@ -382,8 +380,8 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl }) } -// authDecoderMiddleware decodes the authentication token and adds the Subject claim to the context. -func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { +// authMiddleware decodes the authentication token and passes the claims to the context. +func authMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context claims, err := authorizer.ExtractClaims(r) @@ -400,13 +398,13 @@ func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) ht var errMsg string if err != nil { + //remove the stacktrace using the formatting specifier "%#s" errMsg = fmt.Sprintf("%#s", err) } ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ Subject: claims.Subject, - //remove the stacktrace using the formatting specifier "%#s" - ErrMsg: errMsg, + ErrMsg: errMsg, }) handler.ServeHTTP(w, r.WithContext(ctx)) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index e413fb1ca..81025e970 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -191,20 +191,20 @@ type CtxAuthValue struct { // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if v.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Invalid access token")} + authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if authResults.Error != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(authResults.Error, dsserr.Unauthenticated, "Invalid access token")} } - if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, authResults.Claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(v.Claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(authResults.Claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &v.Claims.Subject, - Scopes: v.Claims.Scopes.ToStringSlice(), + ClientID: &authResults.Claims.Subject, + Scopes: authResults.Claims.Scopes.ToStringSlice(), } } diff --git a/pkg/logging/http.go b/pkg/logging/http.go index c17587b07..f7d8c7945 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,11 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if v.ErrMsg != "" { - logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) + authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if authResults.ErrMsg != "" { + logger = logger.With(zap.String("resp_sub_err", authResults.ErrMsg)) } else { - logger = logger.With(zap.String("req_sub", v.Subject)) + logger = logger.With(zap.String("req_sub", authResults.Subject)) } handler.ServeHTTP(trw, r) From 30dd60f3626e43816c95fac5933a9fc9a183aee4 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:05:46 +0200 Subject: [PATCH 08/35] address pr comments --- cmds/core-service/main.go | 34 +-------------- pkg/auth/auth.go | 91 +++++++++++++++++++++++++++------------ pkg/auth/auth_test.go | 4 +- pkg/auth/claims.go | 4 +- pkg/auth/claims_test.go | 2 +- 5 files changed, 69 insertions(+), 66 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index e8e7734c5..13246e5e0 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -27,7 +27,6 @@ import ( "github.com/interuss/dss/pkg/build" "github.com/interuss/dss/pkg/datastore" "github.com/interuss/dss/pkg/datastore/flags" // Force command line flag registration - dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/rid/application" rid_v1 "github.com/interuss/dss/pkg/rid/server/v1" @@ -318,7 +317,7 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st // the middlewares are wrapped and, therefore, executed in the opposite order handler := healthyEndpointMiddleware(logger, &multiRouter) handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) - handler = authMiddleware(authorizer, handler) + handler = authorizer.TokenMiddleware(handler) httpServer := &http.Server{ Addr: address, @@ -380,37 +379,6 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl }) } -// authMiddleware decodes the authentication token and passes the claims to the context. -func authMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context - claims, err := authorizer.ExtractClaims(r) - if err == nil { - if !authorizer.AcceptedAudiences[claims.Audience] { - err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) - } - } - - ctx = context.WithValue(r.Context(), auth.CtxAuthKey{}, auth.CtxAuthValue{ - Claims: claims, - Error: err, - }) - - var errMsg string - if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - errMsg = fmt.Sprintf("%#s", err) - } - - ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ - Subject: claims.Subject, - ErrMsg: errMsg, - }) - - handler.ServeHTTP(w, r.WithContext(ctx)) - }) -} - type RIDGarbageCollectorJob struct { name string gc ridc.GarbageCollector diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 81025e970..dfa26aa18 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -124,7 +124,14 @@ type Authorizer struct { keys []interface{} keyGuard sync.RWMutex - AcceptedAudiences map[string]bool + acceptedAudiences map[string]bool + + decodedClaims *middlewareResult +} + +type middlewareResult struct { + claims claims + err error } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -149,7 +156,7 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author } authorizer := &Authorizer{ - AcceptedAudiences: auds, + acceptedAudiences: auds, logger: logger, keys: keys, } @@ -183,35 +190,72 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -type CtxAuthKey struct{} -type CtxAuthValue struct { - Claims Claims - Error error +// TokenMiddleware decodes the authentication token and passes the claims to the context. +func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ctx context.Context + claims, err := a.extractClaims(r) + if err == nil { + if !a.acceptedAudiences[claims.Audience] { + err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) + } + } + + a.decodedClaims = &middlewareResult{ + claims: claims, + err: err, + } + + var errMsg string + if err != nil { + //remove the stacktrace using the formatting specifier "%#s" + errMsg = fmt.Sprintf("%#s", err) + } + + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + Subject: claims.Subject, + ErrMsg: errMsg, + }) + + handler.ServeHTTP(w, r.WithContext(ctx)) + }) } -// Authorize extracts and verifies bearer tokens from a http.Request. +// Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if authResults.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(authResults.Error, dsserr.Unauthenticated, "Invalid access token")} + if a.decodedClaims == nil { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Access token not found")} + } + + if a.decodedClaims.err != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(a.decodedClaims.err, dsserr.Unauthenticated, "Invalid access token")} } - if pass, missing := validateScopes(authOptions, authResults.Claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, a.decodedClaims.claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(authResults.Claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(a.decodedClaims.claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &authResults.Claims.Subject, - Scopes: authResults.Claims.Scopes.ToStringSlice(), + ClientID: &a.decodedClaims.claims.Subject, + Scopes: a.decodedClaims.claims.Scopes.ToStringSlice(), + } +} + +func HasScope(scopes []string, requiredScope api.RequiredScope) bool { + for _, scope := range scopes { + if scope == string(requiredScope) { + return true + } } + return false } -func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { +func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { tknStr, ok := getToken(r) if !ok { - return Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -219,10 +263,10 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims Claims + var keyClaims claims for _, key := range keys { - keyClaims = Claims{} + keyClaims = claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -237,21 +281,12 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil } -func HasScope(scopes []string, requiredScope api.RequiredScope) bool { - for _, scope := range scopes { - if scope == string(requiredScope) { - return true - } - } - return false -} - // describeAuthorizationExpectations builds a human-readable string describing the expectations of the authorization options. func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) string { if len(authOptions) == 0 { diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 80fae62b4..f1de0a978 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -115,7 +115,7 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - claims, err := a.ExtractClaims(test.req) + claims, err := a.extractClaims(test.req) ctx := context.WithValue(test.req.Context(), CtxAuthKey{}, CtxAuthValue{ Error: err, @@ -210,7 +210,7 @@ func TestClaimsValidation(t *testing.T) { Now = time.Now }() - claims := &Claims{} + claims := &claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims.go index c19140688..31a33f066 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims.go @@ -61,12 +61,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type Claims struct { +type claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *Claims) Valid() error { +func (c *claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims_test.go index c9ec8c206..9005ee1d5 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims_test.go @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &Claims{} + claims := &claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") From 06b3ae21efb7f98a68d121198b4d48b5b90dacb4 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:10:17 +0200 Subject: [PATCH 09/35] fix context key nit --- pkg/auth/auth.go | 2 +- pkg/logging/http.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index dfa26aa18..c9571f0fb 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -212,7 +212,7 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { errMsg = fmt.Sprintf("%#s", err) } - ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + ctx = context.WithValue(ctx, logging.CtxKey("sub"), logging.CtxAuthValue{ Subject: claims.Subject, ErrMsg: errMsg, }) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index f7d8c7945..f5fc7295f 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -39,7 +39,7 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxAuthKey struct{} +type CtxKey string type CtxAuthValue struct { Subject string ErrMsg string @@ -75,7 +75,7 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + authResults := r.Context().Value(CtxKey("sub")).(CtxAuthValue) if authResults.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", authResults.ErrMsg)) } else { From 363c8dad0d0181c97dc4098fd10d73cdabf62900 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:21:55 +0200 Subject: [PATCH 10/35] add missing context --- pkg/auth/auth.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c9571f0fb..2974d7fe7 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -193,7 +193,6 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the context. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context claims, err := a.extractClaims(r) if err == nil { if !a.acceptedAudiences[claims.Audience] { @@ -212,7 +211,7 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { errMsg = fmt.Sprintf("%#s", err) } - ctx = context.WithValue(ctx, logging.CtxKey("sub"), logging.CtxAuthValue{ + ctx := context.WithValue(r.Context(), logging.CtxKey("sub"), logging.CtxAuthValue{ Subject: claims.Subject, ErrMsg: errMsg, }) From 9fd7ae0681a547d9b2d6e62a76c1d4498b92e06d Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:25:11 +0200 Subject: [PATCH 11/35] clarify middleware doc --- pkg/auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 2974d7fe7..46ba5331b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -190,7 +190,7 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -// TokenMiddleware decodes the authentication token and passes the claims to the context. +// TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := a.extractClaims(r) From 36df8dfdf1cf4f5965822e41cd0e2a6f05ea1be4 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 13:27:52 +0200 Subject: [PATCH 12/35] decode only once in middleware --- cmds/core-service/main.go | 22 +++++++++++++------ pkg/auth/auth.go | 45 +++++++++++++++++++++------------------ pkg/auth/auth_test.go | 2 +- pkg/auth/claims.go | 4 ++-- pkg/auth/claims_test.go | 2 +- pkg/logging/http.go | 16 ++++++++------ 6 files changed, 53 insertions(+), 38 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index f02700cfd..fa870d27a 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -26,6 +26,7 @@ import ( "github.com/interuss/dss/pkg/build" "github.com/interuss/dss/pkg/datastore" "github.com/interuss/dss/pkg/datastore/flags" // Force command line flag registration + dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/rid/application" rid_v1 "github.com/interuss/dss/pkg/rid/server/v1" @@ -377,15 +378,24 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context - claims, err := authorizer.ExtractClaims(r) - if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - ctx = context.WithValue(r.Context(), logging.CtxAuthError{}, fmt.Sprintf("%#s", err)) - } else { - ctx = context.WithValue(r.Context(), logging.CtxAuthSubject{}, claims.Subject) + if err == nil { + if !authorizer.AcceptedAudiences[claims.Audience] { + err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) + } } + ctx = context.WithValue(r.Context(), auth.CtxAuthKey{}, auth.CtxAuthValue{ + Claims: claims, + Error: err, + }) + + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + Subject: claims.Subject, + //remove the stacktrace using the formatting specifier "%#s" + ErrMsg: fmt.Sprintf("%#s", err), + }) + handler.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 696ecf7f2..48fb5229f 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -120,10 +120,11 @@ func (r *JWKSResolver) ResolveKeys(ctx context.Context) ([]interface{}, error) { // Authorizer authorizes incoming requests. type Authorizer struct { - logger *zap.Logger - keys []interface{} - keyGuard sync.RWMutex - acceptedAudiences map[string]bool + logger *zap.Logger + keys []interface{} + keyGuard sync.RWMutex + + AcceptedAudiences map[string]bool } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -148,7 +149,7 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author } authorizer := &Authorizer{ - acceptedAudiences: auds, + AcceptedAudiences: auds, logger: logger, keys: keys, } @@ -182,33 +183,35 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } +type CtxAuthKey struct{} +type CtxAuthValue struct { + Claims Claims + Error error +} + // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - keyClaims, err := a.ExtractClaims(r) - if err != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Failed to extract claims from access token")} - } - - if !a.acceptedAudiences[keyClaims.Audience] { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} + v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if v.Error != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Failed to extract claims from access token")} } - if pass, missing := validateScopes(authOptions, keyClaims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(keyClaims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(v.Claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &keyClaims.Subject, - Scopes: keyClaims.Scopes.ToStringSlice(), + ClientID: &v.Claims.Subject, + Scopes: v.Claims.Scopes.ToStringSlice(), } } -func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { +func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { tknStr, ok := getToken(r) if !ok { - return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -216,10 +219,10 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims claims + var keyClaims Claims for _, key := range keys { - keyClaims = claims{} + keyClaims = Claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -234,7 +237,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 065c66ca0..95c3c7306 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -203,7 +203,7 @@ func TestClaimsValidation(t *testing.T) { Now = time.Now }() - claims := &claims{} + claims := &Claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims.go index 31a33f066..c19140688 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims.go @@ -61,12 +61,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type claims struct { +type Claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *claims) Valid() error { +func (c *Claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims_test.go index 9005ee1d5..c9ec8c206 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims_test.go @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &claims{} + claims := &Claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") diff --git a/pkg/logging/http.go b/pkg/logging/http.go index b41a4d6eb..c17587b07 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -39,8 +39,11 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxAuthError struct{} -type CtxAuthSubject struct{} +type CtxAuthKey struct{} +type CtxAuthValue struct { + Subject string + ErrMsg string +} // HTTPMiddleware installs a logging http.Handler that logs requests and // selected aspects of responses to 'logger'. @@ -72,12 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - subject, ok := r.Context().Value(CtxAuthSubject{}).(string) - if !ok { - authErrorMsg := r.Context().Value(CtxAuthError{}).(string) - logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) + v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if v.ErrMsg != "" { + logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { - logger = logger.With(zap.String("req_sub", subject)) + logger = logger.With(zap.String("req_sub", v.Subject)) } handler.ServeHTTP(trw, r) From 929fbe6f30ebf751c9a53398d7e272501ecacfd0 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 14:05:39 +0200 Subject: [PATCH 13/35] fix formatting error --- cmds/core-service/main.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index fa870d27a..0527dc6c8 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -390,10 +390,15 @@ func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) ht Error: err, }) + var errMsg string + if err != nil { + errMsg = fmt.Sprintf("%#s", err) + } + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ Subject: claims.Subject, //remove the stacktrace using the formatting specifier "%#s" - ErrMsg: fmt.Sprintf("%#s", err), + ErrMsg: errMsg, }) handler.ServeHTTP(w, r.WithContext(ctx)) From 0621a008e9e5e7e52a1ca9126a8dc846c32e1c80 Mon Sep 17 00:00:00 2001 From: Mariem Baccari <53703829+MariemBaccari@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:47:03 +0200 Subject: [PATCH 14/35] [logging] Add subject field to log output (#1263) Remove duplicate decoding --- pkg/auth/auth.go | 4 ++++ pkg/logging/http.go | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 48fb5229f..ff40fa33b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -237,7 +237,11 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } +<<<<<<< HEAD return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") +======= + return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") +>>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } return keyClaims, nil diff --git a/pkg/logging/http.go b/pkg/logging/http.go index c17587b07..029acad2e 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,11 +75,20 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } +<<<<<<< HEAD v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { logger = logger.With(zap.String("req_sub", v.Subject)) +======= + subject, ok := r.Context().Value(CtxAuthSubject{}).(string) + if !ok { + authErrorMsg := r.Context().Value(CtxAuthError{}).(string) + logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) + } else { + logger = logger.With(zap.String("req_sub", subject)) +>>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } handler.ServeHTTP(trw, r) From 7316590df9726c7e06a498a138d27a83eaa7b3a1 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 15:14:12 +0200 Subject: [PATCH 15/35] remove merge messages --- pkg/auth/auth.go | 4 ---- pkg/logging/http.go | 9 --------- 2 files changed, 13 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index ff40fa33b..48fb5229f 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -237,11 +237,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } -<<<<<<< HEAD return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") -======= - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") ->>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } return keyClaims, nil diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 029acad2e..c17587b07 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,20 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } -<<<<<<< HEAD v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) } else { logger = logger.With(zap.String("req_sub", v.Subject)) -======= - subject, ok := r.Context().Value(CtxAuthSubject{}).(string) - if !ok { - authErrorMsg := r.Context().Value(CtxAuthError{}).(string) - logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) - } else { - logger = logger.With(zap.String("req_sub", subject)) ->>>>>>> 6bc9c3c2 ([logging] Add subject field to log output (#1263)) } handler.ServeHTTP(trw, r) From 1a74ef78eb56be649465bd4d81999dcacea913cb Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 15:46:35 +0200 Subject: [PATCH 16/35] fix unit test --- pkg/auth/auth_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 95c3c7306..80fae62b4 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -115,7 +115,14 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - res := a.Authorize(nil, test.req, []api.AuthorizationOption{}) + claims, err := a.ExtractClaims(test.req) + + ctx := context.WithValue(test.req.Context(), CtxAuthKey{}, CtxAuthValue{ + Error: err, + Claims: claims, + }) + + res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error()) From 1325293b544eb24f18b41b1c38be19ed3f28801d Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 9 Sep 2025 16:08:08 +0200 Subject: [PATCH 17/35] edit error message --- pkg/auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 48fb5229f..e413fb1ca 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -193,7 +193,7 @@ type CtxAuthValue struct { func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) if v.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Failed to extract claims from access token")} + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Invalid access token")} } if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { From c11145a9300e02035fa93459b1ce0daba79a0e04 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Thu, 18 Sep 2025 11:22:41 +0200 Subject: [PATCH 18/35] Revert changes and improve middleware declarations --- cmds/core-service/main.go | 18 ++++++++---------- pkg/auth/auth.go | 14 +++++++------- pkg/logging/http.go | 8 ++++---- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 0527dc6c8..e9e48618b 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -307,12 +307,10 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, - healthyEndpointMiddleware(logger, - &multiRouter, - )) - - handler = authDecoderMiddleware(authorizer, handler) + // the middlewares are wrapped and, therefore, executed in the opposite order + handler := healthyEndpointMiddleware(logger, &multiRouter) + handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) + handler = authMiddleware(authorizer, handler) httpServer := &http.Server{ Addr: address, @@ -374,8 +372,8 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl }) } -// authDecoderMiddleware decodes the authentication token and adds the Subject claim to the context. -func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { +// authMiddleware decodes the authentication token and passes the claims to the context. +func authMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context claims, err := authorizer.ExtractClaims(r) @@ -392,13 +390,13 @@ func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) ht var errMsg string if err != nil { + //remove the stacktrace using the formatting specifier "%#s" errMsg = fmt.Sprintf("%#s", err) } ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ Subject: claims.Subject, - //remove the stacktrace using the formatting specifier "%#s" - ErrMsg: errMsg, + ErrMsg: errMsg, }) handler.ServeHTTP(w, r.WithContext(ctx)) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index e413fb1ca..81025e970 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -191,20 +191,20 @@ type CtxAuthValue struct { // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if v.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(v.Error, dsserr.Unauthenticated, "Invalid access token")} + authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if authResults.Error != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(authResults.Error, dsserr.Unauthenticated, "Invalid access token")} } - if pass, missing := validateScopes(authOptions, v.Claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, authResults.Claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(v.Claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(authResults.Claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &v.Claims.Subject, - Scopes: v.Claims.Scopes.ToStringSlice(), + ClientID: &authResults.Claims.Subject, + Scopes: authResults.Claims.Scopes.ToStringSlice(), } } diff --git a/pkg/logging/http.go b/pkg/logging/http.go index c17587b07..f7d8c7945 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -75,11 +75,11 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - v := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if v.ErrMsg != "" { - logger = logger.With(zap.String("resp_sub_err", v.ErrMsg)) + authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + if authResults.ErrMsg != "" { + logger = logger.With(zap.String("resp_sub_err", authResults.ErrMsg)) } else { - logger = logger.With(zap.String("req_sub", v.Subject)) + logger = logger.With(zap.String("req_sub", authResults.Subject)) } handler.ServeHTTP(trw, r) From d78dc68a80a5c08042a8dee631aaf7750a5a9f21 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:05:46 +0200 Subject: [PATCH 19/35] address pr comments --- cmds/core-service/main.go | 34 +-------------- pkg/auth/auth.go | 91 +++++++++++++++++++++++++++------------ pkg/auth/auth_test.go | 4 +- pkg/auth/claims.go | 4 +- pkg/auth/claims_test.go | 2 +- 5 files changed, 69 insertions(+), 66 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index e9e48618b..77d01558f 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -26,7 +26,6 @@ import ( "github.com/interuss/dss/pkg/build" "github.com/interuss/dss/pkg/datastore" "github.com/interuss/dss/pkg/datastore/flags" // Force command line flag registration - dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/dss/pkg/rid/application" rid_v1 "github.com/interuss/dss/pkg/rid/server/v1" @@ -310,7 +309,7 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st // the middlewares are wrapped and, therefore, executed in the opposite order handler := healthyEndpointMiddleware(logger, &multiRouter) handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) - handler = authMiddleware(authorizer, handler) + handler = authorizer.TokenMiddleware(handler) httpServer := &http.Server{ Addr: address, @@ -372,37 +371,6 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl }) } -// authMiddleware decodes the authentication token and passes the claims to the context. -func authMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context - claims, err := authorizer.ExtractClaims(r) - if err == nil { - if !authorizer.AcceptedAudiences[claims.Audience] { - err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) - } - } - - ctx = context.WithValue(r.Context(), auth.CtxAuthKey{}, auth.CtxAuthValue{ - Claims: claims, - Error: err, - }) - - var errMsg string - if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - errMsg = fmt.Sprintf("%#s", err) - } - - ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ - Subject: claims.Subject, - ErrMsg: errMsg, - }) - - handler.ServeHTTP(w, r.WithContext(ctx)) - }) -} - func SetDeprecatingHttpFlag(logger *zap.Logger, newFlag **bool, deprecatedFlag **bool) { if **deprecatedFlag { logger.Warn("DEPRECATED: enable_http has been renamed to allow_http_base_urls.") diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 81025e970..dfa26aa18 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -124,7 +124,14 @@ type Authorizer struct { keys []interface{} keyGuard sync.RWMutex - AcceptedAudiences map[string]bool + acceptedAudiences map[string]bool + + decodedClaims *middlewareResult +} + +type middlewareResult struct { + claims claims + err error } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -149,7 +156,7 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author } authorizer := &Authorizer{ - AcceptedAudiences: auds, + acceptedAudiences: auds, logger: logger, keys: keys, } @@ -183,35 +190,72 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -type CtxAuthKey struct{} -type CtxAuthValue struct { - Claims Claims - Error error +// TokenMiddleware decodes the authentication token and passes the claims to the context. +func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ctx context.Context + claims, err := a.extractClaims(r) + if err == nil { + if !a.acceptedAudiences[claims.Audience] { + err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) + } + } + + a.decodedClaims = &middlewareResult{ + claims: claims, + err: err, + } + + var errMsg string + if err != nil { + //remove the stacktrace using the formatting specifier "%#s" + errMsg = fmt.Sprintf("%#s", err) + } + + ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + Subject: claims.Subject, + ErrMsg: errMsg, + }) + + handler.ServeHTTP(w, r.WithContext(ctx)) + }) } -// Authorize extracts and verifies bearer tokens from a http.Request. +// Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) - if authResults.Error != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(authResults.Error, dsserr.Unauthenticated, "Invalid access token")} + if a.decodedClaims == nil { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Access token not found")} + } + + if a.decodedClaims.err != nil { + return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(a.decodedClaims.err, dsserr.Unauthenticated, "Invalid access token")} } - if pass, missing := validateScopes(authOptions, authResults.Claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, a.decodedClaims.claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(authResults.Claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(a.decodedClaims.claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &authResults.Claims.Subject, - Scopes: authResults.Claims.Scopes.ToStringSlice(), + ClientID: &a.decodedClaims.claims.Subject, + Scopes: a.decodedClaims.claims.Scopes.ToStringSlice(), + } +} + +func HasScope(scopes []string, requiredScope api.RequiredScope) bool { + for _, scope := range scopes { + if scope == string(requiredScope) { + return true + } } + return false } -func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { +func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { tknStr, ok := getToken(r) if !ok { - return Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -219,10 +263,10 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims Claims + var keyClaims claims for _, key := range keys { - keyClaims = Claims{} + keyClaims = claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -237,21 +281,12 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (Claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil } -func HasScope(scopes []string, requiredScope api.RequiredScope) bool { - for _, scope := range scopes { - if scope == string(requiredScope) { - return true - } - } - return false -} - // describeAuthorizationExpectations builds a human-readable string describing the expectations of the authorization options. func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) string { if len(authOptions) == 0 { diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 80fae62b4..f1de0a978 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -115,7 +115,7 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - claims, err := a.ExtractClaims(test.req) + claims, err := a.extractClaims(test.req) ctx := context.WithValue(test.req.Context(), CtxAuthKey{}, CtxAuthValue{ Error: err, @@ -210,7 +210,7 @@ func TestClaimsValidation(t *testing.T) { Now = time.Now }() - claims := &Claims{} + claims := &claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims.go index c19140688..31a33f066 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims.go @@ -61,12 +61,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type Claims struct { +type claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *Claims) Valid() error { +func (c *claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims_test.go index c9ec8c206..9005ee1d5 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims_test.go @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &Claims{} + claims := &claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") From 33cca92dd1c6e1edf2a9c4bc10604dd4f62a741d Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:10:17 +0200 Subject: [PATCH 20/35] fix context key nit --- pkg/auth/auth.go | 2 +- pkg/logging/http.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index dfa26aa18..c9571f0fb 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -212,7 +212,7 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { errMsg = fmt.Sprintf("%#s", err) } - ctx = context.WithValue(ctx, logging.CtxAuthKey{}, logging.CtxAuthValue{ + ctx = context.WithValue(ctx, logging.CtxKey("sub"), logging.CtxAuthValue{ Subject: claims.Subject, ErrMsg: errMsg, }) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index f7d8c7945..f5fc7295f 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -39,7 +39,7 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxAuthKey struct{} +type CtxKey string type CtxAuthValue struct { Subject string ErrMsg string @@ -75,7 +75,7 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - authResults := r.Context().Value(CtxAuthKey{}).(CtxAuthValue) + authResults := r.Context().Value(CtxKey("sub")).(CtxAuthValue) if authResults.ErrMsg != "" { logger = logger.With(zap.String("resp_sub_err", authResults.ErrMsg)) } else { From 09f79264ea4f9a89281ed2a38241c0a83bd18a5c Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:21:55 +0200 Subject: [PATCH 21/35] add missing context --- pkg/auth/auth.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c9571f0fb..2974d7fe7 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -193,7 +193,6 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the context. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context claims, err := a.extractClaims(r) if err == nil { if !a.acceptedAudiences[claims.Audience] { @@ -212,7 +211,7 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { errMsg = fmt.Sprintf("%#s", err) } - ctx = context.WithValue(ctx, logging.CtxKey("sub"), logging.CtxAuthValue{ + ctx := context.WithValue(r.Context(), logging.CtxKey("sub"), logging.CtxAuthValue{ Subject: claims.Subject, ErrMsg: errMsg, }) From 26eae64cf0a159eb3aeb01878e6a93346be3b239 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:25:11 +0200 Subject: [PATCH 22/35] clarify middleware doc --- pkg/auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 2974d7fe7..46ba5331b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -190,7 +190,7 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -// TokenMiddleware decodes the authentication token and passes the claims to the context. +// TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := a.extractClaims(r) From 6e443cacd545cca334080108d7dddc082827cc3b Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 21 Oct 2025 18:34:19 +0200 Subject: [PATCH 23/35] fix auth_test.go --- pkg/auth/auth_test.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index f1de0a978..58dd7a39d 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -116,13 +116,11 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { claims, err := a.extractClaims(test.req) - - ctx := context.WithValue(test.req.Context(), CtxAuthKey{}, CtxAuthValue{ - Error: err, - Claims: claims, - }) - - res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) + a.decodedClaims = &middlewareResult{ + claims: claims, + err: err, + } + res := a.Authorize(nil, test.req.WithContext(context.Background()), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error()) From ed2de918da59926eb34a0e2417046dabf94555cd Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Mon, 15 Dec 2025 15:24:58 +0100 Subject: [PATCH 24/35] address review comments --- pkg/auth/auth.go | 88 ++++++++++++++++++++++--------------------- pkg/auth/auth_test.go | 14 ++++--- pkg/logging/http.go | 22 ++++++----- 3 files changed, 66 insertions(+), 58 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 46ba5331b..33e7dbdab 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -125,13 +125,6 @@ type Authorizer struct { keyGuard sync.RWMutex acceptedAudiences map[string]bool - - decodedClaims *middlewareResult -} - -type middlewareResult struct { - claims claims - err error } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -193,62 +186,62 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() claims, err := a.extractClaims(r) - if err == nil { - if !a.acceptedAudiences[claims.Audience] { - err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience) - } - } - - a.decodedClaims = &middlewareResult{ - claims: claims, - err: err, - } - - var errMsg string if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - errMsg = fmt.Sprintf("%#s", err) + // remove the stacktrace using the formatting specifier "%#s" + ctx = context.WithValue(ctx, logging.CtxKeyErrMsg, fmt.Sprintf("%#s", err)) + ctx = context.WithValue(ctx, ctxKeyError, err) + } else { + ctx = context.WithValue(ctx, logging.CtxKeySub, claims.Subject) + ctx = context.WithValue(ctx, ctxKeyClaims, claims) } - ctx := context.WithValue(r.Context(), logging.CtxKey("sub"), logging.CtxAuthValue{ - Subject: claims.Subject, - ErrMsg: errMsg, - }) - handler.ServeHTTP(w, r.WithContext(ctx)) }) } +type ctxKey string + +const ( + ctxKeyClaims ctxKey = "claims" + ctxKeyError ctxKey = "auth_error" +) + // Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - if a.decodedClaims == nil { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Access token not found")} + if errValue := r.Context().Value(ctxKeyError); errValue != nil { + if err, ok := errValue.(error); ok { + return api.AuthorizationResult{Error: err} + } else { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid authentication error type in context")} + } + } + + claimsValue := r.Context().Value(ctxKeyClaims) + if claimsValue == nil { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing authentication claims from context")} + } + + claims, ok := claimsValue.(claims) + if !ok { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid authentication claims type in context")} } - if a.decodedClaims.err != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(a.decodedClaims.err, dsserr.Unauthenticated, "Invalid access token")} + if !a.acceptedAudiences[claims.Audience] { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience)} } - if pass, missing := validateScopes(authOptions, a.decodedClaims.claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, claims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(a.decodedClaims.claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(claims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &a.decodedClaims.claims.Subject, - Scopes: a.decodedClaims.claims.Scopes.ToStringSlice(), - } -} - -func HasScope(scopes []string, requiredScope api.RequiredScope) bool { - for _, scope := range scopes { - if scope == string(requiredScope) { - return true - } + ClientID: &claims.Subject, + Scopes: claims.Scopes.ToStringSlice(), } - return false } func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { @@ -286,6 +279,15 @@ func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { return keyClaims, nil } +func HasScope(scopes []string, requiredScope api.RequiredScope) bool { + for _, scope := range scopes { + if scope == string(requiredScope) { + return true + } + } + return false +} + // describeAuthorizationExpectations builds a human-readable string describing the expectations of the authorization options. func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) string { if len(authOptions) == 0 { diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 58dd7a39d..ec4bff97d 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -52,7 +52,7 @@ func rsaTokenReqWithMissingIssuer(key *rsa.PrivateKey, exp, nbf int64) *http.Req } func TestNewRSAAuthClient(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() tmpfile, err := os.CreateTemp("/tmp", "bad.pem") @@ -103,7 +103,7 @@ func TestRSAAuthInterceptor(t *testing.T) { {rsaTokenReq(key, 100, 50), dsserr.Unauthenticated}, } - a, err := NewRSAAuthorizer(context.Background(), Configuration{ + a, err := NewRSAAuthorizer(t.Context(), Configuration{ KeyResolver: &fromMemoryKeyResolver{ Keys: []interface{}{&key.PublicKey}, }, @@ -115,12 +115,14 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := context.Background() claims, err := a.extractClaims(test.req) - a.decodedClaims = &middlewareResult{ - claims: claims, - err: err, + if err != nil { + ctx = context.WithValue(ctx, ctxKeyError, err) + } else { + ctx = context.WithValue(ctx, ctxKeyClaims, claims) } - res := a.Authorize(nil, test.req.WithContext(context.Background()), []api.AuthorizationOption{}) + res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error()) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index f5fc7295f..90c8ca9b8 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -40,10 +40,11 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { } type CtxKey string -type CtxAuthValue struct { - Subject string - ErrMsg string -} + +const ( + CtxKeySub CtxKey = "sub" + CtxKeyErrMsg CtxKey = "sub_err_msg" +) // HTTPMiddleware installs a logging http.Handler that logs requests and // selected aspects of responses to 'logger'. @@ -75,11 +76,14 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - authResults := r.Context().Value(CtxKey("sub")).(CtxAuthValue) - if authResults.ErrMsg != "" { - logger = logger.With(zap.String("resp_sub_err", authResults.ErrMsg)) - } else { - logger = logger.With(zap.String("req_sub", authResults.Subject)) + if errMsgValue := r.Context().Value(CtxKeyErrMsg); errMsgValue != nil { + if errMsg, ok := errMsgValue.(string); ok && errMsg != "" { + logger = logger.With(zap.String("resp_sub_err", errMsg)) + } + } else if subValue := r.Context().Value(CtxKeySub); subValue != nil { + if sub, ok := subValue.(string); ok && sub != "" { + logger = logger.With(zap.String("req_sub", sub)) + } } handler.ServeHTTP(trw, r) From 71210a2d40e40657e0ccac006796455b11c37e9c Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Mon, 15 Dec 2025 15:37:07 +0100 Subject: [PATCH 25/35] use keyclaims var name --- pkg/auth/auth.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 33e7dbdab..d861d989e 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -218,29 +218,29 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio } } - claimsValue := r.Context().Value(ctxKeyClaims) - if claimsValue == nil { + keyClaimsValue := r.Context().Value(ctxKeyClaims) + if keyClaimsValue == nil { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing authentication claims from context")} } - claims, ok := claimsValue.(claims) + keyClaims, ok := keyClaimsValue.(claims) if !ok { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid authentication claims type in context")} } - if !a.acceptedAudiences[claims.Audience] { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", claims.Audience)} + if !a.acceptedAudiences[keyClaims.Audience] { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} } - if pass, missing := validateScopes(authOptions, claims.Scopes); !pass { + if pass, missing := validateScopes(authOptions, keyClaims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", - missing, describeAuthorizationExpectations(authOptions), strings.Join(claims.Scopes.ToStringSlice(), ", "))} + missing, describeAuthorizationExpectations(authOptions), strings.Join(keyClaims.Scopes.ToStringSlice(), ", "))} } return api.AuthorizationResult{ - ClientID: &claims.Subject, - Scopes: claims.Scopes.ToStringSlice(), + ClientID: &keyClaims.Subject, + Scopes: keyClaims.Scopes.ToStringSlice(), } } From 62ed06cb81be1323557febd4dcdeae614403f644 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Mon, 15 Dec 2025 15:46:23 +0100 Subject: [PATCH 26/35] remove added newline --- pkg/auth/auth.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d861d989e..32052341a 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -120,10 +120,9 @@ func (r *JWKSResolver) ResolveKeys(ctx context.Context) ([]interface{}, error) { // Authorizer authorizes incoming requests. type Authorizer struct { - logger *zap.Logger - keys []interface{} - keyGuard sync.RWMutex - + logger *zap.Logger + keys []interface{} + keyGuard sync.RWMutex acceptedAudiences map[string]bool } From 97891bbe2c909d3181252ff37148e51be8d7f4c1 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Mon, 15 Dec 2025 16:46:39 +0100 Subject: [PATCH 27/35] move ctxkey --- pkg/auth/auth.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 32052341a..dbf30b6fd 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -182,6 +182,13 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } +type ctxKey string + +const ( + ctxKeyClaims ctxKey = "claims" + ctxKeyError ctxKey = "auth_error" +) + // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -200,13 +207,6 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { }) } -type ctxKey string - -const ( - ctxKeyClaims ctxKey = "claims" - ctxKeyError ctxKey = "auth_error" -) - // Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { if errValue := r.Context().Value(ctxKeyError); errValue != nil { From 817115b289c22b44372e504a7ab0bdb982e9e700 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Mon, 15 Dec 2025 16:58:48 +0100 Subject: [PATCH 28/35] use t.context in all testing --- pkg/auth/auth_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index ec4bff97d..45b48fe6f 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -115,7 +115,7 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() claims, err := a.extractClaims(test.req) if err != nil { ctx = context.WithValue(ctx, ctxKeyError, err) From 32e429e49d6bae14b40fd03ac936eac219db8821 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 26 Dec 2025 17:23:54 +0100 Subject: [PATCH 29/35] address review comments --- pkg/auth/auth.go | 54 +++++++--------------------- pkg/auth/auth_test.go | 18 ++++------ pkg/auth/{ => claims}/claims.go | 35 ++++++++++++++++-- pkg/auth/{ => claims}/claims_test.go | 4 +-- pkg/logging/http.go | 9 ++--- 5 files changed, 55 insertions(+), 65 deletions(-) rename pkg/auth/{ => claims}/claims.go (73%) rename pkg/auth/{ => claims}/claims_test.go (96%) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index dbf30b6fd..c5eddb64e 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -15,6 +15,7 @@ import ( "time" "github.com/interuss/dss/pkg/api" + "github.com/interuss/dss/pkg/auth/claims" dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/stacktrace" @@ -182,26 +183,12 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -type ctxKey string - -const ( - ctxKeyClaims ctxKey = "claims" - ctxKeyError ctxKey = "auth_error" -) - // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - claims, err := a.extractClaims(r) - if err != nil { - // remove the stacktrace using the formatting specifier "%#s" - ctx = context.WithValue(ctx, logging.CtxKeyErrMsg, fmt.Sprintf("%#s", err)) - ctx = context.WithValue(ctx, ctxKeyError, err) - } else { - ctx = context.WithValue(ctx, logging.CtxKeySub, claims.Subject) - ctx = context.WithValue(ctx, ctxKeyClaims, claims) - } + claimsValue, err := a.extractClaims(r) + ctx = claims.NewContext(ctx, claimsValue, err) handler.ServeHTTP(w, r.WithContext(ctx)) }) @@ -209,26 +196,9 @@ func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { // Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - if errValue := r.Context().Value(ctxKeyError); errValue != nil { - if err, ok := errValue.(error); ok { - return api.AuthorizationResult{Error: err} - } else { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid authentication error type in context")} - } - } - - keyClaimsValue := r.Context().Value(ctxKeyClaims) - if keyClaimsValue == nil { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing authentication claims from context")} - } - - keyClaims, ok := keyClaimsValue.(claims) - if !ok { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid authentication claims type in context")} - } - - if !a.acceptedAudiences[keyClaims.Audience] { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} + keyClaims, err := claims.FromContext(r.Context()) + if err != nil { + return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")} } if pass, missing := validateScopes(authOptions, keyClaims.Scopes); !pass { @@ -243,10 +213,10 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio } } -func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { +func (a *Authorizer) extractClaims(r *http.Request) (claims.Claims, error) { tknStr, ok := getToken(r) if !ok { - return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return claims.Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -254,10 +224,10 @@ func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims claims + var keyClaims claims.Claims for _, key := range keys { - keyClaims = claims{} + keyClaims = claims.Claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -272,7 +242,7 @@ func (a *Authorizer) extractClaims(r *http.Request) (claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return claims.Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil @@ -316,7 +286,7 @@ func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) st // validateScopes matches scopes against a set of authorization options. Validation against a single one of those is // enough for the validation to succeed. Returns true if it succeeds, or returns false and a string describing the // missing scopes if it fails. Empty authorization options means that the validation passes. -func validateScopes(authOptions []api.AuthorizationOption, clientScopes ScopeSet) (bool, string) { +func validateScopes(authOptions []api.AuthorizationOption, clientScopes claims.ScopeSet) (bool, string) { if len(authOptions) == 0 { return true, "" } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 45b48fe6f..c88f85a97 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -12,6 +12,7 @@ import ( "github.com/interuss/dss/pkg/api" "github.com/interuss/dss/pkg/api/scdv1" + "github.com/interuss/dss/pkg/auth/claims" dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/stacktrace" @@ -115,13 +116,8 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - ctx := t.Context() - claims, err := a.extractClaims(test.req) - if err != nil { - ctx = context.WithValue(ctx, ctxKeyError, err) - } else { - ctx = context.WithValue(ctx, ctxKeyClaims, claims) - } + claimsValue, err := a.extractClaims(test.req) + ctx := claims.NewContext(t.Context(), claimsValue, err) res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) @@ -200,17 +196,17 @@ func TestMissingScopes(t *testing.T) { } func TestClaimsValidation(t *testing.T) { - Now = func() time.Time { + claims.Now = func() time.Time { return time.Unix(42, 0) } - jwt.TimeFunc = Now + jwt.TimeFunc = claims.Now defer func() { jwt.TimeFunc = time.Now - Now = time.Now + claims.Now = time.Now }() - claims := &claims{} + claims := &claims.Claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims/claims.go similarity index 73% rename from pkg/auth/claims.go rename to pkg/auth/claims/claims.go index 31a33f066..c8a0437af 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims/claims.go @@ -1,6 +1,7 @@ -package auth +package claims import ( + "context" "encoding/json" "errors" "strings" @@ -18,6 +19,34 @@ var ( Now = time.Now ) +type ctxKey string + +var ( + claimsKey = ctxKey("claims") + errKey = ctxKey("error") +) + +func NewContext(ctx context.Context, claims Claims, err error) context.Context { + if err != nil { + return context.WithValue(ctx, errKey, err) + } + + return context.WithValue(ctx, claimsKey, claims) +} + +func FromContext(ctx context.Context) (Claims, error) { + claims, ok := ctx.Value(claimsKey).(Claims) + if !ok { + err, ok := ctx.Value(errKey).(error) + if ok { + return Claims{}, err + } + return Claims{}, stacktrace.NewError("No claims or error in context") + } + + return claims, nil +} + // ScopeSet models a set of scopes. type ScopeSet map[string]struct{} @@ -61,12 +90,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type claims struct { +type Claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *claims) Valid() error { +func (c *Claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims/claims_test.go similarity index 96% rename from pkg/auth/claims_test.go rename to pkg/auth/claims/claims_test.go index 9005ee1d5..e0cdcbb6f 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims/claims_test.go @@ -1,4 +1,4 @@ -package auth +package claims import ( "encoding/json" @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &claims{} + claims := &Claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 90c8ca9b8..821cb722e 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -42,8 +42,7 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { type CtxKey string const ( - CtxKeySub CtxKey = "sub" - CtxKeyErrMsg CtxKey = "sub_err_msg" + CtxKeySub CtxKey = "sub" ) // HTTPMiddleware installs a logging http.Handler that logs requests and @@ -76,11 +75,7 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - if errMsgValue := r.Context().Value(CtxKeyErrMsg); errMsgValue != nil { - if errMsg, ok := errMsgValue.(string); ok && errMsg != "" { - logger = logger.With(zap.String("resp_sub_err", errMsg)) - } - } else if subValue := r.Context().Value(CtxKeySub); subValue != nil { + if subValue := r.Context().Value(CtxKeySub); subValue != nil { if sub, ok := subValue.(string); ok && sub != "" { logger = logger.With(zap.String("req_sub", sub)) } From e33f6b4dab4d968a8450635331e8bf971fd3ce27 Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 26 Dec 2025 22:31:28 +0100 Subject: [PATCH 30/35] log value from claims.FromContext --- pkg/logging/http.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 821cb722e..3a58729fb 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/interuss/dss/pkg/auth/claims" "go.uber.org/zap" ) @@ -75,10 +76,9 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - if subValue := r.Context().Value(CtxKeySub); subValue != nil { - if sub, ok := subValue.(string); ok && sub != "" { - logger = logger.With(zap.String("req_sub", sub)) - } + claimsValue, _ := claims.FromContext(r.Context()) + if claimsValue.Subject != "" { + logger = logger.With(zap.String("req_sub", claimsValue.Subject)) } handler.ServeHTTP(trw, r) From 1e5755ff159c1fc7f68759b78638d59c7a64f4fc Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Fri, 26 Dec 2025 22:33:11 +0100 Subject: [PATCH 31/35] remove unused ctxkey --- pkg/logging/http.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 3a58729fb..5a681172d 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -40,12 +40,6 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxKey string - -const ( - CtxKeySub CtxKey = "sub" -) - // HTTPMiddleware installs a logging http.Handler that logs requests and // selected aspects of responses to 'logger'. func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Handler { From 4fdf748a0b5740cb4e276a5154b8875318b11df9 Mon Sep 17 00:00:00 2001 From: Mariem Baccari <53703829+MariemBaccari@users.noreply.github.com> Date: Tue, 30 Dec 2025 09:03:31 +0100 Subject: [PATCH 32/35] nit --- pkg/auth/auth.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index c5eddb64e..12e4a5079 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -186,9 +186,8 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() claimsValue, err := a.extractClaims(r) - ctx = claims.NewContext(ctx, claimsValue, err) + ctx := claims.NewContext(r.Context(), claimsValue, err) handler.ServeHTTP(w, r.WithContext(ctx)) }) From 60f3db16bae9ba09f28ebb24e0198a9b010fb1ec Mon Sep 17 00:00:00 2001 From: Mariem Baccari <53703829+MariemBaccari@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:14:28 +0100 Subject: [PATCH 33/35] Update pkg/logging/http.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mickaƫl Misbach --- pkg/logging/http.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 5a681172d..ab50d874c 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -70,8 +70,7 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - claimsValue, _ := claims.FromContext(r.Context()) - if claimsValue.Subject != "" { + if claimsValue, _ := claims.FromContext(r.Context()); claimsValue.Subject != "" { logger = logger.With(zap.String("req_sub", claimsValue.Subject)) } From 799821b0de6bffcc2ab8874b5a06a1ad0706e70f Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 6 Jan 2026 11:19:10 +0100 Subject: [PATCH 34/35] address review comment --- pkg/auth/auth.go | 11 ++++++++++- pkg/auth/auth_test.go | 8 +++++++- pkg/auth/claims/claims.go | 10 +++++----- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 12e4a5079..28d5ce8c3 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -186,8 +186,13 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ctx context.Context claimsValue, err := a.extractClaims(r) - ctx := claims.NewContext(r.Context(), claimsValue, err) + if err != nil { + ctx = claims.NewContextFromError(ctx, err) + } else { + ctx = claims.NewContext(ctx, claimsValue) + } handler.ServeHTTP(w, r.WithContext(ctx)) }) @@ -200,6 +205,10 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")} } + if !a.acceptedAudiences[keyClaims.Audience] { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} + } + if pass, missing := validateScopes(authOptions, keyClaims.Scopes); !pass { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.PermissionDenied, "Access token missing scopes (%v) while expecting %v and got %v", diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index c88f85a97..4bc900417 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -116,8 +116,14 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { + var ctx context.Context claimsValue, err := a.extractClaims(test.req) - ctx := claims.NewContext(t.Context(), claimsValue, err) + if err != nil { + ctx = claims.NewContextFromError(ctx, err) + } else { + ctx = claims.NewContext(ctx, claimsValue) + } + res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) diff --git a/pkg/auth/claims/claims.go b/pkg/auth/claims/claims.go index c8a0437af..92f2b11f7 100644 --- a/pkg/auth/claims/claims.go +++ b/pkg/auth/claims/claims.go @@ -26,14 +26,14 @@ var ( errKey = ctxKey("error") ) -func NewContext(ctx context.Context, claims Claims, err error) context.Context { - if err != nil { - return context.WithValue(ctx, errKey, err) - } - +func NewContext(ctx context.Context, claims Claims) context.Context { return context.WithValue(ctx, claimsKey, claims) } +func NewContextFromError(ctx context.Context, err error) context.Context { + return context.WithValue(ctx, errKey, err) +} + func FromContext(ctx context.Context) (Claims, error) { claims, ok := ctx.Value(claimsKey).(Claims) if !ok { From bcdebdb3e547de53af50c80a767e4f6b9f662cdb Mon Sep 17 00:00:00 2001 From: Mariem Baccari Date: Tue, 6 Jan 2026 11:23:34 +0100 Subject: [PATCH 35/35] nit --- pkg/auth/auth.go | 2 +- pkg/auth/auth_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 28d5ce8c3..a071f1d25 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -186,7 +186,7 @@ func (a *Authorizer) setKeys(keys []interface{}) { // TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context + ctx := r.Context() claimsValue, err := a.extractClaims(r) if err != nil { ctx = claims.NewContextFromError(ctx, err) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 4bc900417..02141713c 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -116,7 +116,7 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - var ctx context.Context + ctx := t.Context() claimsValue, err := a.extractClaims(test.req) if err != nil { ctx = claims.NewContextFromError(ctx, err)