Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/featureflags"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
Expand Down Expand Up @@ -107,6 +108,10 @@ func newInternalServerError() *refreshError {
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
}

func newUpstreamRefreshError(desc string) *refreshError {
return &refreshError{msg: errInvalidGrant, desc: desc, code: http.StatusBadGateway}
Comment thread
nabokihms marked this conversation as resolved.
Outdated
}

func newBadRequestError(desc string) *refreshError {
return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
}
Expand Down Expand Up @@ -271,7 +276,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext,
newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
if err != nil {
s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err)
return ident, newInternalServerError()
return ident, newUpstreamRefreshError(err.Error())
Comment thread
nabokihms marked this conversation as resolved.
Outdated
}

return newIdent, nil
Expand Down Expand Up @@ -327,6 +332,20 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
Groups: rCtx.storageToken.Claims.Groups,
}

// Pre-fetch UserIdentity outside the storage transaction to avoid deadlocks with
// storage backends that use a single lock (e.g., memory storage).
// This is used as a fallback when the upstream connector refresh fails.
var cachedIdentity *storage.UserIdentity
if featureflags.SessionsEnabled.Enabled() {
ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.WarnContext(ctx, "failed to pre-fetch user identity for upstream refresh fallback",
"user_id", rCtx.storageToken.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID, "err", err)
} else {
cachedIdentity = &ui
}
}

refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)
Expand Down Expand Up @@ -373,7 +392,26 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
// When sessions are enabled and the upstream provider fails (e.g., expired upstream
// refresh token), fall back to claims stored in UserIdentity instead of failing the
// entire refresh. This matches the behavior of other identity brokers (Keycloak, Auth0)
// that do not contact the upstream on every downstream refresh.
if cachedIdentity != nil {
s.logger.WarnContext(ctx, "upstream refresh failed, using cached identity from last login",
"err", rerr, "user_id", cachedIdentity.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID)
ident = connector.Identity{
UserID: cachedIdentity.Claims.UserID,
Username: cachedIdentity.Claims.Username,
PreferredUsername: cachedIdentity.Claims.PreferredUsername,
Email: cachedIdentity.Claims.Email,
EmailVerified: cachedIdentity.Claims.EmailVerified,
Groups: cachedIdentity.Claims.Groups,
Comment thread
nabokihms marked this conversation as resolved.
Outdated
}
rerr = nil
}
if rerr != nil {
return old, rerr
}
}

// Update the claims of the refresh token.
Expand Down
144 changes: 144 additions & 0 deletions server/refreshhandlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package server

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
Expand All @@ -16,6 +18,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
Expand Down Expand Up @@ -347,6 +350,147 @@ func TestRefreshTokenAuthTime(t *testing.T) {
}
}

// failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector
// but always returns an error on Refresh, simulating an upstream provider failure.
type failingRefreshConnector struct {
identity connector.Identity
}

func (f *failingRefreshConnector) LoginURL(_ connector.Scopes, callbackURL, state string) (string, []byte, error) {
u, _ := url.Parse(callbackURL)
v := u.Query()
v.Set("state", state)
u.RawQuery = v.Encode()
return u.String(), nil, nil
}

func (f *failingRefreshConnector) HandleCallback(_ connector.Scopes, _ []byte, _ *http.Request) (connector.Identity, error) {
return f.identity, nil
}

func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, _ connector.Identity) (connector.Identity, error) {
return connector.Identity{}, errors.New("upstream: refresh token expired")
}

func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) {
t0 := time.Now().UTC().Round(time.Second)
loginTime := t0.Add(-10 * time.Minute)

tests := []struct {
name string
sessionsEnabled bool
createUserIdentity bool
wantOK bool
}{
{
name: "sessions enabled with user identity - fallback succeeds",
sessionsEnabled: true,
createUserIdentity: true,
wantOK: true,
},
{
name: "sessions enabled without user identity - fallback fails",
sessionsEnabled: true,
createUserIdentity: false,
wantOK: false,
},
{
name: "sessions disabled - no fallback, error returned",
sessionsEnabled: false,
createUserIdentity: false,
wantOK: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
setSessionsEnabled(t, tc.sessionsEnabled)

httpServer, s := newTestServer(t, func(c *Config) {
c.Now = func() time.Time { return t0 }
})
defer httpServer.Close()

if tc.sessionsEnabled {
s.sessionConfig = &SessionConfig{
CookieName: "dex_session",
AbsoluteLifetime: 24 * time.Hour,
}
}

mockRefreshTokenTestStorage(t, s.storage, false)

// Replace the connector with one that always fails on Refresh.
// ResourceVersion must match the storage connector (empty by default in
// mockRefreshTokenTestStorage) to prevent getConnector from re-opening it.
s.mu.Lock()
s.connectors["test"] = Connector{
Connector: &failingRefreshConnector{
identity: connector.Identity{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
},
},
}
s.mu.Unlock()

if tc.createUserIdentity {
err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{
UserID: "1",
ConnectorID: "test",
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
CreatedAt: loginTime,
LastLogin: loginTime,
})
require.NoError(t, err)
}

u, err := url.Parse(s.issuerURL.String())
require.NoError(t, err)

tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
require.NoError(t, err)

u.Path = path.Join(u.Path, "/token")
v := url.Values{}
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", tokenData)

req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
req.SetBasicAuth("test", "barfoo")

rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)

if tc.wantOK {
require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String())

var resp struct {
IDToken string `json:"id_token"`
}
err = json.Unmarshal(rr.Body.Bytes(), &resp)
require.NoError(t, err)

// Verify the returned claims match UserIdentity, not the connector.
claims := decodeJWTClaims(t, resp.IDToken)
assert.Equal(t, "jane.doe@example.com", claims["email"])
assert.Equal(t, "jane", claims["name"])
} else {
require.NotEqual(t, http.StatusOK, rr.Code,
"expected error when upstream fails without fallback")
}
})
}
}

func TestRefreshTokenPolicy(t *testing.T) {
lastTime := time.Now()
l := slog.New(slog.DiscardHandler)
Expand Down
Loading