fix(backend): make access tokens revocable via session validation

The auth middleware trusted any unexpired, well-signed access token, so
logout, session termination and admin blocks had no effect until the
15-minute token expired. The middleware now validates that the token's
session is still active on every request (SessionRepo.GetByID), and
blocking a user deactivates all of their sessions, immediately revoking
their outstanding access tokens.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-10 14:09:25 +03:00
parent fa2acca858
commit 4645107ea1
7 changed files with 47 additions and 11 deletions
+1 -1
View File
@@ -92,7 +92,7 @@ func main() {
transactor, transactor,
cfg.ImportPath, cfg.ImportPath,
) )
userSvc := service.NewUserService(userRepo, auditSvc) userSvc := service.NewUserService(userRepo, sessionRepo, auditSvc)
// Bootstrap the initial administrator (idempotent). // Bootstrap the initial administrator (idempotent).
if err := userSvc.EnsureAdmin(context.Background(), cfg.AdminUsername, cfg.AdminPassword); err != nil { if err := userSvc.EnsureAdmin(context.Background(), cfg.AdminUsername, cfg.AdminPassword); err != nil {
@@ -74,6 +74,28 @@ func (r *SessionRepo) Create(ctx context.Context, s *domain.Session) (*domain.Se
return &created, nil return &created, nil
} }
func (r *SessionRepo) GetByID(ctx context.Context, id int) (*domain.Session, error) {
const sql = `
SELECT id, token_hash, user_id, user_agent, started_at, expires_at, last_activity
FROM activity.sessions
WHERE id = $1 AND is_active = true`
q := connOrTx(ctx, r.pool)
rows, err := q.Query(ctx, sql, id)
if err != nil {
return nil, fmt.Errorf("SessionRepo.GetByID: %w", err)
}
row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[sessionRow])
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, domain.ErrNotFound
}
return nil, fmt.Errorf("SessionRepo.GetByID scan: %w", err)
}
s := toSession(row)
return &s, nil
}
func (r *SessionRepo) GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error) { func (r *SessionRepo) GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error) {
const sql = ` const sql = `
SELECT id, token_hash, user_id, user_agent, started_at, expires_at, last_activity SELECT id, token_hash, user_id, user_agent, started_at, expires_at, last_activity
+1 -1
View File
@@ -35,7 +35,7 @@ func (m *AuthMiddleware) Handle() gin.HandlerFunc {
} }
token := strings.TrimPrefix(raw, "Bearer ") token := strings.TrimPrefix(raw, "Bearer ")
claims, err := m.authSvc.ParseAccessToken(token) claims, err := m.authSvc.ValidateAccessToken(c.Request.Context(), token)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, errorBody{ c.JSON(http.StatusUnauthorized, errorBody{
Code: domain.ErrUnauthorized.Code(), Code: domain.ErrUnauthorized.Code(),
+1 -1
View File
@@ -131,7 +131,7 @@ func setupSuite(t *testing.T) *harness {
categorySvc := service.NewCategoryService(categoryRepo, tagRepo, aclSvc, auditSvc) categorySvc := service.NewCategoryService(categoryRepo, tagRepo, aclSvc, auditSvc)
poolSvc := service.NewPoolService(poolRepo, aclSvc, auditSvc) poolSvc := service.NewPoolService(poolRepo, aclSvc, auditSvc)
fileSvc := service.NewFileService(fileRepo, mimeRepo, diskStorage, aclSvc, auditSvc, tagSvc, transactor, filesDir) fileSvc := service.NewFileService(fileRepo, mimeRepo, diskStorage, aclSvc, auditSvc, tagSvc, transactor, filesDir)
userSvc := service.NewUserService(userRepo, auditSvc) userSvc := service.NewUserService(userRepo, sessionRepo, auditSvc)
// Bootstrap the admin account the suite logs in with (replaces the old // Bootstrap the admin account the suite logs in with (replaces the old
// hardcoded seed credentials). // hardcoded seed credentials).
+3
View File
@@ -132,6 +132,9 @@ type UserRepo interface {
type SessionRepo interface { type SessionRepo interface {
// ListByUser returns all active sessions for a user. // ListByUser returns all active sessions for a user.
ListByUser(ctx context.Context, userID int16) (*domain.SessionList, error) ListByUser(ctx context.Context, userID int16) (*domain.SessionList, error)
// GetByID returns an active session by its ID, or ErrNotFound if it does not
// exist or has been deactivated.
GetByID(ctx context.Context, id int) (*domain.Session, error)
// GetByTokenHash looks up a session by the hashed refresh token. // GetByTokenHash looks up a session by the hashed refresh token.
GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error) GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error)
Create(ctx context.Context, s *domain.Session) (*domain.Session, error) Create(ctx context.Context, s *domain.Session) (*domain.Session, error)
+9 -3
View File
@@ -214,9 +214,12 @@ func (s *AuthService) TerminateSession(ctx context.Context, callerID int16, isAd
return nil return nil
} }
// ParseAccessToken parses and validates an access token, returning its claims. // ValidateAccessToken parses and validates an access token, returning its
// A refresh token presented here is rejected (wrong token type). // claims. A refresh token is rejected (wrong type), and the token's session
func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) { // must still be active — so logout, session termination, an admin block, or a
// refresh rotation revoke any outstanding access tokens immediately rather than
// only at expiry.
func (s *AuthService) ValidateAccessToken(ctx context.Context, tokenStr string) (*Claims, error) {
claims, err := s.parseToken(tokenStr) claims, err := s.parseToken(tokenStr)
if err != nil { if err != nil {
return nil, domain.ErrUnauthorized return nil, domain.ErrUnauthorized
@@ -224,6 +227,9 @@ func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) {
if claims.TokenType != tokenTypeAccess { if claims.TokenType != tokenTypeAccess {
return nil, domain.ErrUnauthorized return nil, domain.ErrUnauthorized
} }
if _, err := s.sessions.GetByID(ctx, claims.SessionID); err != nil {
return nil, domain.ErrUnauthorized
}
return claims, nil return claims, nil
} }
+10 -5
View File
@@ -13,13 +13,14 @@ import (
// UserService handles user CRUD and profile management. // UserService handles user CRUD and profile management.
type UserService struct { type UserService struct {
users port.UserRepo users port.UserRepo
audit *AuditService sessions port.SessionRepo
audit *AuditService
} }
// NewUserService creates a UserService. // NewUserService creates a UserService.
func NewUserService(users port.UserRepo, audit *AuditService) *UserService { func NewUserService(users port.UserRepo, sessions port.SessionRepo, audit *AuditService) *UserService {
return &UserService{users: users, audit: audit} return &UserService{users: users, sessions: sessions, audit: audit}
} }
// EnsureAdmin creates the initial administrator account if it does not already // EnsureAdmin creates the initial administrator account if it does not already
@@ -166,11 +167,15 @@ func (s *UserService) UpdateAdmin(ctx context.Context, id int16, p UpdateAdminPa
return nil, err return nil, err
} }
// Log block/unblock specifically. // Log block/unblock specifically, and revoke all sessions on block so the
// user's outstanding access tokens stop working immediately.
if p.IsBlocked != nil { if p.IsBlocked != nil {
action := "user_unblock" action := "user_unblock"
if *p.IsBlocked { if *p.IsBlocked {
action = "user_block" action = "user_block"
if err := s.sessions.DeleteByUserID(ctx, id); err != nil {
return nil, fmt.Errorf("UserService.UpdateAdmin revoke sessions: %w", err)
}
} }
_ = s.audit.Log(ctx, action, nil, nil, map[string]any{"target_user_id": id}) _ = s.audit.Log(ctx, action, nil, nil, map[string]any{"target_user_id": id})
} }