Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ type HertzJWTMiddleware struct {

// ParseOptions allow to modify jwt's parser methods
ParseOptions []jwt.ParserOption

// EnableTransparentRefresh enables automatic token refresh during middleware processing.
// When true and the token is expired but within the MaxRefresh window,
// the middleware will automatically generate a new token and continue processing
// the request instead of returning 401.
// Optional, defaults to false for backward compatibility.
EnableTransparentRefresh bool
}

var (
Expand Down Expand Up @@ -468,8 +475,17 @@ func (mw *HertzJWTMiddleware) MiddlewareFunc() app.HandlerFunc {
func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.RequestContext) {
claims, err := mw.GetClaimsFromJWT(ctx, c)
if err != nil {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, ctx, c))
return
if mw.EnableTransparentRefresh && mw.MaxRefresh > 0 && mw.isExpiredTokenError(err) {
if refreshedClaims, refreshErr := mw.tryTransparentRefresh(ctx, c); refreshErr == nil {
claims = refreshedClaims
err = nil
}
}

if err != nil {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, ctx, c))
return
}
}

switch v := claims["exp"].(type) {
Expand Down Expand Up @@ -514,6 +530,15 @@ func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.Request
func (mw *HertzJWTMiddleware) GetClaimsFromJWT(ctx context.Context, c *app.RequestContext) (MapClaims, error) {
token, err := mw.ParseToken(ctx, c)
if err != nil {
// Normalize expired token errors to the sentinel ErrExpiredToken,
// consistent with CheckIfTokenExpire behavior.
// Without this, *jwt.ValidationError is returned as-is,
// making it impossible to match with == ErrExpiredToken downstream.
validationErr, ok := err.(*jwt.ValidationError)
if ok && validationErr.Errors == jwt.ValidationErrorExpired {
return nil, ErrExpiredToken
}

return nil, err
}

Expand All @@ -531,6 +556,78 @@ func (mw *HertzJWTMiddleware) GetClaimsFromJWT(ctx context.Context, c *app.Reque
return claims, nil
}

// isExpiredTokenError checks if the error indicates an expired token.
func (mw *HertzJWTMiddleware) isExpiredTokenError(err error) bool {
if errors.Is(err, ErrExpiredToken) {
return true
}

var ve *jwt.ValidationError
if errors.As(err, &ve) {
return ve.Errors == jwt.ValidationErrorExpired
}

return false
}

// tryTransparentRefresh attempts to refresh an expired token that is still within the MaxRefresh window.
// On success, it returns the new claims and sets the refreshed token in the response (cookie/header).
func (mw *HertzJWTMiddleware) tryTransparentRefresh(ctx context.Context, c *app.RequestContext) (MapClaims, error) {
claims, err := mw.CheckIfTokenExpire(ctx, c)
if err != nil {
return nil, err
}

// Create new token with refreshed expiry
newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm))
newClaims := newToken.Claims.(jwt.MapClaims)
copyClaims := make(jwt.MapClaims, len(claims))

for k, v := range claims {
newClaims[k] = v
copyClaims[k] = v
}

expire := mw.TimeFunc().Add(mw.TimeoutFunc(copyClaims))
newClaims["exp"] = expire.Unix()

// Preserve original orig_iat to maintain MaxRefresh window
if origIat, exists := claims["orig_iat"]; exists {
newClaims["orig_iat"] = origIat
} else {
newClaims["orig_iat"] = mw.TimeFunc().Unix()
}

tokenString, err := mw.signedString(newToken)
if err != nil {
return nil, err
}

// Set cookie if enabled
if mw.SendCookie {
expireCookie := mw.TimeFunc().Add(mw.CookieMaxAge)
maxage := int(expireCookie.Unix() - mw.TimeFunc().Unix())
c.SetCookie(mw.CookieName, tokenString, maxage, "/", mw.CookieDomain, mw.CookieSameSite, mw.SecureCookie, mw.CookieHTTPOnly)
}

// Set Authorization header if enabled
if mw.SendAuthorization {
c.Header("Authorization", mw.TokenHeadName+" "+tokenString)
}

// Store new token in context
c.Set("JWT_TOKEN", tokenString)

// Build result MapClaims with float64 exp for middlewareImpl compatibility
result := make(MapClaims, len(newClaims))
for k, v := range newClaims {
result[k] = v
}
result["exp"] = float64(expire.Unix())

return result, nil
}

// LoginHandler can be used by clients to get a jwt token.
// Payload needs to be json in the form of {"username": "USERNAME", "password": "PASSWORD"}.
// Reply will be of the form {"token": "TOKEN"}.
Expand Down
224 changes: 224 additions & 0 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1180,3 +1180,227 @@ func TestLogout(t *testing.T) {
assert.DeepEqual(t, http.StatusOK, w.Code)
assert.DeepEqual(t, fmt.Sprintf("%s=; domain=%s; path=/", cookieName, cookieDomain), w.Header().Get("Set-Cookie"))
}

// makeExpiredTokenString creates a token that expired expiredAgo duration ago,
// with orig_iat set to origIatAgo duration ago.
func makeExpiredTokenString(expiredAgo, origIatAgo time.Duration) string {
token := jwt.New(jwt.GetSigningMethod("HS256"))
claims := token.Claims.(jwt.MapClaims)
claims["identity"] = "admin"
claims["exp"] = time.Now().Add(-expiredAgo).Unix()
claims["orig_iat"] = time.Now().Add(-origIatAgo).Unix()
tokenString, _ := token.SignedString(key)
return tokenString
}

func transparentRefreshHandler(auth *HertzJWTMiddleware) *route.Engine {
r := route.NewEngine(config.NewOptions([]config.Option{}))
r.Use(auth.MiddlewareFunc())
r.GET("/protected", helloHandler)
return r
}

func TestTransparentRefresh_ExpiredWithinMaxRefresh(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Token expired 1 minute ago, orig_iat 30 minutes ago (within 2h MaxRefresh)
tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusOK, w.Code)
}

func TestTransparentRefresh_ExpiredBeyondMaxRefresh(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Token expired 1 minute ago, orig_iat 3 hours ago (beyond 2h MaxRefresh)
tokenString := makeExpiredTokenString(time.Minute, 3*time.Hour)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusUnauthorized, w.Code)
}

func TestTransparentRefresh_InvalidSignature(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Create token with wrong key
token := jwt.New(jwt.GetSigningMethod("HS256"))
claims := token.Claims.(jwt.MapClaims)
claims["identity"] = "admin"
claims["exp"] = time.Now().Add(-time.Minute).Unix()
claims["orig_iat"] = time.Now().Unix()
tokenString, _ := token.SignedString([]byte("wrong key"))

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusUnauthorized, w.Code)
}

func TestTransparentRefresh_ValidTokenNotRefreshed(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Valid, non-expired token
tokenString := makeTokenString("HS256", "admin")

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusOK, w.Code)

// Should not have a new token in response (no refresh needed)
resp := w.Result()
assert.DeepEqual(t, "", resp.Header.Get("Authorization"))
}

func TestTransparentRefresh_DisabledByDefault(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
Authenticator: defaultAuthenticator,
// EnableTransparentRefresh not set (defaults to false)
})

handler := transparentRefreshHandler(authMiddleware)

// Token expired 1 minute ago, orig_iat 30 minutes ago (within MaxRefresh)
tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusUnauthorized, w.Code)
}

func TestTransparentRefresh_PreservesOrigIat(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
SendAuthorization: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Create token with specific orig_iat
origIat := time.Now().Add(-30 * time.Minute).Unix()
token := jwt.New(jwt.GetSigningMethod("HS256"))
claims := token.Claims.(jwt.MapClaims)
claims["identity"] = "admin"
claims["exp"] = time.Now().Add(-time.Minute).Unix()
claims["orig_iat"] = origIat
tokenString, _ := token.SignedString(key)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusOK, w.Code)

// Parse the new token from Authorization header
resp := w.Result()
authHeader := resp.Header.Get("Authorization")
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))

newTokenString := strings.TrimPrefix(authHeader, "Bearer ")
newToken, err := jwt.Parse(newTokenString, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
assert.Nil(t, err)

newClaims := newToken.Claims.(jwt.MapClaims)
newOrigIat := int64(newClaims["orig_iat"].(float64))
assert.DeepEqual(t, origIat, newOrigIat)
}

func TestTransparentRefresh_CookieAndHeader(t *testing.T) {
cookieName := "jwt"
cookieDomain := "example.com"
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 2 * time.Hour,
EnableTransparentRefresh: true,
SendCookie: true,
CookieName: cookieName,
CookieDomain: cookieDomain,
SendAuthorization: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Token expired 1 minute ago, orig_iat 30 minutes ago
tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusOK, w.Code)

// Check cookie is set
resp := w.Result()
assert.True(t, strings.HasPrefix(string(resp.Header.FullCookie()), cookieName+"="))

// Check Authorization header is set
authHeader := resp.Header.Get("Authorization")
assert.True(t, strings.HasPrefix(authHeader, "Bearer "))
}

func TestTransparentRefresh_MaxRefreshZero(t *testing.T) {
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: 0, // MaxRefresh is zero
EnableTransparentRefresh: true,
Authenticator: defaultAuthenticator,
})

handler := transparentRefreshHandler(authMiddleware)

// Token expired 1 minute ago
tokenString := makeExpiredTokenString(time.Minute, 30*time.Minute)

w := ut.PerformRequest(handler, http.MethodGet, "/protected", nil,
ut.Header{Key: "Authorization", Value: "Bearer " + tokenString})
assert.DeepEqual(t, http.StatusUnauthorized, w.Code)
}