diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index f02700cfd..77d01558f 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -306,12 +306,10 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, - healthyEndpointMiddleware(logger, - &multiRouter, - )) - - handler = authDecoderMiddleware(authorizer, handler) + // the middlewares are wrapped and, therefore, executed in the opposite order + handler := healthyEndpointMiddleware(logger, &multiRouter) + handler = logging.HTTPMiddleware(logger, *dumpRequests, handler) + handler = authorizer.TokenMiddleware(handler) httpServer := &http.Server{ Addr: address, @@ -373,23 +371,6 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl }) } -// authDecoderMiddleware decodes the authentication token and adds the Subject claim to the context. -func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var ctx context.Context - - claims, err := authorizer.ExtractClaims(r) - if err != nil { - //remove the stacktrace using the formatting specifier "%#s" - ctx = context.WithValue(r.Context(), logging.CtxAuthError{}, fmt.Sprintf("%#s", err)) - } else { - ctx = context.WithValue(r.Context(), logging.CtxAuthSubject{}, claims.Subject) - } - - handler.ServeHTTP(w, r.WithContext(ctx)) - }) -} - func SetDeprecatingHttpFlag(logger *zap.Logger, newFlag **bool, deprecatedFlag **bool) { if **deprecatedFlag { logger.Warn("DEPRECATED: enable_http has been renamed to allow_http_base_urls.") diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 696ecf7f2..a071f1d25 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -15,6 +15,7 @@ import ( "time" "github.com/interuss/dss/pkg/api" + "github.com/interuss/dss/pkg/auth/claims" dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/dss/pkg/logging" "github.com/interuss/stacktrace" @@ -182,11 +183,26 @@ func (a *Authorizer) setKeys(keys []interface{}) { a.keyGuard.Unlock() } -// Authorize extracts and verifies bearer tokens from a http.Request. +// TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging. +func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + claimsValue, err := a.extractClaims(r) + if err != nil { + ctx = claims.NewContextFromError(ctx, err) + } else { + ctx = claims.NewContext(ctx, claimsValue) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - keyClaims, err := a.ExtractClaims(r) + keyClaims, err := claims.FromContext(r.Context()) if err != nil { - return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Failed to extract claims from access token")} + return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")} } if !a.acceptedAudiences[keyClaims.Audience] { @@ -205,10 +221,10 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio } } -func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { +func (a *Authorizer) extractClaims(r *http.Request) (claims.Claims, error) { tknStr, ok := getToken(r) if !ok { - return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") + return claims.Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token") } a.keyGuard.RLock() @@ -216,10 +232,10 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { a.keyGuard.RUnlock() validated := false var err error - var keyClaims claims + var keyClaims claims.Claims for _, key := range keys { - keyClaims = claims{} + keyClaims = claims.Claims{} key := key _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { return key, nil @@ -234,7 +250,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) { if err == nil { // If we have no keys, errs may be nil err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against") } - return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") + return claims.Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed") } return keyClaims, nil @@ -278,7 +294,7 @@ func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) st // validateScopes matches scopes against a set of authorization options. Validation against a single one of those is // enough for the validation to succeed. Returns true if it succeeds, or returns false and a string describing the // missing scopes if it fails. Empty authorization options means that the validation passes. -func validateScopes(authOptions []api.AuthorizationOption, clientScopes ScopeSet) (bool, string) { +func validateScopes(authOptions []api.AuthorizationOption, clientScopes claims.ScopeSet) (bool, string) { if len(authOptions) == 0 { return true, "" } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 065c66ca0..02141713c 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -12,6 +12,7 @@ import ( "github.com/interuss/dss/pkg/api" "github.com/interuss/dss/pkg/api/scdv1" + "github.com/interuss/dss/pkg/auth/claims" dsserr "github.com/interuss/dss/pkg/errors" "github.com/interuss/stacktrace" @@ -52,7 +53,7 @@ func rsaTokenReqWithMissingIssuer(key *rsa.PrivateKey, exp, nbf int64) *http.Req } func TestNewRSAAuthClient(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() tmpfile, err := os.CreateTemp("/tmp", "bad.pem") @@ -103,7 +104,7 @@ func TestRSAAuthInterceptor(t *testing.T) { {rsaTokenReq(key, 100, 50), dsserr.Unauthenticated}, } - a, err := NewRSAAuthorizer(context.Background(), Configuration{ + a, err := NewRSAAuthorizer(t.Context(), Configuration{ KeyResolver: &fromMemoryKeyResolver{ Keys: []interface{}{&key.PublicKey}, }, @@ -115,7 +116,15 @@ func TestRSAAuthInterceptor(t *testing.T) { for i, test := range authTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - res := a.Authorize(nil, test.req, []api.AuthorizationOption{}) + ctx := t.Context() + claimsValue, err := a.extractClaims(test.req) + if err != nil { + ctx = claims.NewContextFromError(ctx, err) + } else { + ctx = claims.NewContext(ctx, claimsValue) + } + + res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{}) if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code { t.Logf("%v", res.Error) t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error()) @@ -193,17 +202,17 @@ func TestMissingScopes(t *testing.T) { } func TestClaimsValidation(t *testing.T) { - Now = func() time.Time { + claims.Now = func() time.Time { return time.Unix(42, 0) } - jwt.TimeFunc = Now + jwt.TimeFunc = claims.Now defer func() { jwt.TimeFunc = time.Now - Now = time.Now + claims.Now = time.Now }() - claims := &claims{} + claims := &claims.Claims{} require.Error(t, claims.Valid()) diff --git a/pkg/auth/claims.go b/pkg/auth/claims/claims.go similarity index 72% rename from pkg/auth/claims.go rename to pkg/auth/claims/claims.go index 31a33f066..92f2b11f7 100644 --- a/pkg/auth/claims.go +++ b/pkg/auth/claims/claims.go @@ -1,6 +1,7 @@ -package auth +package claims import ( + "context" "encoding/json" "errors" "strings" @@ -18,6 +19,34 @@ var ( Now = time.Now ) +type ctxKey string + +var ( + claimsKey = ctxKey("claims") + errKey = ctxKey("error") +) + +func NewContext(ctx context.Context, claims Claims) context.Context { + return context.WithValue(ctx, claimsKey, claims) +} + +func NewContextFromError(ctx context.Context, err error) context.Context { + return context.WithValue(ctx, errKey, err) +} + +func FromContext(ctx context.Context) (Claims, error) { + claims, ok := ctx.Value(claimsKey).(Claims) + if !ok { + err, ok := ctx.Value(errKey).(error) + if ok { + return Claims{}, err + } + return Claims{}, stacktrace.NewError("No claims or error in context") + } + + return claims, nil +} + // ScopeSet models a set of scopes. type ScopeSet map[string]struct{} @@ -61,12 +90,12 @@ func (s *ScopeSet) ToStringSlice() []string { return scopes } -type claims struct { +type Claims struct { jwt.StandardClaims Scopes ScopeSet `json:"scope"` } -func (c *claims) Valid() error { +func (c *Claims) Valid() error { if c.Subject == "" { return errMissingOrEmptySubject } diff --git a/pkg/auth/claims_test.go b/pkg/auth/claims/claims_test.go similarity index 96% rename from pkg/auth/claims_test.go rename to pkg/auth/claims/claims_test.go index 9005ee1d5..e0cdcbb6f 100644 --- a/pkg/auth/claims_test.go +++ b/pkg/auth/claims/claims_test.go @@ -1,4 +1,4 @@ -package auth +package claims import ( "encoding/json" @@ -8,7 +8,7 @@ import ( ) func TestScopesJSONUnmarshaling(t *testing.T) { - claims := &claims{} + claims := &Claims{} require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims)) require.Contains(t, claims.Scopes, "one") require.Contains(t, claims.Scopes, "two") diff --git a/pkg/logging/http.go b/pkg/logging/http.go index b41a4d6eb..ab50d874c 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/interuss/dss/pkg/auth/claims" "go.uber.org/zap" ) @@ -39,9 +40,6 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) { w.next.WriteHeader(statusCode) } -type CtxAuthError struct{} -type CtxAuthSubject struct{} - // HTTPMiddleware installs a logging http.Handler that logs requests and // selected aspects of responses to 'logger'. func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Handler { @@ -72,12 +70,8 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha } } - subject, ok := r.Context().Value(CtxAuthSubject{}).(string) - if !ok { - authErrorMsg := r.Context().Value(CtxAuthError{}).(string) - logger = logger.With(zap.String("resp_sub_err", authErrorMsg)) - } else { - logger = logger.With(zap.String("req_sub", subject)) + if claimsValue, _ := claims.FromContext(r.Context()); claimsValue.Subject != "" { + logger = logger.With(zap.String("req_sub", claimsValue.Subject)) } handler.ServeHTTP(trw, r)