diff --git a/router-tests/go.mod b/router-tests/go.mod index 28d13da9e5..5876460d32 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20260213130455-6e3277e7b850 github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index 8740bb0404..902fb0daec 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -356,8 +356,8 @@ github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLV github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255 h1:lN+D5OWay3U1mwtRlA+j7kJqP5ksKdRFMvYA+8XLJ1E= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255/go.mod h1:gfmmrPd2khZONmwYE8RIfnGjwIG+RqL52jYiBzcUST8= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb h1:9ZqCuPqE4x3N8tu8jbm+mtPFP+z7H1Fdf3Mjuq5Qv9s= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb/go.mod h1:gfmmrPd2khZONmwYE8RIfnGjwIG+RqL52jYiBzcUST8= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index b517c287f1..a63083e414 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -711,4 +711,76 @@ func TestStartSubscriptionHook(t *testing.T) { assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) + + t.Run("Test StartSubscription hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the subscription start hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("subscription.employeeUpdatedMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 7, + } + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 7, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index bc5187d400..af93b4c6f6 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -358,4 +358,52 @@ func TestPublishHook(t *testing.T) { require.Equal(t, []byte("3"), header.Value) }) }) + + t.Run("Test Publish hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the publish hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("mutation.updateEmployeeMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return events, nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 5, update: {name: "test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 5, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index c71d3019bb..3936d92b6c 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -963,4 +963,91 @@ func TestReceiveHook(t *testing.T) { assert.Equal(t, int32(3), customModule.HookCallCount.Load()) }) }) + + t.Run("Test Receive hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the receive hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("subscription.employeeUpdatedMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return events, nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 3, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router/core/arguments.go b/router/core/arguments.go new file mode 100644 index 0000000000..eab6788119 --- /dev/null +++ b/router/core/arguments.go @@ -0,0 +1,106 @@ +package core + +import ( + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" +) + +// Arguments allow access to GraphQL field arguments used by clients. +type Arguments struct { + // mapping maps "fieldPath.argumentName" to "variableName". + // For example: {"user.posts.limit": "a", "user.id": "userId"} + mapping astnormalization.FieldArgumentMapping + + // variables contains the JSON-parsed variables from the request. + variables *astjson.Value +} + +// NewArguments creates an Arguments instance. +func NewArguments( + mapping astnormalization.FieldArgumentMapping, + variables *astjson.Value, +) Arguments { + return Arguments{ + mapping: mapping, + variables: variables, + } +} + +// Get will return the value of the field argument at path. +// +// To access a specific field argument you need to provide +// the path in it's GraphQL operation via dot notation, +// prefixed by the root levels type. +// +// Get("rootfield_operation_type.rootfield_name.other.fields.argument_name") +// +// To access the storeId field argument of the operation +// +// subscription { +// orderUpdated(storeId: 1) { +// id +// status +// } +// } +// +// you need to call Get("subscription.orderUpdated.storeId") . +// You can also access deeper nested fields. +// For example you can access the categoryId field of the operation +// +// subscription { +// orderUpdated(storeId: 1) { +// lineItems(categoryId: 2) { +// id +// name +// } +// } +// } +// +// by calling Get("subscription.orderUpdated.lineItems.categoryId") . +// +// If you use aliases in operation you need to provide the alias name +// instead of the field name. +// +// query { +// a: user(id: "1") { name } +// b: user(id: "2") { name } +// } +// +// You need to call Get("query.a.id") or Get("query.b.id") respectively. +// +// If you want to access field arguments of fragments, you need to +// access it on one of the fields where the fragment is resolved. +// +// fragment GoldTrophies on RaceDrivers { +// trophies(color:"gold") { +// title +// } +// } +// +// subscription { +// driversFinish { +// name +// ... GoldTrophies +// } +// } +// +// If you want to access the "color" field argument, you need to +// call Get("subscription.driversFinish.trophies.color") . +// The same concept applies to inline fragments. +// +// If fa is nil, or f or a cannot be found, nil is returned. +func (fa *Arguments) Get(path string) *astjson.Value { + if fa == nil || len(fa.mapping) == 0 || fa.variables == nil { + return nil + } + + // Look up variable name from field argument map + varName, ok := fa.mapping[path] + if !ok { + return nil + } + + // Use the name to get the actual value from + // the operation contexts variables. + return fa.variables.Get(varName) +} diff --git a/router/core/arguments_test.go b/router/core/arguments_test.go new file mode 100644 index 0000000000..f6252c1c26 --- /dev/null +++ b/router/core/arguments_test.go @@ -0,0 +1,590 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" +) + +func TestArgumentMapping(t *testing.T) { + testCases := []struct { + name string + schema string + operation string + variables string + assertions func(t *testing.T, result Arguments) + }{ + { + name: "root field arguments with variables are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUser($userId: ID!) { + user(id: $userId) { + id + name + } + } + `, + variables: `{"userId": "123"}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.user.id": "123", + } + assertFieldArgMap(t, expected, result) + }, + }, + { + name: "root field arguments without variables are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUser { + user(id: "123") { + id + name + } + } + `, + variables: `{}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.user.id": "123", + } + assertFieldArgMap(t, expected, result) + }, + }, + { + name: "nested field arguments are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + posts(limit: Int!, offset: Int): [Post!]! + } + type Post { + id: ID! + title: String! + } + `, + operation: ` + query GetUserPosts($userId: ID!, $limit: Int!, $offset: Int) { + user(id: $userId) { + id + posts(limit: $limit, offset: $offset) { + id + title + } + } + } + `, + variables: `{"userId": "user-1", "limit": 10, "offset": 5}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.user.id": "user-1", + "query.user.posts.limit": 10, + "query.user.posts.offset": 5, + } + assertFieldArgMap(t, expected, result) + }, + }, + { + name: "non-existent field returns nil", + schema: ` + type Query { + hello: String + } + `, + operation: ` + query { + hello + } + `, + variables: `{}`, + assertions: func(t *testing.T, result Arguments) { + arg := result.Get("query.hello.someArg") + require.Nil(t, arg, "expected nil for non-existent argument") + + arg = result.Get("query.nonExistent.arg") + require.Nil(t, arg, "expected nil for non-existent field") + }, + }, + { + name: "multiple root fields with arguments", + schema: ` + type Query { + user(id: ID!): User + post(slug: String!): Post + } + type User { + id: ID! + } + type Post { + slug: String! + } + `, + operation: ` + query GetUserAndPost($userId: ID!, $postSlug: String!) { + user(id: $userId) { + id + } + post(slug: $postSlug) { + slug + } + } + `, + variables: `{"userId": "user-123", "postSlug": "my-post"}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.user.id": "user-123", + "query.post.slug": "my-post", + } + assertFieldArgMap(t, expected, result) + }, + }, + { + name: "array argument is accessible", + schema: ` + type Query { + users(ids: [ID!]!): [User!]! + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($userIds: [ID!]!) { + users(ids: $userIds) { + id + name + } + } + `, + variables: `{"userIds": ["user-1", "user-2", "user-3"]}`, + assertions: func(t *testing.T, result Arguments) { + idsArg := result.Get("query.users.ids") + require.NotNil(t, idsArg, "expected 'ids' argument on 'users' field") + + // Verify it's an array + arr := idsArg.GetArray() + require.Len(t, arr, 3) + assert.Equal(t, "user-1", string(arr[0].GetStringBytes())) + assert.Equal(t, "user-2", string(arr[1].GetStringBytes())) + assert.Equal(t, "user-3", string(arr[2].GetStringBytes())) + }, + }, + { + name: "object argument is accessible", + schema: ` + type Query { + users(filter: UserFilter!): [User!]! + } + input UserFilter { + name: String + age: Int + active: Boolean! + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($filter: UserFilter!) { + users(filter: $filter) { + id + name + } + } + `, + variables: `{"filter": {"name": "John", "age": 30, "active": true}}`, + assertions: func(t *testing.T, result Arguments) { + filterArg := result.Get("query.users.filter") + require.NotNil(t, filterArg, "expected 'filter' argument on 'users' field") + + // Verify it's an object and access its fields + obj := filterArg.GetObject() + require.NotNil(t, obj) + + nameVal := filterArg.Get("name") + require.NotNil(t, nameVal) + assert.Equal(t, "John", string(nameVal.GetStringBytes())) + + ageVal := filterArg.Get("age") + require.NotNil(t, ageVal) + assert.Equal(t, 30, ageVal.GetInt()) + + activeVal := filterArg.Get("active") + require.NotNil(t, activeVal) + assert.True(t, activeVal.GetBool()) + }, + }, + { + name: "aliased fields have unique paths", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($id1: ID!, $id2: ID!) { + a: user(id: $id1) { + id + name + } + b: user(id: $id2) { + id + name + } + } + `, + variables: `{"id1": "user-1", "id2": "user-2"}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.a.id": "user-1", + "query.b.id": "user-2", + } + assertFieldArgMap(t, expected, result) + + // Using the field name should not find the arguments + userIdArg := result.Get("query.user.id") + assert.Nil(t, userIdArg, "expected nil when using field name instead of alias") + }, + }, + { + // After normalization, named fragments are inlined, so arguments should be + // accessible via the normal field path (not fragment definition path) + name: "arguments from named fragments are accessible via spreaded path", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + posts(limit: Int!, offset: Int): [Post!]! + friends(first: Int!): [User!]! + } + type Post { + id: ID! + title: String! + } + `, + operation: ` + fragment UserPosts on User { + posts(limit: $postsLimit, offset: $postsOffset) { + id + title + } + } + + fragment UserFriends on User { + friends(first: $friendsCount) { + id + name + } + } + + query GetUser($userId: ID!, $postsLimit: Int!, $postsOffset: Int, $friendsCount: Int!) { + user(id: $userId) { + id + name + ...UserPosts + ...UserFriends + } + } + `, + variables: `{"userId": "user-1", "postsLimit": 10, "postsOffset": 5, "friendsCount": 20}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.user.id": "user-1", + "query.user.posts.limit": 10, + "query.user.posts.offset": 5, + "query.user.friends.first": 20, + } + assertFieldArgMap(t, expected, result) + }, + }, + { + // Inline fragments remain in the AST after normalization and must be accessible + // with $TypeName notation. + name: "arguments within inline fragments are accessible with $TypeName prefix", + schema: ` + type Query { + search(query: String!): [SearchResult!]! + } + + union SearchResult = User | Post + + type User { + id: ID! + name(format: String): String! + email(verified: Boolean): String! + } + + type Post { + id: ID! + title(truncate: Int): String! + content: String! + } + `, + operation: ` + query GetSearchResults($searchQuery: String!, $nameFormat: String, $verifiedOnly: Boolean) { + search(query: $searchQuery) { + ... on User { + id + name(format: $nameFormat) + email(verified: $verifiedOnly) + } + ... on Post { + id + title(truncate: 100) + content + } + } + } + `, + variables: `{"searchQuery": "test", "nameFormat": "uppercase", "verifiedOnly": true}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.search.query": "test", + "query.search.$User.name.format": "uppercase", + "query.search.$User.email.verified": true, + "query.search.$Post.title.truncate": 100, + } + assertFieldArgMap(t, expected, result) + }, + }, + { + name: "arguments in nested inline fragments are accessible", + schema: ` + interface Titleable { + title(f1: Int): String + } + + interface Nameable { + name(f2: Int): String + } + + type Trophie implements Titleable { + title(f1: Int): String + } + + type Doctor implements Titleable & Nameable { + title(f1: Int): String + name(f2: Int): String + profession(f3: Int): String + } + + type Person implements Nameable { + name(f2: Int): String + hobby(f4: Int): String + } + + type Query { + title(f1: Int): Titleable + } + `, + operation: ` + query { + title(f1: 1) { + ... on Nameable { + name(f2: 2) + ... on Doctor { + profession(f3: 3) + } + ... on Person { + hobby(f4: 4) + } + } + } + } + `, + variables: ``, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.title.f1": 1, + "query.title.$Nameable.name.f2": 2, + "query.title.$Nameable.$Doctor.profession.f3": 3, + "query.title.$Nameable.$Person.hobby.f4": 4, + } + assertFieldArgMap(t, expected, result) + }, + }, + { + // The engine removes inline fragments from operations, + // if they are inaccessable. This can happen on nested interface selections + // where a fragment type implements an interface but not the other. + // We expect a field argument inside such a fragment to still be part + // of the mapping. + name: "arguments in unreachable inline fragments are accessible", + schema: ` + interface Titleable { + title(f1: Int): String + } + + interface Nameable { + name(f2: Int): String + } + + type Trophie implements Titleable { + title(f1: Int): String + } + + type Doctor implements Titleable & Nameable { + title(f1: Int): String + name(f2: Int): String + profession(f3: Int): String + } + + type Person implements Nameable { + name(f2: Int): String + hobby(f4: Int): String + } + + type Query { + title(f1: Int): Titleable + } + `, + operation: ` + query($v1: Int, $v2: Int, $v3: Int, $v4: Int) { + title { # returns Titleable + title(f1: $v1) + ... on Nameable { + name(f2: $v2) + ... on Doctor { + profession(f3: $v3) + } + ... on Person { # implements Nameable but not Titleable + hobby(f4: $v4) + } + } + } + } + `, + variables: `{"v1": 1, "v2": 2, "v3": 3, "v4": 4}`, + assertions: func(t *testing.T, result Arguments) { + expected := map[string]any{ + "query.title.title.f1": 1, + "query.title.$Nameable.name.f2": 2, + "query.title.$Nameable.$Doctor.profession.f3": 3, + "query.title.$Nameable.$Person.hobby.f4": 4, // should exist + } + assertFieldArgMap(t, expected, result) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Mimic what the router is doing by first parsing the schema and operation, + // then normalize the query and only then normalize the variables (and create + // field argument mapping) + + schema, report := astparser.ParseGraphqlDocumentString(tc.schema) + require.False(t, report.HasErrors(), "failed to parse schema") + err := asttransform.MergeDefinitionWithBaseSchema(&schema) + require.NoError(t, err) + + // Parse operation + operation, report := astparser.ParseGraphqlDocumentString(tc.operation) + require.False(t, report.HasErrors(), "failed to parse operation") + + // Set variables before normalization (like the router does) + operation.Input.Variables = []byte(tc.variables) + + // Normalize operation (merges provided variables with extracted inline literals) + rep := &operationreport.Report{} + norm := astnormalization.NewNormalizer(true, true) + norm.NormalizeOperation(&operation, &schema, rep) + require.False(t, rep.HasErrors(), "failed to normalize operation") + + // Then normalize variables using VariablesNormalizer which returns the field argument mapping + varNorm := astnormalization.NewVariablesNormalizer( + astnormalization.VariablesNormalizerOptions{EnableFieldArgumentMapping: true}, + ) + result := varNorm.NormalizeOperation(&operation, &schema, rep) + require.False(t, rep.HasErrors(), "failed to normalize variables") + + // Use normalized variables (includes both provided and extracted variables) + vars, err := astjson.ParseBytes(operation.Input.Variables) + require.NoError(t, err) + + arguments := NewArguments(result.FieldArgumentMapping, vars) + + tc.assertions(t, arguments) + }) + } +} + +func TestNewArguments_NilMapping(t *testing.T) { + // Test that nil mapping returns empty Arguments + result := NewArguments(nil, nil) + assert.Nil(t, result.Get("query.user.id")) +} + +func TestNewArguments_EmptyMapping(t *testing.T) { + // Test that empty mapping returns empty Arguments + result := NewArguments(astnormalization.FieldArgumentMapping{}, nil) + assert.Nil(t, result.Get("query.user.id")) +} + +func TestArguments_Get_NonExistentPath(t *testing.T) { + vars, err := astjson.ParseBytes([]byte(`{"userId": "123"}`)) + require.NoError(t, err) + + mapping := astnormalization.FieldArgumentMapping{ + "query.user.id": "userId", + } + args := NewArguments(mapping, vars) + + assert.Nil(t, args.Get("query.user.nonexistent")) + assert.Nil(t, args.Get("mutation.createUser.id")) + assert.Nil(t, args.Get("")) +} + +func assertFieldArgMap(t *testing.T, expected map[string]any, result Arguments) { + for path, expectedValue := range expected { + jsonValue := result.Get(path) + require.NotNil(t, jsonValue, "no value found at path '%s'", path) + + switch valType := jsonValue.Type(); valType { + case astjson.TypeNumber: + // in tests we assume its always int + assert.Equal(t, expectedValue, jsonValue.GetInt()) + case astjson.TypeString: + assert.Equal(t, expectedValue, string(jsonValue.GetStringBytes())) + case astjson.TypeFalse, astjson.TypeTrue: + assert.Equal(t, expectedValue, jsonValue.GetBool()) + default: + t.Fatalf("can't assert on unknown astjson type '%s'", valType) + } + } +} diff --git a/router/core/cache_warmup.go b/router/core/cache_warmup.go index 56645126a2..153fe26fbc 100644 --- a/router/core/cache_warmup.go +++ b/router/core/cache_warmup.go @@ -323,7 +323,7 @@ func (c *CacheWarmupPlanningProcessor) ProcessOperation(ctx context.Context, ope return nil, err } - _, _, err = k.NormalizeVariables() + _, _, _, err = k.NormalizeVariables() if err != nil { return nil, err } diff --git a/router/core/context.go b/router/core/context.go index 6162f96499..9e92729e16 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -564,16 +564,16 @@ type OperationContext interface { Hash() uint64 // Content is the content of the operation Content() string - // Variables is the variables of the operation + // Arguments allow access to GraphQL operation field arguments. + Arguments() *Arguments + // Variables allow access to GraphQL operation variables. Variables() *astjson.Value // ClientInfo returns information about the client that initiated this operation ClientInfo() ClientInfo - // Sha256Hash returns the SHA256 hash of the original operation // It is important to note that this hash is not calculated just because this method has been called // and is only calculated based on other existing logic (such as if sha256Hash is used in expressions) Sha256Hash() string - // QueryPlanStats returns some statistics about the query plan for the operation // if called too early in request chain, it may be inaccurate for modules, using // in Middleware is recommended @@ -611,11 +611,13 @@ type operationContext struct { // RawContent is the raw content of the operation rawContent string // Content is the normalized content of the operation - content string - variables *astjson.Value - variablesHash uint64 - files []*httpclient.FileUpload - clientInfo *ClientInfo + content string + // These are not mapped by default, only when certain custom modules require them. + fieldArguments Arguments + variables *astjson.Value + variablesHash uint64 + files []*httpclient.FileUpload + clientInfo *ClientInfo // preparedPlan is the prepared plan of the operation preparedPlan *planWithMetaData traceOptions resolve.TraceOptions @@ -648,6 +650,10 @@ func (o *operationContext) Variables() *astjson.Value { return o.variables } +func (o *operationContext) Arguments() *Arguments { + return &o.fieldArguments +} + func (o *operationContext) Files() []*httpclient.FileUpload { return o.files } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 8413bcfca6..33ee2565d0 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1305,12 +1305,13 @@ func (s *graphServer) buildGraphMux( MaxDepth: s.securityConfiguration.ParserLimits.ApproximateDepthLimit, MaxFields: s.securityConfiguration.ParserLimits.TotalFieldsLimit, }, - OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit, - ApolloCompatibilityFlags: s.apolloCompatibilityFlags, - ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, + OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit, + ApolloCompatibilityFlags: s.apolloCompatibilityFlags, + ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError, RelaxSubgraphOperationFieldSelectionMergingNullability: s.engineExecutionConfiguration.RelaxSubgraphOperationFieldSelectionMergingNullability, ComplexityLimits: s.securityConfiguration.ComplexityLimits, + EnableFieldArgumentMapping: s.subscriptionHooks.needFieldArgumentMapping(), }) operationPlanner := NewOperationPlanner(executor, gm.planCache, opts.ReloadPersistentState.inMemoryPlanCacheFallback.IsEnabled()) @@ -1516,6 +1517,7 @@ func (s *graphServer) buildGraphMux( HasPreOriginHandlers: len(s.preOriginHandlers) != 0, HeaderPropagation: s.headerPropagation, OperationContentAttributes: s.traceConfig.OperationContentAttributes, + MapFieldArguments: s.subscriptionHooks.needFieldArgumentMapping(), }) if s.webSocketConfiguration != nil && s.webSocketConfiguration.Enabled { @@ -1537,6 +1539,7 @@ func (s *graphServer) buildGraphMux( WebSocketConfiguration: s.webSocketConfiguration, ClientHeader: s.clientHeader, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, + MapFieldArguments: s.subscriptionHooks.needFieldArgumentMapping(), ApolloCompatibilityFlags: s.apolloCompatibilityFlags, }) diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 817e27c2dd..c5ab2d28a3 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -67,6 +67,7 @@ type PreHandlerOptions struct { ComputeOperationSha256 bool ApolloCompatibilityFlags *config.ApolloCompatibilityFlags DisableVariablesRemapping bool + MapFieldArguments bool ExprManager *expr.Manager OmitBatchExtensions bool OperationContentAttributes bool @@ -107,6 +108,7 @@ type PreHandler struct { computeOperationSha256 bool apolloCompatibilityFlags *config.ApolloCompatibilityFlags disableVariablesRemapping bool + mapFieldArguments bool exprManager *expr.Manager omitBatchExtensions bool operationContentAttributes bool @@ -169,6 +171,7 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler { computeOperationSha256: opts.ComputeOperationSha256, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, exprManager: opts.ExprManager, omitBatchExtensions: opts.OmitBatchExtensions, operationContentAttributes: opts.OperationContentAttributes, @@ -803,7 +806,7 @@ func (h *PreHandler) handleOperation(req *http.Request, httpOperation *httpOpera * Normalize the variables */ - cached, uploadsMapping, err := operationKit.NormalizeVariables() + cached, uploadsMapping, fieldArgMapping, err := operationKit.NormalizeVariables() if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) @@ -818,6 +821,7 @@ func (h *PreHandler) handleOperation(req *http.Request, httpOperation *httpOpera engineNormalizeSpan.End() return err } + // Store the field argument mapping for later use when creating Arguments engineNormalizeSpan.SetAttributes(otel.WgVariablesNormalizationCacheHit.Bool(cached)) requestContext.operation.variablesNormalizationCacheHit = cached @@ -936,6 +940,14 @@ func (h *PreHandler) handleOperation(req *http.Request, httpOperation *httpOpera engineNormalizeSpan.End() return err } + + if h.mapFieldArguments { + requestContext.operation.fieldArguments = NewArguments( + fieldArgMapping, + requestContext.operation.variables, + ) + } + requestContext.operation.normalizationTime = time.Since(startNormalization) requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) diff --git a/router/core/operation_processor.go b/router/core/operation_processor.go index a4842d6b57..58dae46826 100644 --- a/router/core/operation_processor.go +++ b/router/core/operation_processor.go @@ -111,23 +111,24 @@ type OperationProcessorOptions struct { PersistedOperationClient *persistedoperation.Client AutomaticPersistedOperationCacheTtl int - EnablePersistedOperationsCache bool - PersistedOpsNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] - NormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] - QueryDepthCache *ristretto.Cache[uint64, ComplexityCacheEntry] - VariablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry] - RemapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry] - ValidationCache *ristretto.Cache[uint64, bool] - OperationHashCache *ristretto.Cache[uint64, string] - ParseKitPoolSize int - IntrospectionEnabled bool - ApolloCompatibilityFlags config.ApolloCompatibilityFlags - ApolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags - DisableExposingVariablesContentOnValidationError bool - RelaxSubgraphOperationFieldSelectionMergingNullability bool - ComplexityLimits *config.ComplexityLimits - ParserTokenizerLimits astparser.TokenizerLimits - OperationNameLengthLimit int + EnablePersistedOperationsCache bool + PersistedOpsNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] + NormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] + QueryDepthCache *ristretto.Cache[uint64, ComplexityCacheEntry] + VariablesNormalizationCache *ristretto.Cache[uint64, VariablesNormalizationCacheEntry] + RemapVariablesCache *ristretto.Cache[uint64, RemapVariablesCacheEntry] + ValidationCache *ristretto.Cache[uint64, bool] + OperationHashCache *ristretto.Cache[uint64, string] + ParseKitPoolSize int + IntrospectionEnabled bool + ApolloCompatibilityFlags config.ApolloCompatibilityFlags + ApolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags + DisableExposingVariablesContentOnValidationError bool + RelaxSubgraphOperationFieldSelectionMergingNullability bool + ComplexityLimits *config.ComplexityLimits + ParserTokenizerLimits astparser.TokenizerLimits + OperationNameLengthLimit int + EnableFieldArgumentMapping bool } // OperationProcessor provides shared resources to the parseKit and OperationKit. @@ -783,6 +784,10 @@ type VariablesNormalizationCacheEntry struct { // request spec for file uploads. uploadsMapping []uploads.UploadPathMapping + // fieldArgumentMapping maps field arguments to their variable names for fast lookup. + // This is populated during variable normalization and cached to avoid repeated AST walks. + fieldArgumentMapping astnormalization.FieldArgumentMapping + // reparse indicates whether the operation document needs to be reparsed from // its string representation when retrieved from the cache. reparse bool @@ -914,10 +919,10 @@ func (o *OperationKit) normalizeVariablesCacheKey() uint64 { } // NormalizeVariables normalizes variables and returns a slice of upload mappings -// if any of them were present in a query. +// if any of them were present in a query, as well as the field argument mapping. // If normalized values were found in the cache, it skips normalization and returns the caching set to true. // If an error is returned, then caching is set to false. -func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.UploadPathMapping, err error) { +func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.UploadPathMapping, fieldArgMapping astnormalization.FieldArgumentMapping, err error) { cacheKey := o.normalizeVariablesCacheKey() if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry, ok := o.cache.variablesNormalizationCache.Get(cacheKey) @@ -931,10 +936,10 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if entry.reparse { if err = o.setAndParseOperationDoc(); err != nil { - return false, nil, err + return false, nil, nil, err } } - return true, entry.uploadsMapping, nil + return true, entry.uploadsMapping, entry.fieldArgumentMapping, nil } } @@ -947,11 +952,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo o.kit.keyGen.Reset() report := &operationreport.Report{} - uploadsMapping := o.kit.variablesNormalizer.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) + normalizerResult := o.kit.variablesNormalizer.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) if report.HasErrors() { - return false, nil, &reportError{report: report} + return false, nil, nil, &reportError{report: report} } + uploadsMapping := normalizerResult.UploadsMapping + fieldArgumentMapping := normalizerResult.FieldArgumentMapping + // Assuming the user sends a multi-operation document // During normalization, we removed the unused operations from the document // This will always lead to operation definitions of a length of 1 even when multiple operations are sent @@ -971,14 +979,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { - return false, nil, err + return false, nil, nil, err } // Reset the doc with the original name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = nameRef _, err = o.kit.keyGen.Write(o.kit.normalizedOperation.Bytes()) if err != nil { - return false, nil, err + return false, nil, nil, err } o.parsedOperation.ID = o.kit.keyGen.Sum64() o.kit.keyGen.Reset() @@ -993,6 +1001,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry := VariablesNormalizationCacheEntry{ uploadsMapping: uploadsMapping, + fieldArgumentMapping: fieldArgumentMapping, id: o.parsedOperation.ID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, variables: o.parsedOperation.Request.Variables, @@ -1002,14 +1011,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1) } - return false, uploadsMapping, nil + return false, uploadsMapping, fieldArgumentMapping, nil } o.kit.normalizedOperation.Reset() err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { - return false, nil, err + return false, nil, nil, err } o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() @@ -1018,6 +1027,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry := VariablesNormalizationCacheEntry{ uploadsMapping: uploadsMapping, + fieldArgumentMapping: fieldArgumentMapping, id: o.parsedOperation.ID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, variables: o.parsedOperation.Request.Variables, @@ -1027,7 +1037,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1) } - return false, uploadsMapping, nil + return false, uploadsMapping, fieldArgumentMapping, nil } func (o *OperationKit) remapVariablesCacheKey(disabled bool) uint64 { @@ -1422,10 +1432,11 @@ func (o *OperationKit) skipIncludeVariableNames() []string { } type parseKitOptions struct { - apolloCompatibilityFlags config.ApolloCompatibilityFlags - apolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags - disableExposingVariablesContentOnValidationError bool - relaxSubgraphOperationFieldSelectionMergingNullability bool + apolloCompatibilityFlags config.ApolloCompatibilityFlags + apolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags + disableExposingVariablesContentOnValidationError bool + relaxSubgraphOperationFieldSelectionMergingNullability bool + enableFieldArgumentMapping bool } func createParseKit(i int, options *parseKitOptions) *parseKit { @@ -1441,7 +1452,11 @@ func createParseKit(i int, options *parseKitOptions) *parseKit { astnormalization.WithRemoveFragmentDefinitions(), astnormalization.WithRemoveUnusedVariables(), ), - variablesNormalizer: astnormalization.NewVariablesNormalizer(), + variablesNormalizer: astnormalization.NewVariablesNormalizer( + astnormalization.VariablesNormalizerOptions{ + EnableFieldArgumentMapping: options.enableFieldArgumentMapping, + }, + ), variablesRemapper: astnormalization.NewVariablesMapper(), printer: &astprinter.Printer{}, normalizedOperation: &bytes.Buffer{}, @@ -1486,10 +1501,11 @@ func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor { operationNameLengthLimit: opts.OperationNameLengthLimit, complexityLimits: opts.ComplexityLimits, parseKitOptions: &parseKitOptions{ - apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, - apolloRouterCompatibilityFlags: opts.ApolloRouterCompatibilityFlags, - disableExposingVariablesContentOnValidationError: opts.DisableExposingVariablesContentOnValidationError, + apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, + apolloRouterCompatibilityFlags: opts.ApolloRouterCompatibilityFlags, + disableExposingVariablesContentOnValidationError: opts.DisableExposingVariablesContentOnValidationError, relaxSubgraphOperationFieldSelectionMergingNullability: opts.RelaxSubgraphOperationFieldSelectionMergingNullability, + enableFieldArgumentMapping: opts.EnableFieldArgumentMapping, }, } for i := 0; i < opts.ParseKitPoolSize; i++ { diff --git a/router/core/operation_processor_test.go b/router/core/operation_processor_test.go index 390577add2..afca2ab9b1 100644 --- a/router/core/operation_processor_test.go +++ b/router/core/operation_processor_test.go @@ -328,7 +328,7 @@ func TestNormalizeVariablesOperationProcessor(t *testing.T) { _, err = kit.NormalizeOperation("test", false) require.NoError(t, err) - _, _, err = kit.NormalizeVariables() + _, _, _, err = kit.NormalizeVariables() require.NoError(t, err) assert.Equal(t, tc.ExpectedNormalizedRepresentation, kit.parsedOperation.NormalizedRepresentation) diff --git a/router/core/router_config.go b/router/core/router_config.go index c3af6c201d..eeee7ed9cd 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -47,6 +47,10 @@ type onReceiveEventsHooks struct { timeout time.Duration } +func (h *subscriptionHooks) needFieldArgumentMapping() bool { + return len(h.onStart.handlers) > 0 || len(h.onPublishEvents.handlers) > 0 || len(h.onReceiveEvents.handlers) > 0 +} + type Config struct { clusterName string instanceID string diff --git a/router/core/websocket.go b/router/core/websocket.go index 0aa5ca5588..26c43b5621 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -63,6 +63,7 @@ type WebsocketMiddlewareOptions struct { ClientHeader config.ClientHeader DisableVariablesRemapping bool + MapFieldArguments bool ApolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -85,6 +86,7 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions config: opts.WebSocketConfiguration, clientHeader: opts.ClientHeader, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, } if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled { @@ -265,6 +267,7 @@ type WebsocketHandler struct { clientHeader config.ClientHeader disableVariablesRemapping bool + mapFieldArguments bool apolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -372,6 +375,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R ForwardUpgradeHeaders: h.forwardUpgradeHeadersConfig, ForwardQueryParams: h.forwardQueryParamsConfig, DisableVariablesRemapping: h.disableVariablesRemapping, + MapFieldArguments: h.mapFieldArguments, ApolloCompatibilityFlags: h.apolloCompatibilityFlags, }) err = handler.Initialize() @@ -713,6 +717,7 @@ type WebSocketConnectionHandlerOptions struct { ForwardUpgradeHeaders forwardConfig ForwardQueryParams forwardConfig DisableVariablesRemapping bool + MapFieldArguments bool ApolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -750,6 +755,7 @@ type WebSocketConnectionHandler struct { forwardQueryParams *forwardConfig disableVariablesRemapping bool + mapFieldArguments bool apolloCompatibilityFlags config.ApolloCompatibilityFlags @@ -791,6 +797,7 @@ func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnection forwardInitialPayload: opts.ForwardInitialPayload, plannerOptions: opts.PlanOptions, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, clientInfoFromInitialPayload: opts.ClientInfoFromInitialPayload, } @@ -907,7 +914,7 @@ func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegi } opContext.normalizationCacheHit = operationKit.parsedOperation.NormalizationCacheHit - cached, _, err := operationKit.NormalizeVariables() + cached, _, fieldArgMapping, err := operationKit.NormalizeVariables() if err != nil { opContext.normalizationTime = time.Since(startNormalization) return nil, nil, err @@ -933,6 +940,13 @@ func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegi return nil, nil, err } + if h.mapFieldArguments { + opContext.fieldArguments = NewArguments( + fieldArgMapping, + opContext.variables, + ) + } + startValidation := time.Now() _, _, err = operationKit.ValidateQueryComplexity() diff --git a/router/go.mod b/router/go.mod index fde3dffcb3..2c88c7dd5c 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index b3f1a180a2..14e3ed1a5d 100644 --- a/router/go.sum +++ b/router/go.sum @@ -326,8 +326,8 @@ github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLV github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255 h1:lN+D5OWay3U1mwtRlA+j7kJqP5ksKdRFMvYA+8XLJ1E= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255/go.mod h1:gfmmrPd2khZONmwYE8RIfnGjwIG+RqL52jYiBzcUST8= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb h1:9ZqCuPqE4x3N8tu8jbm+mtPFP+z7H1Fdf3Mjuq5Qv9s= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.255.0.20260223080430-55c1a82aa3bb/go.mod h1:gfmmrPd2khZONmwYE8RIfnGjwIG+RqL52jYiBzcUST8= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=