Skip to content
Open
41 changes: 35 additions & 6 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

const (
tokenIdArg = "tokenId"
subjectArg = "subject"
)

type UnauthorizedError struct {
Expand Down Expand Up @@ -42,19 +43,45 @@ func newError(msg string, args ...any) error {

func NewVehicleTokenCheck(requiredAddr common.Address) func(context.Context, any, graphql.Resolver) (any, error) {
return func(ctx context.Context, _ any, next graphql.Resolver) (any, error) {
vehicleTokenID, err := getArg[int](ctx, tokenIdArg)
if err != nil {
return nil, UnauthorizedError{err: err}
tokenID, err := getArg[*int](ctx, tokenIdArg)
if err != nil && !errors.Is(err, errArgNotFound) {
return nil, UnauthorizedError{err: fmt.Errorf("failed to get %s arg: %w", tokenIdArg, err)}
}
subject, err := getArg[*string](ctx, subjectArg)
if err != nil && !errors.Is(err, errArgNotFound) {
return nil, UnauthorizedError{err: fmt.Errorf("failed to get %s arg: %w", subjectArg, err)}
}

if err := validateHeader(ctx, requiredAddr, vehicleTokenID); err != nil {
return nil, UnauthorizedError{err: err}
switch {
case tokenID != nil && subject != nil:
return nil, UnauthorizedError{message: "provide either tokenId or subject, not both"}
case tokenID != nil:
if err := validateHeader(ctx, requiredAddr, *tokenID); err != nil {
return nil, UnauthorizedError{err: err}
}
case subject != nil:
if err := validateSubject(ctx, *subject); err != nil {
return nil, UnauthorizedError{err: err}
}
default:
return nil, UnauthorizedError{message: "tokenId or subject is required"}
}

return next(ctx)
}
}

func validateSubject(ctx context.Context, subject string) error {
claim, err := getTelemetryClaim(ctx)
if err != nil {
return err
}
if subject != claim.Asset {
return newError("subject does not match token claim")
}
return nil
}

func validateHeader(ctx context.Context, requiredAddr common.Address, tokenID int) error {
claim, err := getTelemetryClaim(ctx)
if err != nil {
Expand Down Expand Up @@ -104,6 +131,8 @@ func OneOfPrivilegeCheck(ctx context.Context, _ any, next graphql.Resolver, requ
return nil, newError("requires at least one of the following privileges %v", requiredPrivs)
}

var errArgNotFound = errors.New("arg not found")

func getArg[T any](ctx context.Context, name string) (T, error) {
var resp T
fCtx := graphql.GetFieldContext(ctx)
Expand All @@ -113,7 +142,7 @@ func getArg[T any](ctx context.Context, name string) (T, error) {

val, ok := fCtx.Args[name]
if !ok {
return resp, fmt.Errorf("no argument named %s", name)
return resp, errArgNotFound
}

resp, ok = val.(T)
Expand Down
78 changes: 72 additions & 6 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {

vehicleNFTAddr := common.HexToAddress("0x1")

tokenID123 := 123
tokenID456 := 456
validSubject := "did:erc721:1:0x0000000000000000000000000000000000000001:123"
wrongSubject := "did:erc721:1:0x0000000000000000000000000000000000000001:456"

testCases := []struct {
name string
args map[string]any
Expand All @@ -34,7 +39,7 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
{
name: "valid_token",
args: map[string]any{
"tokenId": 123,
"tokenId": &tokenID123,
},
telemetryClaim: &TelemetryClaim{
AssetDID: cloudevent.ERC721DID{
Expand All @@ -47,7 +52,7 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
{
name: "invalid_token",
args: map[string]any{
"tokenId": 456,
"tokenId": &tokenID456,
},
telemetryClaim: &TelemetryClaim{
AssetDID: cloudevent.ERC721DID{
Expand All @@ -59,7 +64,7 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
expectedError: true,
},
{
name: "missing_tokenId",
name: "missing_both",
args: map[string]any{},
expectedError: true,
telemetryClaim: &TelemetryClaim{
Expand All @@ -71,8 +76,10 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
},
},
{
name: "wrong_contract",
args: map[string]any{},
name: "wrong_contract",
args: map[string]any{
"tokenId": &tokenID123,
},
expectedError: true,
telemetryClaim: &TelemetryClaim{
AssetDID: cloudevent.ERC721DID{
Expand All @@ -85,7 +92,66 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
{
name: "missing claim",
args: map[string]any{
"tokenId": 123,
"tokenId": &tokenID123,
},
expectedError: true,
telemetryClaim: nil,
},
{
name: "valid_subject",
args: map[string]any{
"subject": &validSubject,
},
telemetryClaim: &TelemetryClaim{
CustomClaims: tokenclaims.CustomClaims{
Asset: validSubject,
},
AssetDID: cloudevent.ERC721DID{
ChainID: 1,
ContractAddress: vehicleNFTAddr,
TokenID: big.NewInt(123),
},
},
},
{
name: "wrong_subject",
args: map[string]any{
"subject": &wrongSubject,
},
telemetryClaim: &TelemetryClaim{
CustomClaims: tokenclaims.CustomClaims{
Asset: validSubject,
},
AssetDID: cloudevent.ERC721DID{
ChainID: 1,
ContractAddress: vehicleNFTAddr,
TokenID: big.NewInt(123),
},
},
expectedError: true,
},
{
name: "both_tokenId_and_subject",
args: map[string]any{
"tokenId": &tokenID123,
"subject": &validSubject,
},
telemetryClaim: &TelemetryClaim{
CustomClaims: tokenclaims.CustomClaims{
Asset: validSubject,
},
AssetDID: cloudevent.ERC721DID{
ChainID: 1,
ContractAddress: vehicleNFTAddr,
TokenID: big.NewInt(123),
},
},
expectedError: true,
},
{
name: "subject_missing_claim",
args: map[string]any{
"subject": &validSubject,
},
expectedError: true,
telemetryClaim: nil,
Expand Down
26 changes: 21 additions & 5 deletions internal/graph/arguments.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,40 @@ package graph

import (
"context"
"errors"
"fmt"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/DIMO-Network/model-garage/pkg/vss"
"github.com/DIMO-Network/server-garage/pkg/gql/errorhandler"
"github.com/DIMO-Network/telemetry-api/internal/graph/model"
)

// resolveSubject resolves the subject from the provided tokenID or subject arguments.
// Exactly one of tokenID or subject must be provided.
func (r *queryResolver) resolveSubject(ctx context.Context, tokenID *int, subject *string) (string, error) {
if subject != nil && tokenID != nil {
return "", errorhandler.NewBadRequestError(ctx, errors.New("provide either tokenId or subject, not both"))
}
if subject != nil {
return *subject, nil
}
if tokenID != nil {
return r.BaseRepo.ToSubject(uint32(*tokenID)), nil
}
return "", errorhandler.NewBadRequestError(ctx, errors.New("tokenId or subject is required"))
}

// aggregationArgsFromContext creates an aggregated signals arguments from the context and the provided arguments.
func aggregationArgsFromContext(ctx context.Context, tokenID int, interval string, from time.Time, to time.Time, filter *model.SignalFilter) (*model.AggregatedSignalArgs, error) {
// 1h 1s
func aggregationArgsFromContext(ctx context.Context, subject string, interval string, from time.Time, to time.Time, filter *model.SignalFilter) (*model.AggregatedSignalArgs, error) {
intervalInt, err := getIntervalMicroseconds(interval)
if err != nil {
return nil, err
}
aggArgs := model.AggregatedSignalArgs{
SignalArgs: model.SignalArgs{
TokenID: uint32(tokenID),
Subject: subject,
Filter: filter,
},
FromTS: from,
Expand Down Expand Up @@ -86,11 +102,11 @@ func addSignalAggregation(aggArgs *model.AggregatedSignalArgs, child *graphql.Fi
}

// latestArgsFromContext creates a latest signals arguments from the context and the provided arguments.
func latestArgsFromContext(ctx context.Context, tokenID int, filter *model.SignalFilter) (*model.LatestSignalsArgs, error) {
func latestArgsFromContext(ctx context.Context, subject string, filter *model.SignalFilter) (*model.LatestSignalsArgs, error) {
fields := graphql.CollectFieldsCtx(ctx, nil)
latestArgs := model.LatestSignalsArgs{
SignalArgs: model.SignalArgs{
TokenID: uint32(tokenID),
Subject: subject,
Filter: filter,
},
SignalNames: make(map[string]struct{}),
Expand Down
32 changes: 24 additions & 8 deletions internal/graph/base.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions internal/graph/events.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading