feat: implement auth service with JWT and session management
Login: bcrypt credential validation, session creation, JWT pair issuance. Logout/TerminateSession: soft-delete session (is_active = false). Refresh: token rotation — deactivate old session, issue new pair. ListSessions: marks IsCurrent by comparing session IDs. ParseAccessToken: for use by auth middleware. Claims carry uid (int16), adm (bool), sid (int). Refresh tokens are stored as SHA-256 hashes; raw tokens never reach the database. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f7cf8cb914
commit
296f44b4ed
@ -23,6 +23,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.22.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
|
||||
@ -28,6 +28,8 @@ github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4
|
||||
github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
|
||||
282
backend/internal/service/auth_service.go
Normal file
282
backend/internal/service/auth_service.go
Normal file
@ -0,0 +1,282 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"tanabata/backend/internal/domain"
|
||||
"tanabata/backend/internal/port"
|
||||
)
|
||||
|
||||
// Claims is the JWT payload for both access and refresh tokens.
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
UserID int16 `json:"uid"`
|
||||
IsAdmin bool `json:"adm"`
|
||||
SessionID int `json:"sid"`
|
||||
}
|
||||
|
||||
// TokenPair holds an issued access/refresh token pair with the access TTL.
|
||||
type TokenPair struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn int // access token TTL in seconds
|
||||
}
|
||||
|
||||
// AuthService handles authentication and session lifecycle.
|
||||
type AuthService struct {
|
||||
users port.UserRepo
|
||||
sessions port.SessionRepo
|
||||
secret []byte
|
||||
accessTTL time.Duration
|
||||
refreshTTL time.Duration
|
||||
}
|
||||
|
||||
// NewAuthService creates an AuthService.
|
||||
func NewAuthService(
|
||||
users port.UserRepo,
|
||||
sessions port.SessionRepo,
|
||||
jwtSecret string,
|
||||
accessTTL time.Duration,
|
||||
refreshTTL time.Duration,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
users: users,
|
||||
sessions: sessions,
|
||||
secret: []byte(jwtSecret),
|
||||
accessTTL: accessTTL,
|
||||
refreshTTL: refreshTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Login validates credentials, creates a session, and returns a token pair.
|
||||
func (s *AuthService) Login(ctx context.Context, name, password, userAgent string) (*TokenPair, error) {
|
||||
user, err := s.users.GetByName(ctx, name)
|
||||
if err != nil {
|
||||
// Return ErrUnauthorized regardless of whether the user exists,
|
||||
// to avoid username enumeration.
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
if user.IsBlocked {
|
||||
return nil, domain.ErrForbidden
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
var expiresAt *time.Time
|
||||
if s.refreshTTL > 0 {
|
||||
t := time.Now().Add(s.refreshTTL)
|
||||
expiresAt = &t
|
||||
}
|
||||
|
||||
// Issue the refresh token first so we can store its hash.
|
||||
refreshToken, err := s.issueToken(user.ID, user.IsAdmin, 0, s.refreshTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
session, err := s.sessions.Create(ctx, &domain.Session{
|
||||
TokenHash: hashToken(refreshToken),
|
||||
UserID: user.ID,
|
||||
UserAgent: userAgent,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
accessToken, err := s.issueToken(user.ID, user.IsAdmin, session.ID, s.accessTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue access token: %w", err)
|
||||
}
|
||||
|
||||
// Re-issue the refresh token with the real session ID now that we have it.
|
||||
refreshToken, err = s.issueToken(user.ID, user.IsAdmin, session.ID, s.refreshTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int(s.accessTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Logout deactivates the session identified by sessionID.
|
||||
func (s *AuthService) Logout(ctx context.Context, sessionID int) error {
|
||||
if err := s.sessions.Delete(ctx, sessionID); err != nil {
|
||||
return fmt.Errorf("logout: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh validates a refresh token, issues a new token pair, and deactivates
|
||||
// the old session.
|
||||
func (s *AuthService) Refresh(ctx context.Context, refreshToken, userAgent string) (*TokenPair, error) {
|
||||
claims, err := s.parseToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
session, err := s.sessions.GetByTokenHash(ctx, hashToken(refreshToken))
|
||||
if err != nil {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
if session.ExpiresAt != nil && time.Now().After(*session.ExpiresAt) {
|
||||
_ = s.sessions.Delete(ctx, session.ID)
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
// Rotate: deactivate old session.
|
||||
if err := s.sessions.Delete(ctx, session.ID); err != nil {
|
||||
return nil, fmt.Errorf("deactivate old session: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.users.GetByID(ctx, claims.UserID)
|
||||
if err != nil {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
|
||||
if user.IsBlocked {
|
||||
return nil, domain.ErrForbidden
|
||||
}
|
||||
|
||||
var expiresAt *time.Time
|
||||
if s.refreshTTL > 0 {
|
||||
t := time.Now().Add(s.refreshTTL)
|
||||
expiresAt = &t
|
||||
}
|
||||
|
||||
newRefresh, err := s.issueToken(user.ID, user.IsAdmin, 0, s.refreshTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
newSession, err := s.sessions.Create(ctx, &domain.Session{
|
||||
TokenHash: hashToken(newRefresh),
|
||||
UserID: user.ID,
|
||||
UserAgent: userAgent,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
accessToken, err := s.issueToken(user.ID, user.IsAdmin, newSession.ID, s.accessTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue access token: %w", err)
|
||||
}
|
||||
|
||||
newRefresh, err = s.issueToken(user.ID, user.IsAdmin, newSession.ID, s.refreshTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: newRefresh,
|
||||
ExpiresIn: int(s.accessTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListSessions returns all active sessions for the given user.
|
||||
func (s *AuthService) ListSessions(ctx context.Context, userID int16, currentSessionID int) (*domain.SessionList, error) {
|
||||
list, err := s.sessions.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i].IsCurrent = list.Items[i].ID == currentSessionID
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// TerminateSession deactivates a specific session, enforcing ownership.
|
||||
func (s *AuthService) TerminateSession(ctx context.Context, callerID int16, isAdmin bool, sessionID int) error {
|
||||
if !isAdmin {
|
||||
list, err := s.sessions.ListByUser(ctx, callerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("terminate session: %w", err)
|
||||
}
|
||||
owned := false
|
||||
for _, sess := range list.Items {
|
||||
if sess.ID == sessionID {
|
||||
owned = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !owned {
|
||||
return domain.ErrForbidden
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.sessions.Delete(ctx, sessionID); err != nil {
|
||||
return fmt.Errorf("terminate session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseAccessToken parses and validates an access token, returning its claims.
|
||||
func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) {
|
||||
claims, err := s.parseToken(tokenStr)
|
||||
if err != nil {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// issueToken signs a JWT with the given parameters.
|
||||
func (s *AuthService) issueToken(userID int16, isAdmin bool, sessionID int, ttl time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
||||
},
|
||||
UserID: userID,
|
||||
IsAdmin: isAdmin,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString(s.secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign token: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
// parseToken verifies the signature and parses claims from a token string.
|
||||
func (s *AuthService) parseToken(tokenStr string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return s.secret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return nil, domain.ErrUnauthorized
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// hashToken returns the SHA-256 hex digest of a token string.
|
||||
// The raw token is never stored; only the hash goes to the database.
|
||||
func hashToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user