diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 07e371ddd9..6063c42b19 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -228,6 +228,7 @@ var brokenAuthHeaderDomains = []string{ // connectorData stores information for sessions authenticated by this connector type connectorData struct { RefreshToken []byte + IDToken []byte // raw upstream id_token JWT for RP-Initiated logout } // Detect auth header provider issues for known providers. This lets users @@ -736,6 +737,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I cd := connectorData{ RefreshToken: []byte(token.RefreshToken), } + if rawIDToken, ok := token.Extra("id_token").(string); ok { + cd.IDToken = []byte(rawIDToken) + } connData, err := json.Marshal(&cd) if err != nil { @@ -766,7 +770,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I // LogoutURL returns the upstream OIDC provider's end_session_endpoint URL. // Per the OIDC RP-Initiated Logout spec, the post_logout_redirect_uri parameter // tells the upstream where to redirect after logout. -func (c *oidcConnector) LogoutURL(_ context.Context, _ []byte, postLogoutRedirectURI string) (string, error) { +func (c *oidcConnector) LogoutURL(_ context.Context, rawConnectorData []byte, postLogoutRedirectURI string) (string, error) { if c.endSessionURL == "" { return "", nil } @@ -781,6 +785,16 @@ func (c *oidcConnector) LogoutURL(_ context.Context, _ []byte, postLogoutRedirec q.Set("post_logout_redirect_uri", postLogoutRedirectURI) q.Set("client_id", c.oauth2Config.ClientID) } + // Per the RP-Initiated Logout spec, id_token_hint is independently valid + // of post_logout_redirect_uri — include it whenever we have one. + if len(rawConnectorData) > 0 { + var cd connectorData + if err := json.Unmarshal(rawConnectorData, &cd); err == nil { + if len(cd.IDToken) > 0 { + q.Set("id_token_hint", string(cd.IDToken)) + } + } + } u.RawQuery = q.Encode() return u.String(), nil diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 77473b0d9f..29ec09f021 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -979,10 +979,22 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) { } func TestLogoutURL(t *testing.T) { + idTokenConnData, err := json.Marshal(connectorData{ + RefreshToken: []byte("refresh"), + IDToken: []byte("id-token-jwt"), + }) + require.NoError(t, err) + + noIDTokenConnData, err := json.Marshal(connectorData{ + RefreshToken: []byte("refresh"), + }) + require.NoError(t, err) + tests := []struct { name string endSessionURL string postLogoutRedirectURI string + connectorData []byte wantURL string wantEmpty bool }{ @@ -1008,6 +1020,33 @@ func TestLogoutURL(t *testing.T) { postLogoutRedirectURI: "https://dex.example.com/callback", wantURL: "https://provider.example.com/logout?client_id=clientID&existing=param&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Fcallback", }, + { + name: "with id_token_hint from connector data", + endSessionURL: "https://provider.example.com/logout", + postLogoutRedirectURI: "https://dex.example.com/logout/callback", + connectorData: idTokenConnData, + wantURL: "https://provider.example.com/logout?client_id=clientID&id_token_hint=id-token-jwt&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback", + }, + { + name: "id_token_hint included without post_logout_redirect_uri", + endSessionURL: "https://provider.example.com/logout", + connectorData: idTokenConnData, + wantURL: "https://provider.example.com/logout?id_token_hint=id-token-jwt", + }, + { + name: "connector data without IDToken omits id_token_hint", + endSessionURL: "https://provider.example.com/logout", + postLogoutRedirectURI: "https://dex.example.com/logout/callback", + connectorData: noIDTokenConnData, + wantURL: "https://provider.example.com/logout?client_id=clientID&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback", + }, + { + name: "malformed connector data is ignored", + endSessionURL: "https://provider.example.com/logout", + postLogoutRedirectURI: "https://dex.example.com/logout/callback", + connectorData: []byte("not-json"), + wantURL: "https://provider.example.com/logout?client_id=clientID&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback", + }, } for _, tc := range tests { @@ -1019,7 +1058,7 @@ func TestLogoutURL(t *testing.T) { }, } - got, err := conn.LogoutURL(context.Background(), nil, tc.postLogoutRedirectURI) + got, err := conn.LogoutURL(context.Background(), tc.connectorData, tc.postLogoutRedirectURI) require.NoError(t, err) if tc.wantEmpty { diff --git a/server/logout.go b/server/logout.go index d48f7a59fe..12f5aa0704 100644 --- a/server/logout.go +++ b/server/logout.go @@ -251,13 +251,21 @@ func (s *Server) tryUpstreamLogout(ctx context.Context, userID, connectorID stri } // Check that the session exists — we need it to store logout state. - _, err = s.storage.GetAuthSession(ctx, userID, connectorID) + session, err := s.storage.GetAuthSession(ctx, userID, connectorID) if err != nil { s.logger.DebugContext(ctx, "logout: no auth session for upstream logout, skipping", "user_id", userID, "connector_id", connectorID) return "", false } + // The auth session connector data should keep an id_token that will be used as hint for RP-Initiated logout + if len(session.ConnectorData) > 0 { + connectorData = session.ConnectorData + s.logger.DebugContext(ctx, "logout: using auth_session.ConnectorData", "connector_id", connectorID) + } else if len(connectorData) == 0 { + s.logger.DebugContext(ctx, "logout: no connector data available", "connector_id", connectorID) + } + // Store logout parameters in the session. if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { old.LogoutState = &storage.LogoutState{ diff --git a/server/logout_test.go b/server/logout_test.go index 0554cb0ff1..05ce93ef3f 100644 --- a/server/logout_test.go +++ b/server/logout_test.go @@ -12,9 +12,29 @@ import ( "github.com/stretchr/testify/require" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/storage" ) +// recordingLogoutConnector implements connector.LogoutCallbackConnector and +// records the connectorData it was invoked with so tests can assert what was +// passed down. +type recordingLogoutConnector struct { + gotConnectorData []byte + returnURL string +} + +func (c *recordingLogoutConnector) LogoutURL(_ context.Context, connectorData []byte, _ string) (string, error) { + c.gotConnectorData = connectorData + return c.returnURL, nil +} + +func (c *recordingLogoutConnector) HandleLogoutCallback(_ context.Context, _ *http.Request) error { + return nil +} + +var _ connector.LogoutCallbackConnector = (*recordingLogoutConnector)(nil) + func TestHandleLogoutNoSessions(t *testing.T) { httpServer, server := newTestServer(t, nil) defer httpServer.Close() @@ -380,3 +400,68 @@ func TestRevokeRefreshTokensReturnsConnectorData(t *testing.T) { require.Empty(t, os.Refresh) require.Equal(t, expectedConnData, os.ConnectorData) } + +// TestTryUpstreamLogoutPrefersSessionConnectorData verifies that when the auth +// session has ConnectorData stored (from login), it takes precedence over the +// connectorData the caller passes in (which originates from the offline session). +func TestTryUpstreamLogoutPrefersSessionConnectorData(t *testing.T) { + tests := []struct { + name string + sessionConnData []byte + callerConnData []byte + wantConnData []byte + }{ + { + name: "session data wins over caller data", + sessionConnData: []byte(`{"IDToken":"session-token"}`), + callerConnData: []byte(`{"IDToken":"caller-token"}`), + wantConnData: []byte(`{"IDToken":"session-token"}`), + }, + { + name: "caller data used when session data is empty", + sessionConnData: nil, + callerConnData: []byte(`{"IDToken":"caller-token"}`), + wantConnData: []byte(`{"IDToken":"caller-token"}`), + }, + { + name: "empty when neither source has data", + sessionConnData: nil, + callerConnData: nil, + wantConnData: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + httpServer, server := newTestServerWithSessions(t, nil) + defer httpServer.Close() + + ctx := t.Context() + userID := "test-user" + connectorID := "mock" + + // Inject a recording connector with matching ResourceVersion so + // getConnector returns our mock instead of re-opening from storage. + rec := &recordingLogoutConnector{returnURL: "https://upstream.example.com/logout"} + server.mu.Lock() + server.connectors[connectorID] = Connector{ + Type: "mockCallback", + ResourceVersion: "1", + Connector: rec, + } + server.mu.Unlock() + + require.NoError(t, server.storage.CreateAuthSession(ctx, storage.AuthSession{ + UserID: userID, ConnectorID: connectorID, Nonce: "nonce", + CreatedAt: time.Now(), LastActivity: time.Now(), + ConnectorData: tc.sessionConnData, + })) + + redirectURL, ok := server.tryUpstreamLogout(ctx, userID, connectorID, tc.callerConnData, + "https://dex.example.com/cb", "state-123", "client-123") + require.True(t, ok) + require.Equal(t, "https://upstream.example.com/logout", redirectURL) + require.Equal(t, tc.wantConnData, rec.gotConnectorData) + }) + } +} diff --git a/server/session.go b/server/session.go index b5adbeb900..87795893ce 100644 --- a/server/session.go +++ b/server/session.go @@ -261,6 +261,7 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request, old.ClientStates = make(map[string]*storage.ClientAuthState) } old.ClientStates[authReq.ClientID] = clientState + old.ConnectorData = authReq.ConnectorData return old, nil }); err != nil { return fmt.Errorf("update auth session: %w", err) @@ -289,6 +290,7 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request, UserAgent: r.UserAgent(), AbsoluteExpiry: now.Add(s.sessionConfig.AbsoluteLifetime), IdleExpiry: now.Add(s.sessionConfig.ValidIfNotUsedFor), + ConnectorData: authReq.ConnectorData, } if err := s.storage.CreateAuthSession(ctx, newSession); err != nil { diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 33ec950cfd..1810421b76 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -2,6 +2,7 @@ package conformance import ( + "bytes" "crypto/ecdsa" "reflect" "sort" @@ -1388,6 +1389,7 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { UserAgent: "TestBrowser/1.0", AbsoluteExpiry: now.Add(24 * time.Hour), IdleExpiry: now.Add(1 * time.Hour), + ConnectorData: []byte(`{"RefreshToken":"dGVzdA==","IDToken":"ZXlKaGJHY21PaUpTVXpJMU5pSjk="}`), } // Create. @@ -1418,8 +1420,9 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { t.Errorf("auth session retrieved from storage did not match: %s", diff) } - // Update: add a new client state. + // Update: add a new client state and rotate connector data. newNow := now.Add(time.Minute) + updatedConnectorData := []byte(`{"RefreshToken":"bmV3","IDToken":"bmV3LWlk"}`) if err := s.UpdateAuthSession(ctx, session.UserID, session.ConnectorID, func(old storage.AuthSession) (storage.AuthSession, error) { old.ClientStates["client2"] = &storage.ClientAuthState{ Active: true, @@ -1427,6 +1430,7 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { LastActivity: newNow, } old.LastActivity = newNow + old.ConnectorData = updatedConnectorData return old, nil }); err != nil { t.Fatalf("update auth session: %v", err) @@ -1443,6 +1447,9 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { if got.ClientStates["client2"] == nil { t.Fatal("expected client2 state to exist") } + if !bytes.Equal(got.ConnectorData, updatedConnectorData) { + t.Fatalf("expected updated connector data %q, got %q", updatedConnectorData, got.ConnectorData) + } // List and verify. sessions, err := s.ListAuthSessions(ctx) diff --git a/storage/ent/client/authsession.go b/storage/ent/client/authsession.go index b4cdfe8147..5ed9dc86b3 100644 --- a/storage/ent/client/authsession.go +++ b/storage/ent/client/authsession.go @@ -31,6 +31,7 @@ func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSe SetUserAgent(session.UserAgent). SetAbsoluteExpiry(session.AbsoluteExpiry.UTC()). SetIdleExpiry(session.IdleExpiry.UTC()). + SetConnectorData(session.ConnectorData). Save(ctx) if err != nil { return convertDBError("create auth session: %w", err) @@ -106,6 +107,7 @@ func (d *Database) UpdateAuthSession(ctx context.Context, userID, connectorID st SetUserAgent(newSession.UserAgent). SetAbsoluteExpiry(newSession.AbsoluteExpiry.UTC()). SetIdleExpiry(newSession.IdleExpiry.UTC()). + SetConnectorData(newSession.ConnectorData). Save(ctx) if err != nil { return rollback(tx, "update auth session updating: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 4a6e2bc740..dc40e27135 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -244,6 +244,10 @@ func toStorageAuthSession(s *db.AuthSession) storage.AuthSession { IdleExpiry: s.IdleExpiry, } + if s.ConnectorData != nil { + result.ConnectorData = *s.ConnectorData + } + if s.ClientStates != nil { if err := json.Unmarshal(s.ClientStates, &result.ClientStates); err != nil { panic(err) diff --git a/storage/ent/db/authsession.go b/storage/ent/db/authsession.go index 6ced0680fc..a20a49fd35 100644 --- a/storage/ent/db/authsession.go +++ b/storage/ent/db/authsession.go @@ -36,8 +36,10 @@ type AuthSession struct { // AbsoluteExpiry holds the value of the "absolute_expiry" field. AbsoluteExpiry time.Time `json:"absolute_expiry,omitempty"` // IdleExpiry holds the value of the "idle_expiry" field. - IdleExpiry time.Time `json:"idle_expiry,omitempty"` - selectValues sql.SelectValues + IdleExpiry time.Time `json:"idle_expiry,omitempty"` + // ConnectorData holds the value of the "connector_data" field. + ConnectorData *[]byte `json:"connector_data,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -45,7 +47,7 @@ func (*AuthSession) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case authsession.FieldClientStates: + case authsession.FieldClientStates, authsession.FieldConnectorData: values[i] = new([]byte) case authsession.FieldID, authsession.FieldUserID, authsession.FieldConnectorID, authsession.FieldNonce, authsession.FieldIPAddress, authsession.FieldUserAgent: values[i] = new(sql.NullString) @@ -132,6 +134,12 @@ func (_m *AuthSession) assignValues(columns []string, values []any) error { } else if value.Valid { _m.IdleExpiry = value.Time } + case authsession.FieldConnectorData: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field connector_data", values[i]) + } else if value != nil { + _m.ConnectorData = value + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -197,6 +205,11 @@ func (_m *AuthSession) String() string { builder.WriteString(", ") builder.WriteString("idle_expiry=") builder.WriteString(_m.IdleExpiry.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConnectorData; v != nil { + builder.WriteString("connector_data=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authsession/authsession.go b/storage/ent/db/authsession/authsession.go index fc1cd5ed3b..6315f63d08 100644 --- a/storage/ent/db/authsession/authsession.go +++ b/storage/ent/db/authsession/authsession.go @@ -31,6 +31,8 @@ const ( FieldAbsoluteExpiry = "absolute_expiry" // FieldIdleExpiry holds the string denoting the idle_expiry field in the database. FieldIdleExpiry = "idle_expiry" + // FieldConnectorData holds the string denoting the connector_data field in the database. + FieldConnectorData = "connector_data" // Table holds the table name of the authsession in the database. Table = "auth_sessions" ) @@ -48,6 +50,7 @@ var Columns = []string{ FieldUserAgent, FieldAbsoluteExpiry, FieldIdleExpiry, + FieldConnectorData, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/storage/ent/db/authsession/where.go b/storage/ent/db/authsession/where.go index 193f1133e5..7ef61428f1 100644 --- a/storage/ent/db/authsession/where.go +++ b/storage/ent/db/authsession/where.go @@ -114,6 +114,11 @@ func IdleExpiry(v time.Time) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldIdleExpiry, v)) } +// ConnectorData applies equality check predicate on the "connector_data" field. It's identical to ConnectorDataEQ. +func ConnectorData(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldConnectorData, v)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v string) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldUserID, v)) @@ -639,6 +644,56 @@ func IdleExpiryLTE(v time.Time) predicate.AuthSession { return predicate.AuthSession(sql.FieldLTE(FieldIdleExpiry, v)) } +// ConnectorDataEQ applies the EQ predicate on the "connector_data" field. +func ConnectorDataEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldConnectorData, v)) +} + +// ConnectorDataNEQ applies the NEQ predicate on the "connector_data" field. +func ConnectorDataNEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldConnectorData, v)) +} + +// ConnectorDataIn applies the In predicate on the "connector_data" field. +func ConnectorDataIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldConnectorData, vs...)) +} + +// ConnectorDataNotIn applies the NotIn predicate on the "connector_data" field. +func ConnectorDataNotIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldConnectorData, vs...)) +} + +// ConnectorDataGT applies the GT predicate on the "connector_data" field. +func ConnectorDataGT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldConnectorData, v)) +} + +// ConnectorDataGTE applies the GTE predicate on the "connector_data" field. +func ConnectorDataGTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldConnectorData, v)) +} + +// ConnectorDataLT applies the LT predicate on the "connector_data" field. +func ConnectorDataLT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldConnectorData, v)) +} + +// ConnectorDataLTE applies the LTE predicate on the "connector_data" field. +func ConnectorDataLTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldConnectorData, v)) +} + +// ConnectorDataIsNil applies the IsNil predicate on the "connector_data" field. +func ConnectorDataIsNil() predicate.AuthSession { + return predicate.AuthSession(sql.FieldIsNull(FieldConnectorData)) +} + +// ConnectorDataNotNil applies the NotNil predicate on the "connector_data" field. +func ConnectorDataNotNil() predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotNull(FieldConnectorData)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthSession) predicate.AuthSession { return predicate.AuthSession(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authsession_create.go b/storage/ent/db/authsession_create.go index 0dc99e7615..db8e36f29e 100644 --- a/storage/ent/db/authsession_create.go +++ b/storage/ent/db/authsession_create.go @@ -96,6 +96,12 @@ func (_c *AuthSessionCreate) SetIdleExpiry(v time.Time) *AuthSessionCreate { return _c } +// SetConnectorData sets the "connector_data" field. +func (_c *AuthSessionCreate) SetConnectorData(v []byte) *AuthSessionCreate { + _c.mutation.SetConnectorData(v) + return _c +} + // SetID sets the "id" field. func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate { _c.mutation.SetID(v) @@ -274,6 +280,10 @@ func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) _node.IdleExpiry = value } + if value, ok := _c.mutation.ConnectorData(); ok { + _spec.SetField(authsession.FieldConnectorData, field.TypeBytes, value) + _node.ConnectorData = &value + } return _node, _spec } diff --git a/storage/ent/db/authsession_update.go b/storage/ent/db/authsession_update.go index d80e682b91..5af22877c1 100644 --- a/storage/ent/db/authsession_update.go +++ b/storage/ent/db/authsession_update.go @@ -160,6 +160,18 @@ func (_u *AuthSessionUpdate) SetNillableIdleExpiry(v *time.Time) *AuthSessionUpd return _u } +// SetConnectorData sets the "connector_data" field. +func (_u *AuthSessionUpdate) SetConnectorData(v []byte) *AuthSessionUpdate { + _u.mutation.SetConnectorData(v) + return _u +} + +// ClearConnectorData clears the value of the "connector_data" field. +func (_u *AuthSessionUpdate) ClearConnectorData() *AuthSessionUpdate { + _u.mutation.ClearConnectorData() + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation { return _u.mutation @@ -254,6 +266,12 @@ func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) if value, ok := _u.mutation.IdleExpiry(); ok { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) } + if value, ok := _u.mutation.ConnectorData(); ok { + _spec.SetField(authsession.FieldConnectorData, field.TypeBytes, value) + } + if _u.mutation.ConnectorDataCleared() { + _spec.ClearField(authsession.FieldConnectorData, field.TypeBytes) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authsession.Label} @@ -406,6 +424,18 @@ func (_u *AuthSessionUpdateOne) SetNillableIdleExpiry(v *time.Time) *AuthSession return _u } +// SetConnectorData sets the "connector_data" field. +func (_u *AuthSessionUpdateOne) SetConnectorData(v []byte) *AuthSessionUpdateOne { + _u.mutation.SetConnectorData(v) + return _u +} + +// ClearConnectorData clears the value of the "connector_data" field. +func (_u *AuthSessionUpdateOne) ClearConnectorData() *AuthSessionUpdateOne { + _u.mutation.ClearConnectorData() + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation { return _u.mutation @@ -530,6 +560,12 @@ func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession if value, ok := _u.mutation.IdleExpiry(); ok { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) } + if value, ok := _u.mutation.ConnectorData(); ok { + _spec.SetField(authsession.FieldConnectorData, field.TypeBytes, value) + } + if _u.mutation.ConnectorDataCleared() { + _spec.ClearField(authsession.FieldConnectorData, field.TypeBytes) + } _node = &AuthSession{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index a6050cb333..f66b1254d4 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -82,6 +82,7 @@ var ( {Name: "user_agent", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "absolute_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "idle_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, } // AuthSessionsTable holds the schema information for the "auth_sessions" table. AuthSessionsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index a21c65765c..49a651b2ce 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -3154,6 +3154,7 @@ type AuthSessionMutation struct { user_agent *string absolute_expiry *time.Time idle_expiry *time.Time + connector_data *[]byte clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthSession, error) @@ -3624,6 +3625,55 @@ func (m *AuthSessionMutation) ResetIdleExpiry() { m.idle_expiry = nil } +// SetConnectorData sets the "connector_data" field. +func (m *AuthSessionMutation) SetConnectorData(b []byte) { + m.connector_data = &b +} + +// ConnectorData returns the value of the "connector_data" field in the mutation. +func (m *AuthSessionMutation) ConnectorData() (r []byte, exists bool) { + v := m.connector_data + if v == nil { + return + } + return *v, true +} + +// OldConnectorData returns the old "connector_data" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldConnectorData(ctx context.Context) (v *[]byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConnectorData is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConnectorData requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConnectorData: %w", err) + } + return oldValue.ConnectorData, nil +} + +// ClearConnectorData clears the value of the "connector_data" field. +func (m *AuthSessionMutation) ClearConnectorData() { + m.connector_data = nil + m.clearedFields[authsession.FieldConnectorData] = struct{}{} +} + +// ConnectorDataCleared returns if the "connector_data" field was cleared in this mutation. +func (m *AuthSessionMutation) ConnectorDataCleared() bool { + _, ok := m.clearedFields[authsession.FieldConnectorData] + return ok +} + +// ResetConnectorData resets all changes to the "connector_data" field. +func (m *AuthSessionMutation) ResetConnectorData() { + m.connector_data = nil + delete(m.clearedFields, authsession.FieldConnectorData) +} + // Where appends a list predicates to the AuthSessionMutation builder. func (m *AuthSessionMutation) Where(ps ...predicate.AuthSession) { m.predicates = append(m.predicates, ps...) @@ -3658,7 +3708,7 @@ func (m *AuthSessionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthSessionMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.user_id != nil { fields = append(fields, authsession.FieldUserID) } @@ -3689,6 +3739,9 @@ func (m *AuthSessionMutation) Fields() []string { if m.idle_expiry != nil { fields = append(fields, authsession.FieldIdleExpiry) } + if m.connector_data != nil { + fields = append(fields, authsession.FieldConnectorData) + } return fields } @@ -3717,6 +3770,8 @@ func (m *AuthSessionMutation) Field(name string) (ent.Value, bool) { return m.AbsoluteExpiry() case authsession.FieldIdleExpiry: return m.IdleExpiry() + case authsession.FieldConnectorData: + return m.ConnectorData() } return nil, false } @@ -3746,6 +3801,8 @@ func (m *AuthSessionMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldAbsoluteExpiry(ctx) case authsession.FieldIdleExpiry: return m.OldIdleExpiry(ctx) + case authsession.FieldConnectorData: + return m.OldConnectorData(ctx) } return nil, fmt.Errorf("unknown AuthSession field %s", name) } @@ -3825,6 +3882,13 @@ func (m *AuthSessionMutation) SetField(name string, value ent.Value) error { } m.SetIdleExpiry(v) return nil + case authsession.FieldConnectorData: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConnectorData(v) + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } @@ -3854,7 +3918,11 @@ func (m *AuthSessionMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *AuthSessionMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(authsession.FieldConnectorData) { + fields = append(fields, authsession.FieldConnectorData) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -3867,6 +3935,11 @@ func (m *AuthSessionMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *AuthSessionMutation) ClearField(name string) error { + switch name { + case authsession.FieldConnectorData: + m.ClearConnectorData() + return nil + } return fmt.Errorf("unknown AuthSession nullable field %s", name) } @@ -3904,6 +3977,9 @@ func (m *AuthSessionMutation) ResetField(name string) error { case authsession.FieldIdleExpiry: m.ResetIdleExpiry() return nil + case authsession.FieldConnectorData: + m.ResetConnectorData() + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } diff --git a/storage/ent/schema/authsession.go b/storage/ent/schema/authsession.go index 0b641b7f7a..aaa8ec3678 100644 --- a/storage/ent/schema/authsession.go +++ b/storage/ent/schema/authsession.go @@ -41,6 +41,9 @@ func (AuthSession) Fields() []ent.Field { SchemaType(timeSchema), field.Time("idle_expiry"). SchemaType(timeSchema), + field.Bytes("connector_data"). + Nillable(). + Optional(), } } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index aabb16f6c1..8bf5eff71b 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -336,6 +336,7 @@ type AuthSession struct { UserAgent string `json:"user_agent,omitempty"` AbsoluteExpiry time.Time `json:"absolute_expiry"` IdleExpiry time.Time `json:"idle_expiry"` + ConnectorData []byte `json:"connector_data,omitempty"` } func fromStorageAuthSession(s storage.AuthSession) AuthSession { @@ -350,6 +351,7 @@ func fromStorageAuthSession(s storage.AuthSession) AuthSession { UserAgent: s.UserAgent, AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, + ConnectorData: s.ConnectorData, } } @@ -365,6 +367,7 @@ func toStorageAuthSession(s AuthSession) storage.AuthSession { UserAgent: s.UserAgent, AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, + ConnectorData: s.ConnectorData, } if result.ClientStates == nil { result.ClientStates = make(map[string]*storage.ClientAuthState) diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index d936865ad8..b03f7f2475 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -1022,6 +1022,7 @@ type AuthSession struct { AbsoluteExpiry time.Time `json:"absoluteExpiry,omitempty"` IdleExpiry time.Time `json:"idleExpiry,omitempty"` LogoutState *storage.LogoutState `json:"logoutState,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` } // AuthSessionList is a list of AuthSessions. @@ -1052,6 +1053,7 @@ func (cli *client) fromStorageAuthSession(s storage.AuthSession) AuthSession { AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, LogoutState: s.LogoutState, + ConnectorData: s.ConnectorData, } } @@ -1068,6 +1070,7 @@ func toStorageAuthSession(s AuthSession) storage.AuthSession { AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, LogoutState: s.LogoutState, + ConnectorData: s.ConnectorData, } if result.ClientStates == nil { result.ClientStates = make(map[string]*storage.ClientAuthState) diff --git a/storage/sql/crud.go b/storage/sql/crud.go index a8eaf2fd4b..5ef6dddfed 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -998,15 +998,17 @@ func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) err created_at, last_activity, ip_address, user_agent, absolute_expiry, idle_expiry, + connector_data, logout_state ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12); `, s.UserID, s.ConnectorID, s.Nonce, encoder(s.ClientStates), s.CreatedAt, s.LastActivity, s.IPAddress, s.UserAgent, s.AbsoluteExpiry, s.IdleExpiry, + s.ConnectorData, encoder(s.LogoutState), ) if err != nil { @@ -1036,12 +1038,14 @@ func (c *conn) UpdateAuthSession(ctx context.Context, userID, connectorID string last_activity = $2, ip_address = $3, user_agent = $4, - logout_state = $5 - where user_id = $6 AND connector_id = $7; + connector_data = $5, + logout_state = $6 + where user_id = $7 AND connector_id = $8; `, encoder(newSession.ClientStates), newSession.LastActivity, newSession.IPAddress, newSession.UserAgent, + newSession.ConnectorData, encoder(newSession.LogoutState), userID, connectorID, ) @@ -1062,6 +1066,7 @@ const authSessionColumns = ` created_at, last_activity, ip_address, user_agent, absolute_expiry, idle_expiry, + connector_data, logout_state ` @@ -1081,6 +1086,7 @@ func scanAuthSession(s scanner) (session storage.AuthSession, err error) { &session.CreatedAt, &session.LastActivity, &session.IPAddress, &session.UserAgent, &session.AbsoluteExpiry, &session.IdleExpiry, + &session.ConnectorData, &logoutState, ) if err != nil { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index d9b3dbed5c..caa488ac72 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -467,4 +467,10 @@ var migrations = []migration{ add column sso_shared_with bytea;`, }, }, + { + stmts: []string{ + `alter table auth_session + add column connector_data bytea;`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 6d5f40b427..2a99a9bf45 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -455,6 +455,10 @@ type AuthSession struct { // upstream provider. The callback handler reads it back to complete the flow. // Nil when no logout is in progress. LogoutState *LogoutState + + // Connector data is set during login, meant to store information from the + // upstream OIDC connector to be used later on logout (id_token) + ConnectorData []byte } // OfflineSessions objects are sessions pertaining to users with refresh tokens.