diff --git a/README.md b/README.md index bfa4f9d..f7e4735 100644 --- a/README.md +++ b/README.md @@ -87,4 +87,9 @@ tbls doc Generate swagger docs via: ```bash swag init -g main.go -o docs +``` + +Make db migration via: +```bash +migrate create -ext sql -dir db/migration -seq ``` \ No newline at end of file diff --git a/api/middleware_test.go b/api/middleware_test.go index 6372193..8279363 100644 --- a/api/middleware_test.go +++ b/api/middleware_test.go @@ -20,8 +20,9 @@ func addAuthorization( username string, duration time.Duration, ) { - token, err := tokenMaker.CreateToken(username, duration) + token, payload, err := tokenMaker.CreateToken(username, duration) require.NoError(t, err) + require.NotEmpty(t, payload) authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) request.Header.Set(authorizationHeaderKey, authorizationHeader) diff --git a/api/server.go b/api/server.go index 0499231..2818f7d 100644 --- a/api/server.go +++ b/api/server.go @@ -51,6 +51,7 @@ func (server *Server) setupRouter() { router.POST("/users", server.createUser) router.POST("/users/login", server.loginUser) + router.POST("/tokens/renew_access", server.renewAccessToken) router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) diff --git a/api/token.go b/api/token.go new file mode 100644 index 0000000..a52234b --- /dev/null +++ b/api/token.go @@ -0,0 +1,80 @@ +package api + +import ( + "database/sql" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" +) + +type renewAccessTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +type renewAccessTokenResponse struct { + AccessToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` +} + +func (server *Server) renewAccessToken(ctx *gin.Context) { + var req renewAccessTokenRequest + if err := ctx.ShouldBindBodyWith(&req, binding.JSON); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + session, err := server.store.GetSession(ctx, refreshPayload.ID) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + if time.Now().After(session.ExpiresAt) { + err := fmt.Errorf("session expired") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.IsBlocked { + err := fmt.Errorf("blocked session") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.RefreshToken != req.RefreshToken { + err := fmt.Errorf("session token mismatch") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.Username != refreshPayload.Username { + err := fmt.Errorf("incorrect session user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken, accessPayload, err := server.tokenMaker.CreateToken(refreshPayload.Username, server.config.AccessTokenDuration) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + resp := renewAccessTokenResponse{ + AccessToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + } + ctx.JSON(http.StatusOK, resp) +} diff --git a/api/users.go b/api/users.go index feb2a60..527c370 100644 --- a/api/users.go +++ b/api/users.go @@ -9,6 +9,7 @@ import ( "github.com/HyperNaser/gobank/util" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" + "github.com/google/uuid" "github.com/lib/pq" ) @@ -93,8 +94,12 @@ type loginUserRequest struct { } type LoginUserResponse struct { - AccessToken string `json:"access_token"` - User UserResponse `json:"user"` + SessionID uuid.UUID `json:"session_id"` + AccessToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` + User UserResponse `json:"user"` } // loginUser authenticates a user and returns a JWT access token. @@ -133,15 +138,39 @@ func (server *Server) loginUser(ctx *gin.Context) { return } - accessToken, err := server.tokenMaker.CreateToken(user.Username, server.config.AccessTokenDuration) + accessToken, accessPayload, err := server.tokenMaker.CreateToken(user.Username, server.config.AccessTokenDuration) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + refreshToken, refreshPayload, err := server.tokenMaker.CreateToken(user.Username, server.config.RefreshTokenDuration) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + session, err := server.store.CreateSession(ctx, db.CreateSessionParams{ + ID: refreshPayload.ID, + Username: user.Username, + RefreshToken: refreshToken, + UserAgent: ctx.Request.UserAgent(), + ClientIp: ctx.ClientIP(), + IsBlocked: false, + ExpiresAt: refreshPayload.ExpiredAt, + }) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return } resp := LoginUserResponse{ - AccessToken: accessToken, - User: newUserResponse(user), + SessionID: session.ID, + AccessToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + RefreshToken: refreshToken, + RefreshTokenExpiresAt: refreshPayload.ExpiredAt, + User: newUserResponse(user), } ctx.JSON(http.StatusOK, resp) } diff --git a/app.env b/app.env index 730c02b..59ab00f 100644 --- a/app.env +++ b/app.env @@ -3,3 +3,4 @@ DB_SOURCE=postgresql://root:root@localhost:5432/gobank?sslmode=disable SERVER_ADDRESS=0.0.0.0:8080 TOKEN_SYMMETRIC_KEY=M7vTmAXwYdAzZc48p9I5Dq7VnyzYKtLS ACCESS_TOKEN_DURATION=15m +REFRESH_TOKEN_DURATION=24h \ No newline at end of file diff --git a/db/migration/000004_add_sessions.down.sql b/db/migration/000004_add_sessions.down.sql new file mode 100644 index 0000000..9a8955b --- /dev/null +++ b/db/migration/000004_add_sessions.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "sessions"; \ No newline at end of file diff --git a/db/migration/000004_add_sessions.up.sql b/db/migration/000004_add_sessions.up.sql new file mode 100644 index 0000000..36b8654 --- /dev/null +++ b/db/migration/000004_add_sessions.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE "sessions" ( + "id" uuid PRIMARY KEY, + "username" varchar NOT NULL, + "refresh_token" varchar NOT NULL, + "user_agent" varchar NOT NULL, + "client_ip" varchar NOT NULL, + "is_blocked" boolean NOT NULL DEFAULT false, + "expires_at" timestamptz NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + +ALTER TABLE "sessions" ADD FOREIGN KEY ("username") REFERENCES "users" ("username") DEFERRABLE INITIALLY IMMEDIATE; \ No newline at end of file diff --git a/db/mock/store.go b/db/mock/store.go index ffdef84..ec0f15b 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -14,6 +14,7 @@ import ( reflect "reflect" db "github.com/HyperNaser/gobank/db/sqlc" + uuid "github.com/google/uuid" gomock "go.uber.org/mock/gomock" ) @@ -86,6 +87,21 @@ func (mr *MockStoreMockRecorder) CreateEntry(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEntry", reflect.TypeOf((*MockStore)(nil).CreateEntry), ctx, arg) } +// CreateSession mocks base method. +func (m *MockStore) CreateSession(ctx context.Context, arg db.CreateSessionParams) (db.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", ctx, arg) + ret0, _ := ret[0].(db.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockStoreMockRecorder) CreateSession(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), ctx, arg) +} + // CreateTransfer mocks base method. func (m *MockStore) CreateTransfer(ctx context.Context, arg db.CreateTransferParams) (db.Transfer, error) { m.ctrl.T.Helper() @@ -190,6 +206,21 @@ func (mr *MockStoreMockRecorder) GetEntry(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntry", reflect.TypeOf((*MockStore)(nil).GetEntry), ctx, id) } +// GetSession mocks base method. +func (m *MockStore) GetSession(ctx context.Context, id uuid.UUID) (db.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSession", ctx, id) + ret0, _ := ret[0].(db.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSession indicates an expected call of GetSession. +func (mr *MockStoreMockRecorder) GetSession(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), ctx, id) +} + // GetTransfer mocks base method. func (m *MockStore) GetTransfer(ctx context.Context, id int64) (db.Transfer, error) { m.ctrl.T.Helper() diff --git a/db/query/sessions.sql b/db/query/sessions.sql new file mode 100644 index 0000000..df164c8 --- /dev/null +++ b/db/query/sessions.sql @@ -0,0 +1,14 @@ +-- name: CreateSession :one +INSERT INTO sessions ( + id, + username, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( $1, $2, $3, $4, $5, $6, $7) RETURNING *; + +-- name: GetSession :one +SELECT * FROM sessions +WHERE id = $1 LIMIT 1; diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index c4b360b..28b814a 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: account.sql package db diff --git a/db/sqlc/db.go b/db/sqlc/db.go index cd5bbb8..f43598b 100644 --- a/db/sqlc/db.go +++ b/db/sqlc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package db diff --git a/db/sqlc/entries.sql.go b/db/sqlc/entries.sql.go index 826f584..47bf0ac 100644 --- a/db/sqlc/entries.sql.go +++ b/db/sqlc/entries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: entries.sql package db diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 988bb22..d5139a3 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -1,12 +1,14 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package db import ( "database/sql" "time" + + "github.com/google/uuid" ) type Account struct { @@ -26,6 +28,17 @@ type Entry struct { CreatedAt time.Time `json:"created_at"` } +type Session struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + type Transfer struct { ID int64 `json:"id"` FromAccountID int64 `json:"from_account_id"` diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 6b880d8..fa62e3e 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -1,23 +1,27 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package db import ( "context" + + "github.com/google/uuid" ) type Querier interface { AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) + CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAccount(ctx context.Context, id int64) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) GetEntry(ctx context.Context, id int64) (Entry, error) + GetSession(ctx context.Context, id uuid.UUID) (Session, error) GetTransfer(ctx context.Context, id int64) (Transfer, error) GetUser(ctx context.Context, username string) (User, error) ListAccountEntries(ctx context.Context, arg ListAccountEntriesParams) ([]Entry, error) diff --git a/db/sqlc/sessions.sql.go b/db/sqlc/sessions.sql.go new file mode 100644 index 0000000..acb61d2 --- /dev/null +++ b/db/sqlc/sessions.sql.go @@ -0,0 +1,80 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: sessions.sql + +package db + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO sessions ( + id, + username, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( $1, $2, $3, $4, $5, $6, $7) RETURNING id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at +` + +type CreateSessionParams struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.ID, + arg.Username, + arg.RefreshToken, + arg.UserAgent, + arg.ClientIp, + arg.IsBlocked, + arg.ExpiresAt, + ) + var i Session + err := row.Scan( + &i.ID, + &i.Username, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const getSession = `-- name: GetSession :one +SELECT id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM sessions +WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, id) + var i Session + err := row.Scan( + &i.ID, + &i.Username, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/db/sqlc/transfers.sql.go b/db/sqlc/transfers.sql.go index 319d73e..cb38e57 100644 --- a/db/sqlc/transfers.sql.go +++ b/db/sqlc/transfers.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: transfers.sql package db diff --git a/db/sqlc/users.sql.go b/db/sqlc/users.sql.go index d81ea84..3858141 100644 --- a/db/sqlc/users.sql.go +++ b/db/sqlc/users.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: users.sql package db diff --git a/token/jwt_maker.go b/token/jwt_maker.go index 4e72e18..7238434 100644 --- a/token/jwt_maker.go +++ b/token/jwt_maker.go @@ -22,10 +22,10 @@ func NewJWTMaker(secretKey string) (Maker, error) { return &JWTMaker{secretKey: secretKey}, nil } -func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (string, error) { +func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (string, *Payload, error) { payload, err := NewPayload(username, duration) if err != nil { - return "", err + return "", payload, err } claims := jwt.RegisteredClaims{ @@ -36,7 +36,8 @@ func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (str } jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return jwtToken.SignedString([]byte(maker.secretKey)) + token, err := jwtToken.SignedString([]byte(maker.secretKey)) + return token, payload, err } func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go index 22c19c7..a2aeb6c 100644 --- a/token/jwt_maker_test.go +++ b/token/jwt_maker_test.go @@ -19,11 +19,12 @@ func TestJWTMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, err := maker.CreateToken(username, duration) + token, payload, err := maker.CreateToken(username, duration) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.NoError(t, err) require.NotEmpty(t, payload) @@ -37,11 +38,12 @@ func TestExpiredJWTToken(t *testing.T) { maker, err := NewJWTMaker(util.RandomString(minSecretKeySize)) require.NoError(t, err) - token, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.Error(t, err) require.EqualError(t, err, ErrExpiredToken.Error()) require.Nil(t, payload) diff --git a/token/maker.go b/token/maker.go index 836bd1d..574eccb 100644 --- a/token/maker.go +++ b/token/maker.go @@ -3,6 +3,6 @@ package token import "time" type Maker interface { - CreateToken(username string, duration time.Duration) (string, error) + CreateToken(username string, duration time.Duration) (string, *Payload, error) VerifyToken(token string) (*Payload, error) } diff --git a/token/paseto_maker.go b/token/paseto_maker.go index 53f526f..9b9a4cb 100644 --- a/token/paseto_maker.go +++ b/token/paseto_maker.go @@ -25,13 +25,14 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) { return maker, nil } -func (maker *PasetoMaker) CreateToken(username string, duration time.Duration) (string, error) { +func (maker *PasetoMaker) CreateToken(username string, duration time.Duration) (string, *Payload, error) { payload, err := NewPayload(username, duration) if err != nil { - return "", err + return "", payload, err } - return maker.paseto.Encrypt(maker.symmetricKey, payload, nil) + token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil) + return token, payload, err } func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go index 18985db..e38f14a 100644 --- a/token/paseto_maker_test.go +++ b/token/paseto_maker_test.go @@ -19,11 +19,12 @@ func TestPasetoMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, err := maker.CreateToken(username, duration) + token, payload, err := maker.CreateToken(username, duration) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.NoError(t, err) require.NotEmpty(t, payload) @@ -37,11 +38,12 @@ func TestExpiredPasetoToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(chacha20poly1305.KeySize)) require.NoError(t, err) - token, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.Error(t, err) require.EqualError(t, err, ErrExpiredToken.Error()) require.Nil(t, payload) @@ -58,13 +60,14 @@ func TestPasetoInvalidToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(chacha20poly1305.KeySize)) require.NoError(t, err) - token, err := maker.CreateToken(util.RandomOwner(), time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), time.Minute) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) invalidToken := token + "corruption" - payload, err := maker.VerifyToken(invalidToken) + payload, err = maker.VerifyToken(invalidToken) require.Error(t, err) require.EqualError(t, err, ErrInvalidToken.Error()) require.Nil(t, payload) diff --git a/util/config.go b/util/config.go index d709cdb..640d8cb 100644 --- a/util/config.go +++ b/util/config.go @@ -7,11 +7,12 @@ import ( ) type Config struct { - DBDriver string `mapstructure:"DB_DRIVER"` - DBSource string `mapstructure:"DB_SOURCE"` - ServerAddress string `mapstructure:"SERVER_ADDRESS"` - TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` - AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` + DBDriver string `mapstructure:"DB_DRIVER"` + DBSource string `mapstructure:"DB_SOURCE"` + ServerAddress string `mapstructure:"SERVER_ADDRESS"` + TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` + AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` + RefreshTokenDuration time.Duration `mapstructure:"REFRESH_TOKEN_DURATION"` } func LoadConfig(path string) (config Config, err error) {