Skip to content

Commit 8459b34

Browse files
authored
feat: forward headers to grpc subgraphs (#1382)
This is part of a Cosmo Connect related feature. It allows to send HTTP headers from user requests down to gRPC sources. Technically it converts headers to [gRPC metadata](https://grpc.io/docs/guides/metadata/) on `datasource.Load`. A header `X-My-Custom-Header` will be sent as a grpc metadata field `x-my-custom-header` with all grpc calls the datasource will make. The change is implemented solely on the grpc datasource. All headers are mapped one-to-one as they are given to the datasources `Load` method. ## Checklist - [x] I have discussed my proposed changes in an issue and have received approval to proceed. - [x] I have followed the coding standards of the project. - [x] Tests or benchmarks have been added or updated. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * HTTP headers are now forwarded as gRPC metadata, allowing request headers to influence downstream behavior. * Metadata-driven overrides enabled for user IDs, names, and per-item counts via incoming headers. * Existing context metadata is preserved and composed with forwarded headers. * **Tests** * Added comprehensive tests validating header-driven overrides, metadata preservation, operation coverage (queries/mutations), and error scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent aa8904f commit 8459b34

4 files changed

Lines changed: 351 additions & 5 deletions

File tree

v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ import (
1111
"encoding/binary"
1212
"fmt"
1313
"net/http"
14+
"strings"
1415

1516
"github.com/cespare/xxhash/v2"
1617
"github.com/tidwall/gjson"
1718
"golang.org/x/sync/errgroup"
1819
"google.golang.org/grpc"
20+
"google.golang.org/grpc/metadata"
1921

2022
"github.com/wundergraph/astjson"
2123
"github.com/wundergraph/go-arena"
@@ -90,6 +92,8 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
9092
// It processes the input JSON data to make gRPC calls and returns
9193
// the response data.
9294
//
95+
// Headers are converted to gRPC metadata and part of gRPC calls.
96+
//
9397
// The input is expected to contain the necessary information to make
9498
// a gRPC call, including service name, method name, and request data.
9599
func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) {
@@ -111,6 +115,19 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte
111115
return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil
112116
}
113117

118+
// convert headers to grpc metadata and attach to ctx
119+
if len(headers) > 0 {
120+
// assume that each header has exactly one value for default pairs size
121+
pairs := make([]string, 0, len(headers)*2)
122+
for headerName, headerValues := range headers {
123+
headerName = strings.ToLower(headerName)
124+
for _, v := range headerValues {
125+
pairs = append(pairs, headerName, v)
126+
}
127+
}
128+
ctx = metadata.AppendToOutgoingContext(ctx, pairs...)
129+
}
130+
114131
graph := NewDependencyGraph(d.plan)
115132

116133
root := astjson.ObjectValue(nil)

v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import (
66
"fmt"
77
"math"
88
"net"
9+
"net/http"
910
"strings"
1011
"testing"
1112

1213
"github.com/stretchr/testify/require"
1314
"github.com/tidwall/gjson"
1415
"google.golang.org/grpc"
1516
"google.golang.org/grpc/credentials/insecure"
17+
"google.golang.org/grpc/metadata"
1618
"google.golang.org/grpc/test/bufconn"
1719
"google.golang.org/protobuf/encoding/protojson"
1820
protoref "google.golang.org/protobuf/reflect/protoreflect"
@@ -5237,3 +5239,276 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) {
52375239
})
52385240
}
52395241
}
5242+
5243+
func Test_Datasource_Load_WithHeaders(t *testing.T) {
5244+
conn, cleanup := setupTestGRPCServer(t)
5245+
t.Cleanup(cleanup)
5246+
5247+
type graphqlError struct {
5248+
Message string `json:"message"`
5249+
}
5250+
type graphqlResponse struct {
5251+
Data map[string]interface{} `json:"data"`
5252+
Errors []graphqlError `json:"errors,omitempty"`
5253+
}
5254+
5255+
testCases := []struct {
5256+
name string
5257+
query string
5258+
vars string
5259+
headers http.Header
5260+
validate func(t *testing.T, data map[string]interface{})
5261+
validateError func(t *testing.T, errData []graphqlError)
5262+
}{
5263+
{
5264+
name: "QueryUser with header override",
5265+
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
5266+
vars: `{"variables":{"id":"original-user-123"}}`,
5267+
headers: func() http.Header {
5268+
h := make(http.Header)
5269+
h.Set("X-User-ID", "header-user-42")
5270+
return h
5271+
}(),
5272+
validate: func(t *testing.T, data map[string]interface{}) {
5273+
user, ok := data["user"].(map[string]interface{})
5274+
require.True(t, ok, "user should be an object")
5275+
require.Equal(t, "header-user-42", user["id"], "user ID should come from header")
5276+
require.Equal(t, "User header-user-42", user["name"], "user name should use header-derived ID")
5277+
},
5278+
validateError: func(t *testing.T, errData []graphqlError) {
5279+
require.Empty(t, errData)
5280+
},
5281+
},
5282+
{
5283+
name: "QueryUser with header triggering error",
5284+
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
5285+
vars: `{"variables":{"id":"valid-user-123"}}`,
5286+
headers: func() http.Header {
5287+
h := make(http.Header)
5288+
h.Set("X-User-ID", "error-user")
5289+
return h
5290+
}(),
5291+
validate: func(t *testing.T, data map[string]interface{}) {
5292+
// Data might be present but should have errors
5293+
},
5294+
validateError: func(t *testing.T, errData []graphqlError) {
5295+
require.NotEmpty(t, errData, "should have errors")
5296+
require.Contains(t, errData[0].Message, "user not found: error-user")
5297+
},
5298+
},
5299+
{
5300+
name: "QueryUser without headers (nil) - baseline behavior",
5301+
query: `query UserQuery($id: ID!) { user(id: $id) { id name } }`,
5302+
vars: `{"variables":{"id":"baseline-user-99"}}`,
5303+
headers: nil,
5304+
validate: func(t *testing.T, data map[string]interface{}) {
5305+
user, ok := data["user"].(map[string]interface{})
5306+
require.True(t, ok, "user should be an object")
5307+
require.Equal(t, "baseline-user-99", user["id"], "user ID should come from query variable")
5308+
require.Equal(t, "User baseline-user-99", user["name"], "user name should use variable-derived ID")
5309+
},
5310+
validateError: func(t *testing.T, errData []graphqlError) {
5311+
require.Empty(t, errData)
5312+
},
5313+
},
5314+
{
5315+
name: "QueryUsers with custom prefix header",
5316+
query: `query UsersQuery { users { id name } }`,
5317+
vars: `{"variables":{}}`,
5318+
headers: func() http.Header {
5319+
h := make(http.Header)
5320+
h.Set("X-User-Prefix", "Admin")
5321+
return h
5322+
}(),
5323+
validate: func(t *testing.T, data map[string]interface{}) {
5324+
users, ok := data["users"].([]interface{})
5325+
require.True(t, ok, "users should be an array")
5326+
require.Len(t, users, 3, "should return 3 users")
5327+
5328+
for i, u := range users {
5329+
user, ok := u.(map[string]interface{})
5330+
require.True(t, ok, "each user should be an object")
5331+
require.Equal(t, fmt.Sprintf("user-%d", i+1), user["id"])
5332+
require.Equal(t, fmt.Sprintf("Admin %d", i+1), user["name"], "user name should use custom prefix from header")
5333+
}
5334+
},
5335+
validateError: func(t *testing.T, errData []graphqlError) {
5336+
require.Empty(t, errData)
5337+
},
5338+
},
5339+
{
5340+
name: "MutationCreateUser with name override header",
5341+
query: `mutation CreateUser($input: UserInput!) { createUser(input: $input) { id name } }`,
5342+
vars: `{"variables":{"input":{"name":"OriginalName"}}}`,
5343+
headers: func() http.Header {
5344+
h := make(http.Header)
5345+
h.Set("X-Custom-Name", "HeaderName")
5346+
return h
5347+
}(),
5348+
validate: func(t *testing.T, data map[string]interface{}) {
5349+
createUser, ok := data["createUser"].(map[string]interface{})
5350+
require.True(t, ok, "createUser should be an object")
5351+
require.NotEmpty(t, createUser["id"], "created user should have an ID")
5352+
require.Equal(t, "HeaderName", createUser["name"], "created user name should come from header")
5353+
},
5354+
validateError: func(t *testing.T, errData []graphqlError) {
5355+
require.Empty(t, errData)
5356+
},
5357+
},
5358+
{
5359+
name: "Categories with productCount field resolver and header offset",
5360+
query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`,
5361+
vars: `{"variables":{"filters":{"minPrice":100}}}`,
5362+
headers: func() http.Header {
5363+
h := make(http.Header)
5364+
h.Set("X-Count-Offset", "100")
5365+
return h
5366+
}(),
5367+
validate: func(t *testing.T, data map[string]interface{}) {
5368+
categories, ok := data["categories"].([]interface{})
5369+
require.True(t, ok, "categories should be an array")
5370+
require.Len(t, categories, 4, "should return 4 categories")
5371+
5372+
// Verify that productCount for each category is offset by 100
5373+
expectedCounts := []float64{100, 101, 102, 103}
5374+
for i, c := range categories {
5375+
category, ok := c.(map[string]interface{})
5376+
require.True(t, ok, "category should be an object")
5377+
require.NotEmpty(t, category["id"])
5378+
require.NotEmpty(t, category["name"])
5379+
require.Equal(t, expectedCounts[i], category["productCount"], "productCount should be offset by header value")
5380+
}
5381+
},
5382+
validateError: func(t *testing.T, errData []graphqlError) {
5383+
require.Empty(t, errData)
5384+
},
5385+
},
5386+
{
5387+
name: "Categories with productCount without headers - baseline behavior",
5388+
query: `query CategoriesWithProductCount($filters: ProductCountFilter) { categories { id name kind productCount(filters: $filters) } }`,
5389+
vars: `{"variables":{"filters":{"minPrice":100}}}`,
5390+
headers: nil,
5391+
validate: func(t *testing.T, data map[string]interface{}) {
5392+
categories, ok := data["categories"].([]interface{})
5393+
require.True(t, ok, "categories should be an array")
5394+
require.Len(t, categories, 4, "should return 4 categories")
5395+
5396+
// Verify default productCount values (no offset)
5397+
expectedCounts := []float64{0, 1, 2, 3}
5398+
for i, c := range categories {
5399+
category, ok := c.(map[string]interface{})
5400+
require.True(t, ok, "category should be an object")
5401+
require.NotEmpty(t, category["id"])
5402+
require.NotEmpty(t, category["name"])
5403+
require.Equal(t, expectedCounts[i], category["productCount"], "productCount should use default values without header")
5404+
}
5405+
},
5406+
validateError: func(t *testing.T, errData []graphqlError) {
5407+
require.Empty(t, errData)
5408+
},
5409+
},
5410+
}
5411+
5412+
for _, tc := range testCases {
5413+
t.Run(tc.name, func(t *testing.T) {
5414+
// Parse the GraphQL schema
5415+
schemaDoc := grpctest.MustGraphQLSchema(t)
5416+
5417+
// Parse the GraphQL query
5418+
queryDoc, report := astparser.ParseGraphqlDocumentString(tc.query)
5419+
require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error())
5420+
5421+
compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping())
5422+
require.NoError(t, err)
5423+
5424+
// Create the datasource
5425+
ds, err := NewDataSource(conn, DataSourceConfig{
5426+
Operation: &queryDoc,
5427+
Definition: &schemaDoc,
5428+
SubgraphName: "Products",
5429+
Mapping: testMapping(),
5430+
Compiler: compiler,
5431+
})
5432+
require.NoError(t, err)
5433+
5434+
// Execute the query with headers
5435+
input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars)
5436+
output, err := ds.Load(context.Background(), tc.headers, []byte(input))
5437+
require.NoError(t, err)
5438+
5439+
// Parse the response
5440+
var resp graphqlResponse
5441+
err = json.Unmarshal(output, &resp)
5442+
require.NoError(t, err, "Failed to unmarshal response")
5443+
5444+
tc.validate(t, resp.Data)
5445+
tc.validateError(t, resp.Errors)
5446+
})
5447+
}
5448+
}
5449+
5450+
func Test_Datasource_Load_PreservesExistingContextMetadata(t *testing.T) {
5451+
conn, cleanup := setupTestGRPCServer(t)
5452+
t.Cleanup(cleanup)
5453+
5454+
type graphqlError struct {
5455+
Message string `json:"message"`
5456+
}
5457+
type graphqlResponse struct {
5458+
Data map[string]interface{} `json:"data"`
5459+
Errors []graphqlError `json:"errors,omitempty"`
5460+
}
5461+
5462+
// Parse the GraphQL schema
5463+
schemaDoc := grpctest.MustGraphQLSchema(t)
5464+
5465+
query := `query UserQuery($id: ID!) { user(id: $id) { id name } }`
5466+
vars := `{"variables":{"id":"test-user-123"}}`
5467+
5468+
// Parse the GraphQL query
5469+
queryDoc, report := astparser.ParseGraphqlDocumentString(query)
5470+
require.False(t, report.HasErrors(), "failed to parse query: %s", report.Error())
5471+
5472+
compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping())
5473+
require.NoError(t, err)
5474+
5475+
// Create the datasource
5476+
ds, err := NewDataSource(conn, DataSourceConfig{
5477+
Operation: &queryDoc,
5478+
Definition: &schemaDoc,
5479+
SubgraphName: "Products",
5480+
Mapping: testMapping(),
5481+
Compiler: compiler,
5482+
})
5483+
require.NoError(t, err)
5484+
5485+
// Create a context with existing metadata
5486+
ctx := metadata.NewOutgoingContext(
5487+
context.Background(),
5488+
metadata.Pairs("x-existing-key", "existing-value"),
5489+
)
5490+
5491+
// Create HTTP headers to be forwarded
5492+
headers := make(http.Header)
5493+
headers.Set("X-User-ID", "header-user-456")
5494+
5495+
// Execute the query with both existing context metadata and new HTTP headers
5496+
input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, vars)
5497+
output, err := ds.Load(ctx, headers, []byte(input))
5498+
require.NoError(t, err)
5499+
5500+
// Parse the response
5501+
var resp graphqlResponse
5502+
err = json.Unmarshal(output, &resp)
5503+
require.NoError(t, err, "Failed to unmarshal response")
5504+
5505+
// Verify no errors
5506+
require.Empty(t, resp.Errors, "Should not have GraphQL errors")
5507+
5508+
// Verify the response includes both the header-derived ID and the existing metadata value
5509+
user, ok := resp.Data["user"].(map[string]interface{})
5510+
require.True(t, ok, "user should be an object")
5511+
require.Equal(t, "header-user-456", user["id"], "user ID should come from HTTP header")
5512+
require.Equal(t, "User header-user-456 (existing: existing-value)", user["name"],
5513+
"user name should include both header-derived ID and existing context metadata")
5514+
}

0 commit comments

Comments
 (0)