From 515000b4bd4ca34c9d3b49bca66c8c22c012db1b Mon Sep 17 00:00:00 2001 From: masonsxu Date: Sun, 8 Mar 2026 22:22:23 +0800 Subject: [PATCH 1/2] fix: normalize expired token error in GetClaimsFromJWT to ErrExpiredToken GetClaimsFromJWT returns the raw *jwt.ValidationError from ParseToken when a token is expired. This makes it impossible for downstream HTTPStatusMessageFunc to identify expired tokens via == ErrExpiredToken, since *jwt.ValidationError and the sentinel ErrExpiredToken are different types. CheckIfTokenExpire already handles this correctly by checking validationErr.Errors == jwt.ValidationErrorExpired and continuing execution. However, GetClaimsFromJWT lacks the same normalization, causing an inconsistency between the auth flow and the refresh flow. This fix normalizes *jwt.ValidationError with only ValidationErrorExpired to the sentinel ErrExpiredToken in GetClaimsFromJWT, making error handling consistent across both paths. Before this fix: - middlewareImpl receives *jwt.ValidationError for expired tokens - HTTPStatusMessageFunc cannot match it with == ErrExpiredToken - The manual exp check in middlewareImpl (lines 475-496) is dead code for expired tokens since GetClaimsFromJWT already returns an error After this fix: - middlewareImpl receives ErrExpiredToken for expired tokens - HTTPStatusMessageFunc can correctly identify expired tokens --- auth_jwt.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/auth_jwt.go b/auth_jwt.go index 6ee1f59..220be3a 100644 --- a/auth_jwt.go +++ b/auth_jwt.go @@ -514,6 +514,15 @@ func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.Request func (mw *HertzJWTMiddleware) GetClaimsFromJWT(ctx context.Context, c *app.RequestContext) (MapClaims, error) { token, err := mw.ParseToken(ctx, c) if err != nil { + // Normalize expired token errors to the sentinel ErrExpiredToken, + // consistent with CheckIfTokenExpire behavior. + // Without this, *jwt.ValidationError is returned as-is, + // making it impossible to match with == ErrExpiredToken downstream. + validationErr, ok := err.(*jwt.ValidationError) + if ok && validationErr.Errors == jwt.ValidationErrorExpired { + return nil, ErrExpiredToken + } + return nil, err } From ef7528631f526ab6d9cc3d8d6580cadaca6cd1d0 Mon Sep 17 00:00:00 2001 From: masonsxu Date: Mon, 9 Mar 2026 00:53:05 +0800 Subject: [PATCH 2/2] feat: add transparent token refresh in middleware When EnableTransparentRefresh is true and the token is expired but within the MaxRefresh window, the middleware automatically generates a new token and continues processing the request instead of returning 401. This allows seamless token renewal without requiring a separate refresh endpoint call. The feature is opt-in (defaults to false) and fully backward compatible. --- auth_jwt.go | 92 ++++++++++++++++++- auth_jwt_test.go | 224 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 314 insertions(+), 2 deletions(-) diff --git a/auth_jwt.go b/auth_jwt.go index 220be3a..5b82912 100644 --- a/auth_jwt.go +++ b/auth_jwt.go @@ -191,6 +191,13 @@ type HertzJWTMiddleware struct { // ParseOptions allow to modify jwt's parser methods ParseOptions []jwt.ParserOption + + // EnableTransparentRefresh enables automatic token refresh during middleware processing. + // When true and the token is expired but within the MaxRefresh window, + // the middleware will automatically generate a new token and continue processing + // the request instead of returning 401. + // Optional, defaults to false for backward compatibility. + EnableTransparentRefresh bool } var ( @@ -468,8 +475,17 @@ func (mw *HertzJWTMiddleware) MiddlewareFunc() app.HandlerFunc { func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.RequestContext) { claims, err := mw.GetClaimsFromJWT(ctx, c) if err != nil { - mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, ctx, c)) - return + if mw.EnableTransparentRefresh && mw.MaxRefresh > 0 && mw.isExpiredTokenError(err) { + if refreshedClaims, refreshErr := mw.tryTransparentRefresh(ctx, c); refreshErr == nil { + claims = refreshedClaims + err = nil + } + } + + if err != nil { + mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, ctx, c)) + return + } } switch v := claims["exp"].(type) { @@ -540,6 +556,78 @@ func (mw *HertzJWTMiddleware) GetClaimsFromJWT(ctx context.Context, c *app.Reque return claims, nil } +// isExpiredTokenError checks if the error indicates an expired token. +func (mw *HertzJWTMiddleware) isExpiredTokenError(err error) bool { + if errors.Is(err, ErrExpiredToken) { + return true + } + + var ve *jwt.ValidationError + if errors.As(err, &ve) { + return ve.Errors == jwt.ValidationErrorExpired + } + + return false +} + +// tryTransparentRefresh attempts to refresh an expired token that is still within the MaxRefresh window. +// On success, it returns the new claims and sets the refreshed token in the response (cookie/header). +func (mw *HertzJWTMiddleware) tryTransparentRefresh(ctx context.Context, c *app.RequestContext) (MapClaims, error) { + claims, err := mw.CheckIfTokenExpire(ctx, c) + if err != nil { + return nil, err + } + + // Create new token with refreshed expiry + newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) + newClaims := newToken.Claims.(jwt.MapClaims) + copyClaims := make(jwt.MapClaims, len(claims)) + + for k, v := range claims { + newClaims[k] = v + copyClaims[k] = v + } + + expire := mw.TimeFunc().Add(mw.TimeoutFunc(copyClaims)) + newClaims["exp"] = expire.Unix() + + // Preserve original orig_iat to maintain MaxRefresh window + if origIat, exists := claims["orig_iat"]; exists { + newClaims["orig_iat"] = origIat + } else { + newClaims["orig_iat"] = mw.TimeFunc().Unix() + } + + tokenString, err := mw.signedString(newToken) + if err != nil { + return nil, err + } + + // Set cookie if enabled + if mw.SendCookie { + expireCookie := mw.TimeFunc().Add(mw.CookieMaxAge) + maxage := int(expireCookie.Unix() - mw.TimeFunc().Unix()) + c.SetCookie(mw.CookieName, tokenString, maxage, "/", mw.CookieDomain, mw.CookieSameSite, mw.SecureCookie, mw.CookieHTTPOnly) + } + + // Set Authorization header if enabled + if mw.SendAuthorization { + c.Header("Authorization", mw.TokenHeadName+" "+tokenString) + } + + // Store new token in context + c.Set("JWT_TOKEN", tokenString) + + // Build result MapClaims with float64 exp for middlewareImpl compatibility + result := make(MapClaims, len(newClaims)) + for k, v := range newClaims { + result[k] = v + } + result["exp"] = float64(expire.Unix()) + + return result, nil +} + // LoginHandler can be used by clients to get a jwt token. // Payload needs to be json in the form of {"username": "USERNAME", "password": "PASSWORD"}. // Reply will be of the form {"token": "TOKEN"}. diff --git a/auth_jwt_test.go b/auth_jwt_test.go index 4730dd4..2779780 100644 --- a/auth_jwt_test.go +++ b/auth_jwt_test.go @@ -1180,3 +1180,227 @@ func TestLogout(t *testing.T) { assert.DeepEqual(t, http.StatusOK, w.Code) assert.DeepEqual(t, fmt.Sprintf("%s=; domain=%s; path=/", cookieName, cookieDomain), w.Header().Get("Set-Cookie")) } + +// makeExpiredTokenString creates a token that expired expiredAgo duration ago, +// with orig_iat set to origIatAgo duration ago. +func makeExpiredTokenString(expiredAgo, origIatAgo time.Duration) string { + token := jwt.New(jwt.GetSigningMethod("HS256")) + claims := token.Claims.(jwt.MapClaims) + claims["identity"] = "admin" + claims["exp"] = time.Now().Add(-expiredAgo).Unix() + claims["orig_iat"] = time.Now().Add(-origIatAgo).Unix() + tokenString, _ := token.SignedString(key) + return tokenString +} + +func transparentRefreshHandler(auth *HertzJWTMiddleware) *route.Engine { + r := route.NewEngine(config.NewOptions([]config.Option{})) + r.Use(auth.MiddlewareFunc()) + r.GET("/protected", helloHandler) + return r +} + +func TestTransparentRefresh_ExpiredWithinMaxRefresh(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Token expired 1 minute ago, orig_iat 30 minutes ago (within 2h MaxRefresh) + tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusOK, w.Code) +} + +func TestTransparentRefresh_ExpiredBeyondMaxRefresh(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Token expired 1 minute ago, orig_iat 3 hours ago (beyond 2h MaxRefresh) + tokenString := makeExpiredTokenString(time.Minute, 3*time.Hour) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusUnauthorized, w.Code) +} + +func TestTransparentRefresh_InvalidSignature(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Create token with wrong key + token := jwt.New(jwt.GetSigningMethod("HS256")) + claims := token.Claims.(jwt.MapClaims) + claims["identity"] = "admin" + claims["exp"] = time.Now().Add(-time.Minute).Unix() + claims["orig_iat"] = time.Now().Unix() + tokenString, _ := token.SignedString([]byte("wrong key")) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusUnauthorized, w.Code) +} + +func TestTransparentRefresh_ValidTokenNotRefreshed(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Valid, non-expired token + tokenString := makeTokenString("HS256", "admin") + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusOK, w.Code) + + // Should not have a new token in response (no refresh needed) + resp := w.Result() + assert.DeepEqual(t, "", resp.Header.Get("Authorization")) +} + +func TestTransparentRefresh_DisabledByDefault(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + Authenticator: defaultAuthenticator, + // EnableTransparentRefresh not set (defaults to false) + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Token expired 1 minute ago, orig_iat 30 minutes ago (within MaxRefresh) + tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusUnauthorized, w.Code) +} + +func TestTransparentRefresh_PreservesOrigIat(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + SendAuthorization: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Create token with specific orig_iat + origIat := time.Now().Add(-30 * time.Minute).Unix() + token := jwt.New(jwt.GetSigningMethod("HS256")) + claims := token.Claims.(jwt.MapClaims) + claims["identity"] = "admin" + claims["exp"] = time.Now().Add(-time.Minute).Unix() + claims["orig_iat"] = origIat + tokenString, _ := token.SignedString(key) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusOK, w.Code) + + // Parse the new token from Authorization header + resp := w.Result() + authHeader := resp.Header.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "Bearer ")) + + newTokenString := strings.TrimPrefix(authHeader, "Bearer ") + newToken, err := jwt.Parse(newTokenString, func(token *jwt.Token) (interface{}, error) { + return key, nil + }) + assert.Nil(t, err) + + newClaims := newToken.Claims.(jwt.MapClaims) + newOrigIat := int64(newClaims["orig_iat"].(float64)) + assert.DeepEqual(t, origIat, newOrigIat) +} + +func TestTransparentRefresh_CookieAndHeader(t *testing.T) { + cookieName := "jwt" + cookieDomain := "example.com" + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 2 * time.Hour, + EnableTransparentRefresh: true, + SendCookie: true, + CookieName: cookieName, + CookieDomain: cookieDomain, + SendAuthorization: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Token expired 1 minute ago, orig_iat 30 minutes ago + tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusOK, w.Code) + + // Check cookie is set + resp := w.Result() + assert.True(t, strings.HasPrefix(string(resp.Header.FullCookie()), cookieName+"=")) + + // Check Authorization header is set + authHeader := resp.Header.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "Bearer ")) +} + +func TestTransparentRefresh_MaxRefreshZero(t *testing.T) { + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: 0, // MaxRefresh is zero + EnableTransparentRefresh: true, + Authenticator: defaultAuthenticator, + }) + + handler := transparentRefreshHandler(authMiddleware) + + // Token expired 1 minute ago + tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute) + + w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil, + ut.Header{Key: "Authorization", Value: "Bearer " + tokenString}) + assert.DeepEqual(t, http.StatusUnauthorized, w.Code) +}