diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 48fc39f130..7d3c7d4cd5 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -107,6 +107,10 @@ func newInternalServerError() *refreshError { return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} } +func newUpstreamRefreshError() *refreshError { + return &refreshError{msg: errInvalidGrant, desc: "Upstream identity provider refresh failed.", code: http.StatusBadGateway} +} + func newBadRequestError(desc string) *refreshError { return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} } @@ -271,7 +275,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err) - return ident, newInternalServerError() + return ident, newUpstreamRefreshError() } return newIdent, nil @@ -308,7 +312,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr } // updateRefreshToken updates refresh token and offline session in the storage -func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) { +func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext, userIdent *storage.UserIdentity) (*internal.RefreshToken, connector.Identity, *refreshError) { var rerr *refreshError newToken := &internal.RefreshToken{ @@ -327,6 +331,12 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Groups: rCtx.storageToken.Claims.Groups, } + // When sessions are enabled, downstream token refresh is disconnected from the upstream + // identity provider. Instead of calling the connector's Refresh method (which would contact + // the upstream IdP and may fail if the upstream refresh token has expired), we use the claims + // stored in UserIdentity at the time of the last interactive login. This aligns with the + // behavior of other identity brokers (e.g., Keycloak, Auth0) that treat downstream sessions + // independently from the upstream provider session lifetime. refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { rotationEnabled := s.refreshTokenPolicy.RotationEnabled() reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) @@ -371,9 +381,21 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // Call only once if there is a request which is not in the reuse interval. // This is required to avoid multiple calls to the external IdP for concurrent requests. // Dex will call the connector's Refresh method only once if request is not in reuse interval. - ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) - if rerr != nil { - return old, rerr + // When sessions are enabled, use cached identity instead of refreshing upstream. + if userIdent != nil { + ident = connector.Identity{ + UserID: userIdent.Claims.UserID, + Username: userIdent.Claims.Username, + PreferredUsername: userIdent.Claims.PreferredUsername, + Email: userIdent.Claims.Email, + EmailVerified: userIdent.Claims.EmailVerified, + Groups: userIdent.Claims.Groups, + } + } else { + ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) + if rerr != nil { + return old, rerr + } } // Update the claims of the refresh token. @@ -424,7 +446,24 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx) + var userIdent *storage.UserIdentity + + if s.sessionConfig != nil { + ui, err := s.storage.GetUserIdentity(r.Context(), rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + userIdent = &ui + } + + authTime := time.Time{} + if userIdent != nil { + authTime = userIdent.LastLogin + } + + newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx, userIdent) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -439,17 +478,6 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - authTime := time.Time{} - if s.sessionConfig != nil { - ui, err := s.storage.GetUserIdentity(r.Context(), ident.UserID, rCtx.storageToken.ConnectorID) - if err != nil { - s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err) - s.refreshTokenErrHelper(w, newInternalServerError()) - return - } - authTime = ui.LastLogin - } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, authTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 8db80c31eb..e870da541f 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -2,8 +2,10 @@ package server import ( "bytes" + "context" "encoding/base64" "encoding/json" + "errors" "log/slog" "net/http" "net/http/httptest" @@ -16,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -276,17 +279,17 @@ func TestRefreshTokenAuthTime(t *testing.T) { mockRefreshTokenTestStorage(t, s.storage, false) if tc.createUserIdentity { - // The mock connector returns UserID "0-385-28089-0" on Refresh, - // so the UserIdentity must use that ID to be found by handleRefreshToken. + // UserIdentity must match the refresh token's Claims.UserID ("1") + // because updateRefreshToken looks it up by that ID. err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ - UserID: "0-385-28089-0", + UserID: "1", ConnectorID: "test", Claims: storage.Claims{ - UserID: "0-385-28089-0", - Username: "Kilgore Trout", - Email: "kilgore@kilgore.trout", + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", EmailVerified: true, - Groups: []string{"authors"}, + Groups: []string{"a", "b"}, }, CreatedAt: loginTime, LastLogin: loginTime, @@ -347,6 +350,145 @@ func TestRefreshTokenAuthTime(t *testing.T) { } } +// failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector +// but always returns an error on Refresh, proving that the upstream is not contacted. +type failingRefreshConnector struct { + identity connector.Identity +} + +func (f *failingRefreshConnector) LoginURL(_ connector.Scopes, callbackURL, state string) (string, []byte, error) { + u, _ := url.Parse(callbackURL) + v := u.Query() + v.Set("state", state) + u.RawQuery = v.Encode() + return u.String(), nil, nil +} + +func (f *failingRefreshConnector) HandleCallback(_ connector.Scopes, _ []byte, _ *http.Request) (connector.Identity, error) { + return f.identity, nil +} + +func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, _ connector.Identity) (connector.Identity, error) { + return connector.Identity{}, errors.New("upstream: refresh token expired") +} + +func TestRefreshDisconnectsUpstreamWhenSessionsEnabled(t *testing.T) { + t0 := time.Now().UTC().Round(time.Second) + loginTime := t0.Add(-10 * time.Minute) + + tests := []struct { + name string + sessionsEnabled bool + createUserIdentity bool + wantOK bool + }{ + { + name: "sessions enabled - uses user identity, skips upstream", + sessionsEnabled: true, + createUserIdentity: true, + wantOK: true, + }, + { + name: "sessions enabled without user identity - fails", + sessionsEnabled: true, + createUserIdentity: false, + wantOK: false, + }, + { + name: "sessions disabled - upstream failure returns error", + sessionsEnabled: false, + createUserIdentity: false, + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + httpServer, s := newTestServer(t, func(c *Config) { + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + if tc.sessionsEnabled { + s.sessionConfig = &SessionConfig{ + CookieName: "dex_session", + AbsoluteLifetime: 24 * time.Hour, + } + } + + mockRefreshTokenTestStorage(t, s.storage, false) + + // Replace the connector with one that always fails on Refresh. + // When sessions are enabled this connector should never be called; + // when sessions are disabled, the failure proves the error path works. + s.mu.Lock() + s.connectors["test"] = Connector{ + Connector: &failingRefreshConnector{ + identity: connector.Identity{ + UserID: "0-385-28089-0", + Username: "Kilgore Trout", + Email: "kilgore@kilgore.trout", + }, + }, + } + s.mu.Unlock() + + if tc.createUserIdentity { + err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ + UserID: "1", + ConnectorID: "test", + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + CreatedAt: loginTime, + LastLogin: loginTime, + }) + require.NoError(t, err) + } + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.wantOK { + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var resp struct { + IDToken string `json:"id_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &resp) + require.NoError(t, err) + + // Verify the returned claims match UserIdentity, not the connector. + claims := decodeJWTClaims(t, resp.IDToken) + assert.Equal(t, "jane.doe@example.com", claims["email"]) + assert.Equal(t, "jane", claims["name"]) + } else { + require.NotEqual(t, http.StatusOK, rr.Code, + "expected error when sessions disabled or user identity missing") + } + }) + } +} + func TestRefreshTokenPolicy(t *testing.T) { lastTime := time.Now() l := slog.New(slog.DiscardHandler)