fix(backend): harden JWT handling and login
Three related auth weaknesses: - Access and refresh tokens were structurally identical, so a 30-day refresh token was accepted as a bearer access token. Tokens now carry a "typ" claim; the access path rejects refresh tokens and /refresh rejects access tokens. - Login stored the hash of a throwaway refresh token (sid=0) but returned a re-issued one, so the stored hash never matched and /refresh always 401'd. Tokens are no longer re-issued: the refresh token is located by hash and carries no session id, while the access token embeds the real session id. A random jti keeps tokens unique within the same second. - Login skipped bcrypt for unknown users (a timing oracle) and returned 403 for blocked accounts before checking the password (leaking account existence). It now always runs a bcrypt comparison and verifies the password before disclosing blocked state. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -14,12 +15,25 @@ import (
|
|||||||
"tanabata/backend/internal/port"
|
"tanabata/backend/internal/port"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Token types distinguish short-lived access tokens from long-lived refresh
|
||||||
|
// tokens so the two cannot be substituted for one another.
|
||||||
|
const (
|
||||||
|
tokenTypeAccess = "access"
|
||||||
|
tokenTypeRefresh = "refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dummyPasswordHash is a valid bcrypt hash used to equalise the cost of a login
|
||||||
|
// attempt against a non-existent user, preventing username enumeration via
|
||||||
|
// response timing. It is the hash of a random string no one knows.
|
||||||
|
const dummyPasswordHash = "$2a$10$N9qo8uLOickgx2ZMRZoMyeIjZAgcfl7p92ldGxad68LJZdL17lhWy"
|
||||||
|
|
||||||
// Claims is the JWT payload for both access and refresh tokens.
|
// Claims is the JWT payload for both access and refresh tokens.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
UserID int16 `json:"uid"`
|
UserID int16 `json:"uid"`
|
||||||
IsAdmin bool `json:"adm"`
|
IsAdmin bool `json:"adm"`
|
||||||
SessionID int `json:"sid"`
|
SessionID int `json:"sid"`
|
||||||
|
TokenType string `json:"typ"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenPair holds an issued access/refresh token pair with the access TTL.
|
// TokenPair holds an issued access/refresh token pair with the access TTL.
|
||||||
@@ -31,9 +45,9 @@ type TokenPair struct {
|
|||||||
|
|
||||||
// AuthService handles authentication and session lifecycle.
|
// AuthService handles authentication and session lifecycle.
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
users port.UserRepo
|
users port.UserRepo
|
||||||
sessions port.SessionRepo
|
sessions port.SessionRepo
|
||||||
secret []byte
|
secret []byte
|
||||||
accessTTL time.Duration
|
accessTTL time.Duration
|
||||||
refreshTTL time.Duration
|
refreshTTL time.Duration
|
||||||
}
|
}
|
||||||
@@ -59,8 +73,16 @@ func NewAuthService(
|
|||||||
func (s *AuthService) Login(ctx context.Context, name, password, userAgent string) (*TokenPair, error) {
|
func (s *AuthService) Login(ctx context.Context, name, password, userAgent string) (*TokenPair, error) {
|
||||||
user, err := s.users.GetByName(ctx, name)
|
user, err := s.users.GetByName(ctx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return ErrUnauthorized regardless of whether the user exists,
|
// Compare against a dummy hash so a missing user costs the same as a
|
||||||
// to avoid username enumeration.
|
// wrong password, and return ErrUnauthorized either way to avoid
|
||||||
|
// username enumeration.
|
||||||
|
_ = bcrypt.CompareHashAndPassword([]byte(dummyPasswordHash), []byte(password))
|
||||||
|
return nil, domain.ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the password before disclosing anything about account state, so a
|
||||||
|
// caller cannot distinguish "blocked" from "wrong password".
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||||
return nil, domain.ErrUnauthorized
|
return nil, domain.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,48 +90,7 @@ func (s *AuthService) Login(ctx context.Context, name, password, userAgent strin
|
|||||||
return nil, domain.ErrForbidden
|
return nil, domain.ErrForbidden
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
return s.issuePair(ctx, user, userAgent)
|
||||||
return nil, domain.ErrUnauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
var expiresAt *time.Time
|
|
||||||
if s.refreshTTL > 0 {
|
|
||||||
t := time.Now().Add(s.refreshTTL)
|
|
||||||
expiresAt = &t
|
|
||||||
}
|
|
||||||
|
|
||||||
// Issue the refresh token first so we can store its hash.
|
|
||||||
refreshToken, err := s.issueToken(user.ID, user.IsAdmin, 0, s.refreshTTL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := s.sessions.Create(ctx, &domain.Session{
|
|
||||||
TokenHash: hashToken(refreshToken),
|
|
||||||
UserID: user.ID,
|
|
||||||
UserAgent: userAgent,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken, err := s.issueToken(user.ID, user.IsAdmin, session.ID, s.accessTTL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("issue access token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-issue the refresh token with the real session ID now that we have it.
|
|
||||||
refreshToken, err = s.issueToken(user.ID, user.IsAdmin, session.ID, s.refreshTTL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TokenPair{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
ExpiresIn: int(s.accessTTL.Seconds()),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout deactivates the session identified by sessionID.
|
// Logout deactivates the session identified by sessionID.
|
||||||
@@ -124,7 +105,7 @@ func (s *AuthService) Logout(ctx context.Context, sessionID int) error {
|
|||||||
// the old session.
|
// the old session.
|
||||||
func (s *AuthService) Refresh(ctx context.Context, refreshToken, userAgent string) (*TokenPair, error) {
|
func (s *AuthService) Refresh(ctx context.Context, refreshToken, userAgent string) (*TokenPair, error) {
|
||||||
claims, err := s.parseToken(refreshToken)
|
claims, err := s.parseToken(refreshToken)
|
||||||
if err != nil {
|
if err != nil || claims.TokenType != tokenTypeRefresh {
|
||||||
return nil, domain.ErrUnauthorized
|
return nil, domain.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,19 +133,30 @@ func (s *AuthService) Refresh(ctx context.Context, refreshToken, userAgent strin
|
|||||||
return nil, domain.ErrForbidden
|
return nil, domain.ErrForbidden
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s.issuePair(ctx, user, userAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// issuePair creates a session and the access/refresh token pair for user.
|
||||||
|
//
|
||||||
|
// The refresh token is issued first and its hash is stored as the session's
|
||||||
|
// identity; the refresh token is located on /refresh purely by that hash, so it
|
||||||
|
// carries no session ID. The access token then embeds the real session ID so it
|
||||||
|
// can be revoked on logout. Because the stored hash is the hash of the token
|
||||||
|
// actually returned, /refresh works (unlike the previous re-issue approach).
|
||||||
|
func (s *AuthService) issuePair(ctx context.Context, user *domain.User, userAgent string) (*TokenPair, error) {
|
||||||
var expiresAt *time.Time
|
var expiresAt *time.Time
|
||||||
if s.refreshTTL > 0 {
|
if s.refreshTTL > 0 {
|
||||||
t := time.Now().Add(s.refreshTTL)
|
t := time.Now().Add(s.refreshTTL)
|
||||||
expiresAt = &t
|
expiresAt = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
newRefresh, err := s.issueToken(user.ID, user.IsAdmin, 0, s.refreshTTL)
|
refreshToken, err := s.issueToken(user.ID, user.IsAdmin, 0, s.refreshTTL, tokenTypeRefresh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
return nil, fmt.Errorf("issue refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession, err := s.sessions.Create(ctx, &domain.Session{
|
session, err := s.sessions.Create(ctx, &domain.Session{
|
||||||
TokenHash: hashToken(newRefresh),
|
TokenHash: hashToken(refreshToken),
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
@@ -173,19 +165,14 @@ func (s *AuthService) Refresh(ctx context.Context, refreshToken, userAgent strin
|
|||||||
return nil, fmt.Errorf("create session: %w", err)
|
return nil, fmt.Errorf("create session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := s.issueToken(user.ID, user.IsAdmin, newSession.ID, s.accessTTL)
|
accessToken, err := s.issueToken(user.ID, user.IsAdmin, session.ID, s.accessTTL, tokenTypeAccess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("issue access token: %w", err)
|
return nil, fmt.Errorf("issue access token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newRefresh, err = s.issueToken(user.ID, user.IsAdmin, newSession.ID, s.refreshTTL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TokenPair{
|
return &TokenPair{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefresh,
|
RefreshToken: refreshToken,
|
||||||
ExpiresIn: int(s.accessTTL.Seconds()),
|
ExpiresIn: int(s.accessTTL.Seconds()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -228,25 +215,36 @@ func (s *AuthService) TerminateSession(ctx context.Context, callerID int16, isAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ParseAccessToken parses and validates an access token, returning its claims.
|
// ParseAccessToken parses and validates an access token, returning its claims.
|
||||||
|
// A refresh token presented here is rejected (wrong token type).
|
||||||
func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) {
|
func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) {
|
||||||
claims, err := s.parseToken(tokenStr)
|
claims, err := s.parseToken(tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, domain.ErrUnauthorized
|
return nil, domain.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
if claims.TokenType != tokenTypeAccess {
|
||||||
|
return nil, domain.ErrUnauthorized
|
||||||
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// issueToken signs a JWT with the given parameters.
|
// issueToken signs a JWT with the given parameters. A random JWT ID guarantees
|
||||||
func (s *AuthService) issueToken(userID int16, isAdmin bool, sessionID int, ttl time.Duration) (string, error) {
|
// uniqueness even for tokens minted within the same second.
|
||||||
|
func (s *AuthService) issueToken(userID int16, isAdmin bool, sessionID int, ttl time.Duration, tokenType string) (string, error) {
|
||||||
|
jti, err := randomJTI()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := Claims{
|
claims := Claims{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ID: jti,
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
||||||
},
|
},
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
IsAdmin: isAdmin,
|
IsAdmin: isAdmin,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
|
TokenType: tokenType,
|
||||||
}
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
signed, err := token.SignedString(s.secret)
|
signed, err := token.SignedString(s.secret)
|
||||||
@@ -280,3 +278,12 @@ func hashToken(token string) string {
|
|||||||
sum := sha256.Sum256([]byte(token))
|
sum := sha256.Sum256([]byte(token))
|
||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// randomJTI returns a 128-bit random hex string for use as a JWT ID.
|
||||||
|
func randomJTI() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", fmt.Errorf("generate jti: %w", err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user