Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 45 additions & 17 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func newInternalServerError() *refreshError {
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
}

func newUpstreamRefreshError() *refreshError {
return &refreshError{msg: errInvalidGrant, desc: "Upstream identity provider refresh failed.", code: http.StatusBadGateway}
}

func newBadRequestError(desc string) *refreshError {
return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
}
Expand Down Expand Up @@ -271,7 +275,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext,
newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
if err != nil {
s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err)
return ident, newInternalServerError()
return ident, newUpstreamRefreshError()
}

return newIdent, nil
Expand Down Expand Up @@ -308,7 +312,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr
}

// updateRefreshToken updates refresh token and offline session in the storage
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) {
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext, userIdent *storage.UserIdentity) (*internal.RefreshToken, connector.Identity, *refreshError) {
var rerr *refreshError

newToken := &internal.RefreshToken{
Expand All @@ -327,6 +331,12 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
Groups: rCtx.storageToken.Claims.Groups,
}

// When sessions are enabled, downstream token refresh is disconnected from the upstream
// identity provider. Instead of calling the connector's Refresh method (which would contact
// the upstream IdP and may fail if the upstream refresh token has expired), we use the claims
// stored in UserIdentity at the time of the last interactive login. This aligns with the
// behavior of other identity brokers (e.g., Keycloak, Auth0) that treat downstream sessions
// independently from the upstream provider session lifetime.
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)
Expand Down Expand Up @@ -371,9 +381,21 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
// Call only once if there is a request which is not in the reuse interval.
// This is required to avoid multiple calls to the external IdP for concurrent requests.
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
// When sessions are enabled, use cached identity instead of refreshing upstream.
if userIdent != nil {
ident = connector.Identity{
UserID: userIdent.Claims.UserID,
Username: userIdent.Claims.Username,
PreferredUsername: userIdent.Claims.PreferredUsername,
Email: userIdent.Claims.Email,
EmailVerified: userIdent.Claims.EmailVerified,
Groups: userIdent.Claims.Groups,
}
} else {
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
if rerr != nil {
return old, rerr
}
}

// Update the claims of the refresh token.
Expand Down Expand Up @@ -424,7 +446,24 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return
}

newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx)
var userIdent *storage.UserIdentity

if s.sessionConfig != nil {
ui, err := s.storage.GetUserIdentity(r.Context(), rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}
userIdent = &ui
}

authTime := time.Time{}
if userIdent != nil {
authTime = userIdent.LastLogin
}

newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx, userIdent)
if rerr != nil {
s.refreshTokenErrHelper(w, rerr)
return
Expand All @@ -439,17 +478,6 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

authTime := time.Time{}
if s.sessionConfig != nil {
ui, err := s.storage.GetUserIdentity(r.Context(), ident.UserID, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}
authTime = ui.LastLogin
}

accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, authTime)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err)
Expand Down
156 changes: 149 additions & 7 deletions server/refreshhandlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package server

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
Expand All @@ -16,6 +18,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
Expand Down Expand Up @@ -276,17 +279,17 @@ func TestRefreshTokenAuthTime(t *testing.T) {
mockRefreshTokenTestStorage(t, s.storage, false)

if tc.createUserIdentity {
// The mock connector returns UserID "0-385-28089-0" on Refresh,
// so the UserIdentity must use that ID to be found by handleRefreshToken.
// UserIdentity must match the refresh token's Claims.UserID ("1")
// because updateRefreshToken looks it up by that ID.
err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{
UserID: "0-385-28089-0",
UserID: "1",
ConnectorID: "test",
Claims: storage.Claims{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"authors"},
Groups: []string{"a", "b"},
},
CreatedAt: loginTime,
LastLogin: loginTime,
Expand Down Expand Up @@ -347,6 +350,145 @@ func TestRefreshTokenAuthTime(t *testing.T) {
}
}

// failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector
// but always returns an error on Refresh, proving that the upstream is not contacted.
type failingRefreshConnector struct {
identity connector.Identity
}

func (f *failingRefreshConnector) LoginURL(_ connector.Scopes, callbackURL, state string) (string, []byte, error) {
u, _ := url.Parse(callbackURL)
v := u.Query()
v.Set("state", state)
u.RawQuery = v.Encode()
return u.String(), nil, nil
}

func (f *failingRefreshConnector) HandleCallback(_ connector.Scopes, _ []byte, _ *http.Request) (connector.Identity, error) {
return f.identity, nil
}

func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, _ connector.Identity) (connector.Identity, error) {
return connector.Identity{}, errors.New("upstream: refresh token expired")
}

func TestRefreshDisconnectsUpstreamWhenSessionsEnabled(t *testing.T) {
t0 := time.Now().UTC().Round(time.Second)
loginTime := t0.Add(-10 * time.Minute)

tests := []struct {
name string
sessionsEnabled bool
createUserIdentity bool
wantOK bool
}{
{
name: "sessions enabled - uses user identity, skips upstream",
sessionsEnabled: true,
createUserIdentity: true,
wantOK: true,
},
{
name: "sessions enabled without user identity - fails",
sessionsEnabled: true,
createUserIdentity: false,
wantOK: false,
},
{
name: "sessions disabled - upstream failure returns error",
sessionsEnabled: false,
createUserIdentity: false,
wantOK: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Now = func() time.Time { return t0 }
})
defer httpServer.Close()

if tc.sessionsEnabled {
s.sessionConfig = &SessionConfig{
CookieName: "dex_session",
AbsoluteLifetime: 24 * time.Hour,
}
}

mockRefreshTokenTestStorage(t, s.storage, false)

// Replace the connector with one that always fails on Refresh.
// When sessions are enabled this connector should never be called;
// when sessions are disabled, the failure proves the error path works.
s.mu.Lock()
s.connectors["test"] = Connector{
Connector: &failingRefreshConnector{
identity: connector.Identity{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
},
},
}
s.mu.Unlock()

if tc.createUserIdentity {
err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{
UserID: "1",
ConnectorID: "test",
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
CreatedAt: loginTime,
LastLogin: loginTime,
})
require.NoError(t, err)
}

u, err := url.Parse(s.issuerURL.String())
require.NoError(t, err)

tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
require.NoError(t, err)

u.Path = path.Join(u.Path, "/token")
v := url.Values{}
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", tokenData)

req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
req.SetBasicAuth("test", "barfoo")

rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)

if tc.wantOK {
require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String())

var resp struct {
IDToken string `json:"id_token"`
}
err = json.Unmarshal(rr.Body.Bytes(), &resp)
require.NoError(t, err)

// Verify the returned claims match UserIdentity, not the connector.
claims := decodeJWTClaims(t, resp.IDToken)
assert.Equal(t, "jane.doe@example.com", claims["email"])
assert.Equal(t, "jane", claims["name"])
} else {
require.NotEqual(t, http.StatusOK, rr.Code,
"expected error when sessions disabled or user identity missing")
}
})
}
}

func TestRefreshTokenPolicy(t *testing.T) {
lastTime := time.Now()
l := slog.New(slog.DiscardHandler)
Expand Down
Loading