diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 7befc13..1724132 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -92,7 +92,7 @@ func main() { transactor, cfg.ImportPath, ) - userSvc := service.NewUserService(userRepo, auditSvc) + userSvc := service.NewUserService(userRepo, sessionRepo, auditSvc) // Bootstrap the initial administrator (idempotent). if err := userSvc.EnsureAdmin(context.Background(), cfg.AdminUsername, cfg.AdminPassword); err != nil { diff --git a/backend/internal/db/postgres/session_repo.go b/backend/internal/db/postgres/session_repo.go index a7a05de..3b24930 100644 --- a/backend/internal/db/postgres/session_repo.go +++ b/backend/internal/db/postgres/session_repo.go @@ -74,6 +74,28 @@ func (r *SessionRepo) Create(ctx context.Context, s *domain.Session) (*domain.Se return &created, nil } +func (r *SessionRepo) GetByID(ctx context.Context, id int) (*domain.Session, error) { + const sql = ` + SELECT id, token_hash, user_id, user_agent, started_at, expires_at, last_activity + FROM activity.sessions + WHERE id = $1 AND is_active = true` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sql, id) + if err != nil { + return nil, fmt.Errorf("SessionRepo.GetByID: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[sessionRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("SessionRepo.GetByID scan: %w", err) + } + s := toSession(row) + return &s, nil +} + func (r *SessionRepo) GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error) { const sql = ` SELECT id, token_hash, user_id, user_agent, started_at, expires_at, last_activity diff --git a/backend/internal/handler/middleware.go b/backend/internal/handler/middleware.go index d5b0869..b5c95b4 100644 --- a/backend/internal/handler/middleware.go +++ b/backend/internal/handler/middleware.go @@ -35,7 +35,7 @@ func (m *AuthMiddleware) Handle() gin.HandlerFunc { } token := strings.TrimPrefix(raw, "Bearer ") - claims, err := m.authSvc.ParseAccessToken(token) + claims, err := m.authSvc.ValidateAccessToken(c.Request.Context(), token) if err != nil { c.JSON(http.StatusUnauthorized, errorBody{ Code: domain.ErrUnauthorized.Code(), diff --git a/backend/internal/integration/server_test.go b/backend/internal/integration/server_test.go index a0b1c84..a2922e2 100644 --- a/backend/internal/integration/server_test.go +++ b/backend/internal/integration/server_test.go @@ -131,7 +131,7 @@ func setupSuite(t *testing.T) *harness { categorySvc := service.NewCategoryService(categoryRepo, tagRepo, aclSvc, auditSvc) poolSvc := service.NewPoolService(poolRepo, aclSvc, auditSvc) fileSvc := service.NewFileService(fileRepo, mimeRepo, diskStorage, aclSvc, auditSvc, tagSvc, transactor, filesDir) - userSvc := service.NewUserService(userRepo, auditSvc) + userSvc := service.NewUserService(userRepo, sessionRepo, auditSvc) // Bootstrap the admin account the suite logs in with (replaces the old // hardcoded seed credentials). diff --git a/backend/internal/port/repository.go b/backend/internal/port/repository.go index e1ca9ef..a7d3649 100644 --- a/backend/internal/port/repository.go +++ b/backend/internal/port/repository.go @@ -132,6 +132,9 @@ type UserRepo interface { type SessionRepo interface { // ListByUser returns all active sessions for a user. ListByUser(ctx context.Context, userID int16) (*domain.SessionList, error) + // GetByID returns an active session by its ID, or ErrNotFound if it does not + // exist or has been deactivated. + GetByID(ctx context.Context, id int) (*domain.Session, error) // GetByTokenHash looks up a session by the hashed refresh token. GetByTokenHash(ctx context.Context, hash string) (*domain.Session, error) Create(ctx context.Context, s *domain.Session) (*domain.Session, error) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index c139922..a23abb1 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -214,9 +214,12 @@ func (s *AuthService) TerminateSession(ctx context.Context, callerID int16, isAd return nil } -// ParseAccessToken parses and validates an access token, returning its claims. -// A refresh token presented here is rejected (wrong token type). -func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) { +// ValidateAccessToken parses and validates an access token, returning its +// claims. A refresh token is rejected (wrong type), and the token's session +// must still be active — so logout, session termination, an admin block, or a +// refresh rotation revoke any outstanding access tokens immediately rather than +// only at expiry. +func (s *AuthService) ValidateAccessToken(ctx context.Context, tokenStr string) (*Claims, error) { claims, err := s.parseToken(tokenStr) if err != nil { return nil, domain.ErrUnauthorized @@ -224,6 +227,9 @@ func (s *AuthService) ParseAccessToken(tokenStr string) (*Claims, error) { if claims.TokenType != tokenTypeAccess { return nil, domain.ErrUnauthorized } + if _, err := s.sessions.GetByID(ctx, claims.SessionID); err != nil { + return nil, domain.ErrUnauthorized + } return claims, nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1e899f2..0123888 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -13,13 +13,14 @@ import ( // UserService handles user CRUD and profile management. type UserService struct { - users port.UserRepo - audit *AuditService + users port.UserRepo + sessions port.SessionRepo + audit *AuditService } // NewUserService creates a UserService. -func NewUserService(users port.UserRepo, audit *AuditService) *UserService { - return &UserService{users: users, audit: audit} +func NewUserService(users port.UserRepo, sessions port.SessionRepo, audit *AuditService) *UserService { + return &UserService{users: users, sessions: sessions, audit: audit} } // EnsureAdmin creates the initial administrator account if it does not already @@ -166,11 +167,15 @@ func (s *UserService) UpdateAdmin(ctx context.Context, id int16, p UpdateAdminPa return nil, err } - // Log block/unblock specifically. + // Log block/unblock specifically, and revoke all sessions on block so the + // user's outstanding access tokens stop working immediately. if p.IsBlocked != nil { action := "user_unblock" if *p.IsBlocked { action = "user_block" + if err := s.sessions.DeleteByUserID(ctx, id); err != nil { + return nil, fmt.Errorf("UserService.UpdateAdmin revoke sessions: %w", err) + } } _ = s.audit.Log(ctx, action, nil, nil, map[string]any{"target_user_id": id}) }