diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index d503edd..e4412b8 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -63,6 +63,8 @@ func main() { mimeRepo := postgres.NewMimeRepo(pool) aclRepo := postgres.NewACLRepo(pool) auditRepo := postgres.NewAuditRepo(pool) + tagRepo := postgres.NewTagRepo(pool) + tagRuleRepo := postgres.NewTagRuleRepo(pool) transactor := postgres.NewTransactor(pool) // Services @@ -75,12 +77,14 @@ func main() { ) aclSvc := service.NewACLService(aclRepo) auditSvc := service.NewAuditService(auditRepo) + tagSvc := service.NewTagService(tagRepo, tagRuleRepo, aclSvc, auditSvc, transactor) fileSvc := service.NewFileService( fileRepo, mimeRepo, diskStorage, aclSvc, auditSvc, + tagSvc, transactor, cfg.ImportPath, ) @@ -88,9 +92,10 @@ func main() { // Handlers authMiddleware := handler.NewAuthMiddleware(authSvc) authHandler := handler.NewAuthHandler(authSvc) - fileHandler := handler.NewFileHandler(fileSvc) + fileHandler := handler.NewFileHandler(fileSvc, tagSvc) + tagHandler := handler.NewTagHandler(tagSvc, fileSvc) - r := handler.NewRouter(authMiddleware, authHandler, fileHandler) + r := handler.NewRouter(authMiddleware, authHandler, fileHandler, tagHandler) slog.Info("starting server", "addr", cfg.ListenAddr) if err := r.Run(cfg.ListenAddr); err != nil { diff --git a/backend/internal/db/postgres/tag_repo.go b/backend/internal/db/postgres/tag_repo.go new file mode 100644 index 0000000..458263d --- /dev/null +++ b/backend/internal/db/postgres/tag_repo.go @@ -0,0 +1,607 @@ +package postgres + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/google/uuid" + + "tanabata/backend/internal/domain" + "tanabata/backend/internal/port" +) + +// --------------------------------------------------------------------------- +// Row structs — use pgx-scannable types +// --------------------------------------------------------------------------- + +type tagRow struct { + ID uuid.UUID `db:"id"` + Name string `db:"name"` + Notes *string `db:"notes"` + Color *string `db:"color"` + CategoryID *uuid.UUID `db:"category_id"` + CategoryName *string `db:"category_name"` + CategoryColor *string `db:"category_color"` + Metadata []byte `db:"metadata"` + CreatorID int16 `db:"creator_id"` + CreatorName string `db:"creator_name"` + IsPublic bool `db:"is_public"` +} + +type tagRowWithTotal struct { + tagRow + Total int `db:"total"` +} + +type tagRuleRow struct { + WhenTagID uuid.UUID `db:"when_tag_id"` + ThenTagID uuid.UUID `db:"then_tag_id"` + ThenTagName string `db:"then_tag_name"` + IsActive bool `db:"is_active"` +} + +// --------------------------------------------------------------------------- +// Converters +// --------------------------------------------------------------------------- + +func toTag(r tagRow) domain.Tag { + t := domain.Tag{ + ID: r.ID, + Name: r.Name, + Notes: r.Notes, + Color: r.Color, + CategoryID: r.CategoryID, + CategoryName: r.CategoryName, + CategoryColor: r.CategoryColor, + CreatorID: r.CreatorID, + CreatorName: r.CreatorName, + IsPublic: r.IsPublic, + CreatedAt: domain.UUIDCreatedAt(r.ID), + } + if len(r.Metadata) > 0 && string(r.Metadata) != "null" { + t.Metadata = json.RawMessage(r.Metadata) + } + return t +} + +func toTagRule(r tagRuleRow) domain.TagRule { + return domain.TagRule{ + WhenTagID: r.WhenTagID, + ThenTagID: r.ThenTagID, + ThenTagName: r.ThenTagName, + IsActive: r.IsActive, + } +} + +// --------------------------------------------------------------------------- +// Shared SQL fragments +// --------------------------------------------------------------------------- + +const tagSelectFrom = ` +SELECT + t.id, + t.name, + t.notes, + t.color, + t.category_id, + c.name AS category_name, + c.color AS category_color, + t.metadata, + t.creator_id, + u.name AS creator_name, + t.is_public +FROM data.tags t +LEFT JOIN data.categories c ON c.id = t.category_id +JOIN core.users u ON u.id = t.creator_id` + +func tagSortColumn(s string) string { + switch s { + case "name": + return "t.name" + case "color": + return "t.color" + case "category_name": + return "c.name" + default: // "created" + return "t.id" + } +} + +// isPgUniqueViolation reports whether err is a PostgreSQL unique-constraint error. +func isPgUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + return errors.As(err, &pgErr) && pgErr.Code == "23505" +} + +// --------------------------------------------------------------------------- +// TagRepo — implements port.TagRepo +// --------------------------------------------------------------------------- + +// TagRepo handles tag CRUD and file–tag relations. +type TagRepo struct { + pool *pgxpool.Pool +} + +var _ port.TagRepo = (*TagRepo)(nil) + +// NewTagRepo creates a TagRepo backed by pool. +func NewTagRepo(pool *pgxpool.Pool) *TagRepo { + return &TagRepo{pool: pool} +} + +// --------------------------------------------------------------------------- +// List / ListByCategory +// --------------------------------------------------------------------------- + +func (r *TagRepo) List(ctx context.Context, params port.OffsetParams) (*domain.TagOffsetPage, error) { + return r.listTags(ctx, params, nil) +} + +func (r *TagRepo) ListByCategory(ctx context.Context, categoryID uuid.UUID, params port.OffsetParams) (*domain.TagOffsetPage, error) { + return r.listTags(ctx, params, &categoryID) +} + +func (r *TagRepo) listTags(ctx context.Context, params port.OffsetParams, categoryID *uuid.UUID) (*domain.TagOffsetPage, error) { + order := "ASC" + if strings.ToLower(params.Order) == "desc" { + order = "DESC" + } + sortCol := tagSortColumn(params.Sort) + + args := []any{} + n := 1 + var conditions []string + + if params.Search != "" { + conditions = append(conditions, fmt.Sprintf("lower(t.name) LIKE lower($%d)", n)) + args = append(args, "%"+params.Search+"%") + n++ + } + if categoryID != nil { + conditions = append(conditions, fmt.Sprintf("t.category_id = $%d", n)) + args = append(args, *categoryID) + n++ + } + + where := "" + if len(conditions) > 0 { + where = "WHERE " + strings.Join(conditions, " AND ") + } + + limit := params.Limit + if limit <= 0 { + limit = 50 + } + offset := params.Offset + if offset < 0 { + offset = 0 + } + + query := fmt.Sprintf(` +SELECT + t.id, t.name, t.notes, t.color, + t.category_id, + c.name AS category_name, + c.color AS category_color, + t.metadata, t.creator_id, + u.name AS creator_name, + t.is_public, + COUNT(*) OVER() AS total +FROM data.tags t +LEFT JOIN data.categories c ON c.id = t.category_id +JOIN core.users u ON u.id = t.creator_id +%s +ORDER BY %s %s NULLS LAST, t.id ASC +LIMIT $%d OFFSET $%d`, where, sortCol, order, n, n+1) + + args = append(args, limit, offset) + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("TagRepo.List query: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[tagRowWithTotal]) + if err != nil { + return nil, fmt.Errorf("TagRepo.List scan: %w", err) + } + + items := make([]domain.Tag, len(collected)) + total := 0 + for i, row := range collected { + items[i] = toTag(row.tagRow) + total = row.Total + } + return &domain.TagOffsetPage{ + Items: items, + Total: total, + Offset: offset, + Limit: limit, + }, nil +} + +// --------------------------------------------------------------------------- +// GetByID +// --------------------------------------------------------------------------- + +func (r *TagRepo) GetByID(ctx context.Context, id uuid.UUID) (*domain.Tag, error) { + const query = tagSelectFrom + ` +WHERE t.id = $1` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, id) + if err != nil { + return nil, fmt.Errorf("TagRepo.GetByID: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[tagRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("TagRepo.GetByID scan: %w", err) + } + t := toTag(row) + return &t, nil +} + +// --------------------------------------------------------------------------- +// Create +// --------------------------------------------------------------------------- + +func (r *TagRepo) Create(ctx context.Context, t *domain.Tag) (*domain.Tag, error) { + const query = ` +WITH ins AS ( + INSERT INTO data.tags (name, notes, color, category_id, metadata, creator_id, is_public) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING * +) +SELECT + ins.id, ins.name, ins.notes, ins.color, + ins.category_id, + c.name AS category_name, + c.color AS category_color, + ins.metadata, ins.creator_id, + u.name AS creator_name, + ins.is_public +FROM ins +LEFT JOIN data.categories c ON c.id = ins.category_id +JOIN core.users u ON u.id = ins.creator_id` + + var meta any + if len(t.Metadata) > 0 { + meta = t.Metadata + } + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, + t.Name, t.Notes, t.Color, t.CategoryID, meta, t.CreatorID, t.IsPublic) + if err != nil { + return nil, fmt.Errorf("TagRepo.Create: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[tagRow]) + if err != nil { + if isPgUniqueViolation(err) { + return nil, domain.ErrConflict + } + return nil, fmt.Errorf("TagRepo.Create scan: %w", err) + } + created := toTag(row) + return &created, nil +} + +// --------------------------------------------------------------------------- +// Update +// --------------------------------------------------------------------------- + +// Update replaces all mutable fields. The caller must merge current values with +// the patch (read-then-write) before calling this. +func (r *TagRepo) Update(ctx context.Context, id uuid.UUID, t *domain.Tag) (*domain.Tag, error) { + const query = ` +WITH upd AS ( + UPDATE data.tags SET + name = $2, + notes = $3, + color = $4, + category_id = $5, + metadata = COALESCE($6, metadata), + is_public = $7 + WHERE id = $1 + RETURNING * +) +SELECT + upd.id, upd.name, upd.notes, upd.color, + upd.category_id, + c.name AS category_name, + c.color AS category_color, + upd.metadata, upd.creator_id, + u.name AS creator_name, + upd.is_public +FROM upd +LEFT JOIN data.categories c ON c.id = upd.category_id +JOIN core.users u ON u.id = upd.creator_id` + + var meta any + if len(t.Metadata) > 0 { + meta = t.Metadata + } + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, + id, t.Name, t.Notes, t.Color, t.CategoryID, meta, t.IsPublic) + if err != nil { + return nil, fmt.Errorf("TagRepo.Update: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[tagRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + if isPgUniqueViolation(err) { + return nil, domain.ErrConflict + } + return nil, fmt.Errorf("TagRepo.Update scan: %w", err) + } + updated := toTag(row) + return &updated, nil +} + +// --------------------------------------------------------------------------- +// Delete +// --------------------------------------------------------------------------- + +func (r *TagRepo) Delete(ctx context.Context, id uuid.UUID) error { + const query = `DELETE FROM data.tags WHERE id = $1` + + q := connOrTx(ctx, r.pool) + ct, err := q.Exec(ctx, query, id) + if err != nil { + return fmt.Errorf("TagRepo.Delete: %w", err) + } + if ct.RowsAffected() == 0 { + return domain.ErrNotFound + } + return nil +} + +// --------------------------------------------------------------------------- +// File–tag operations +// --------------------------------------------------------------------------- + +func (r *TagRepo) ListByFile(ctx context.Context, fileID uuid.UUID) ([]domain.Tag, error) { + const query = tagSelectFrom + ` +JOIN data.file_tag ft ON ft.tag_id = t.id +WHERE ft.file_id = $1 +ORDER BY t.name` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, fileID) + if err != nil { + return nil, fmt.Errorf("TagRepo.ListByFile: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[tagRow]) + if err != nil { + return nil, fmt.Errorf("TagRepo.ListByFile scan: %w", err) + } + tags := make([]domain.Tag, len(collected)) + for i, row := range collected { + tags[i] = toTag(row) + } + return tags, nil +} + +func (r *TagRepo) AddFileTag(ctx context.Context, fileID, tagID uuid.UUID) error { + const query = ` +INSERT INTO data.file_tag (file_id, tag_id) VALUES ($1, $2) +ON CONFLICT DO NOTHING` + + q := connOrTx(ctx, r.pool) + if _, err := q.Exec(ctx, query, fileID, tagID); err != nil { + return fmt.Errorf("TagRepo.AddFileTag: %w", err) + } + return nil +} + +func (r *TagRepo) RemoveFileTag(ctx context.Context, fileID, tagID uuid.UUID) error { + const query = `DELETE FROM data.file_tag WHERE file_id = $1 AND tag_id = $2` + + q := connOrTx(ctx, r.pool) + if _, err := q.Exec(ctx, query, fileID, tagID); err != nil { + return fmt.Errorf("TagRepo.RemoveFileTag: %w", err) + } + return nil +} + +func (r *TagRepo) SetFileTags(ctx context.Context, fileID uuid.UUID, tagIDs []uuid.UUID) error { + q := connOrTx(ctx, r.pool) + + if _, err := q.Exec(ctx, + `DELETE FROM data.file_tag WHERE file_id = $1`, fileID); err != nil { + return fmt.Errorf("TagRepo.SetFileTags delete: %w", err) + } + if len(tagIDs) == 0 { + return nil + } + + placeholders := make([]string, len(tagIDs)) + args := []any{fileID} + for i, tagID := range tagIDs { + placeholders[i] = fmt.Sprintf("($1, $%d)", i+2) + args = append(args, tagID) + } + ins := `INSERT INTO data.file_tag (file_id, tag_id) VALUES ` + + strings.Join(placeholders, ", ") + ` ON CONFLICT DO NOTHING` + + if _, err := q.Exec(ctx, ins, args...); err != nil { + return fmt.Errorf("TagRepo.SetFileTags insert: %w", err) + } + return nil +} + +func (r *TagRepo) CommonTagsForFiles(ctx context.Context, fileIDs []uuid.UUID) ([]domain.Tag, error) { + if len(fileIDs) == 0 { + return []domain.Tag{}, nil + } + return r.queryTagsByPresence(ctx, fileIDs, "=") +} + +func (r *TagRepo) PartialTagsForFiles(ctx context.Context, fileIDs []uuid.UUID) ([]domain.Tag, error) { + if len(fileIDs) == 0 { + return []domain.Tag{}, nil + } + return r.queryTagsByPresence(ctx, fileIDs, "<") +} + +func (r *TagRepo) queryTagsByPresence(ctx context.Context, fileIDs []uuid.UUID, op string) ([]domain.Tag, error) { + placeholders := make([]string, len(fileIDs)) + args := make([]any, len(fileIDs)+1) + for i, id := range fileIDs { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + args[len(fileIDs)] = len(fileIDs) + n := len(fileIDs) + 1 + + query := fmt.Sprintf(` +SELECT + t.id, t.name, t.notes, t.color, + t.category_id, + c.name AS category_name, + c.color AS category_color, + t.metadata, t.creator_id, + u.name AS creator_name, + t.is_public +FROM data.tags t +JOIN data.file_tag ft ON ft.tag_id = t.id +LEFT JOIN data.categories c ON c.id = t.category_id +JOIN core.users u ON u.id = t.creator_id +WHERE ft.file_id IN (%s) +GROUP BY t.id, c.id, u.id +HAVING COUNT(DISTINCT ft.file_id) %s $%d +ORDER BY t.name`, + strings.Join(placeholders, ", "), op, n) + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("TagRepo.queryTagsByPresence: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[tagRow]) + if err != nil { + return nil, fmt.Errorf("TagRepo.queryTagsByPresence scan: %w", err) + } + tags := make([]domain.Tag, len(collected)) + for i, row := range collected { + tags[i] = toTag(row) + } + return tags, nil +} + +// --------------------------------------------------------------------------- +// TagRuleRepo — implements port.TagRuleRepo (separate type to avoid method collision) +// --------------------------------------------------------------------------- + +// TagRuleRepo handles tag-rule CRUD. +type TagRuleRepo struct { + pool *pgxpool.Pool +} + +var _ port.TagRuleRepo = (*TagRuleRepo)(nil) + +// NewTagRuleRepo creates a TagRuleRepo backed by pool. +func NewTagRuleRepo(pool *pgxpool.Pool) *TagRuleRepo { + return &TagRuleRepo{pool: pool} +} + +func (r *TagRuleRepo) ListByTag(ctx context.Context, tagID uuid.UUID) ([]domain.TagRule, error) { + const query = ` +SELECT + tr.when_tag_id, + tr.then_tag_id, + t.name AS then_tag_name, + tr.is_active +FROM data.tag_rules tr +JOIN data.tags t ON t.id = tr.then_tag_id +WHERE tr.when_tag_id = $1 +ORDER BY t.name` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, tagID) + if err != nil { + return nil, fmt.Errorf("TagRuleRepo.ListByTag: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[tagRuleRow]) + if err != nil { + return nil, fmt.Errorf("TagRuleRepo.ListByTag scan: %w", err) + } + rules := make([]domain.TagRule, len(collected)) + for i, row := range collected { + rules[i] = toTagRule(row) + } + return rules, nil +} + +func (r *TagRuleRepo) Create(ctx context.Context, rule domain.TagRule) (*domain.TagRule, error) { + const query = ` +WITH ins AS ( + INSERT INTO data.tag_rules (when_tag_id, then_tag_id, is_active) + VALUES ($1, $2, $3) + RETURNING * +) +SELECT ins.when_tag_id, ins.then_tag_id, t.name AS then_tag_name, ins.is_active +FROM ins +JOIN data.tags t ON t.id = ins.then_tag_id` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, query, rule.WhenTagID, rule.ThenTagID, rule.IsActive) + if err != nil { + return nil, fmt.Errorf("TagRuleRepo.Create: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[tagRuleRow]) + if err != nil { + if isPgUniqueViolation(err) { + return nil, domain.ErrConflict + } + return nil, fmt.Errorf("TagRuleRepo.Create scan: %w", err) + } + result := toTagRule(row) + return &result, nil +} + +func (r *TagRuleRepo) SetActive(ctx context.Context, whenTagID, thenTagID uuid.UUID, active bool) error { + const query = ` +UPDATE data.tag_rules SET is_active = $3 +WHERE when_tag_id = $1 AND then_tag_id = $2` + + q := connOrTx(ctx, r.pool) + ct, err := q.Exec(ctx, query, whenTagID, thenTagID, active) + if err != nil { + return fmt.Errorf("TagRuleRepo.SetActive: %w", err) + } + if ct.RowsAffected() == 0 { + return domain.ErrNotFound + } + return nil +} + +func (r *TagRuleRepo) Delete(ctx context.Context, whenTagID, thenTagID uuid.UUID) error { + const query = ` +DELETE FROM data.tag_rules +WHERE when_tag_id = $1 AND then_tag_id = $2` + + q := connOrTx(ctx, r.pool) + ct, err := q.Exec(ctx, query, whenTagID, thenTagID) + if err != nil { + return fmt.Errorf("TagRuleRepo.Delete: %w", err) + } + if ct.RowsAffected() == 0 { + return domain.ErrNotFound + } + return nil +} \ No newline at end of file diff --git a/backend/internal/handler/file_handler.go b/backend/internal/handler/file_handler.go index b810b67..df2cbe6 100644 --- a/backend/internal/handler/file_handler.go +++ b/backend/internal/handler/file_handler.go @@ -20,11 +20,12 @@ import ( // FileHandler handles all /files endpoints. type FileHandler struct { fileSvc *service.FileService + tagSvc *service.TagService } // NewFileHandler creates a FileHandler. -func NewFileHandler(fileSvc *service.FileService) *FileHandler { - return &FileHandler{fileSvc: fileSvc} +func NewFileHandler(fileSvc *service.FileService, tagSvc *service.TagService) *FileHandler { + return &FileHandler{fileSvc: fileSvc, tagSvc: tagSvc} } // --------------------------------------------------------------------------- @@ -498,117 +499,6 @@ func (h *FileHandler) PermanentDelete(c *gin.Context) { c.Status(http.StatusNoContent) } -// --------------------------------------------------------------------------- -// GET /files/:id/tags -// --------------------------------------------------------------------------- - -func (h *FileHandler) ListTags(c *gin.Context) { - id, ok := parseFileID(c) - if !ok { - return - } - - tags, err := h.fileSvc.ListFileTags(c.Request.Context(), id) - if err != nil { - respondError(c, err) - return - } - - items := make([]tagJSON, len(tags)) - for i, t := range tags { - items[i] = toTagJSON(t) - } - respondJSON(c, http.StatusOK, items) -} - -// --------------------------------------------------------------------------- -// PUT /files/:id/tags (replace all) -// --------------------------------------------------------------------------- - -func (h *FileHandler) SetTags(c *gin.Context) { - id, ok := parseFileID(c) - if !ok { - return - } - - var body struct { - TagIDs []string `json:"tag_ids" binding:"required"` - } - if err := c.ShouldBindJSON(&body); err != nil { - respondError(c, domain.ErrValidation) - return - } - - tagIDs, err := parseUUIDs(body.TagIDs) - if err != nil { - respondError(c, domain.ErrValidation) - return - } - - tags, err := h.fileSvc.SetFileTags(c.Request.Context(), id, tagIDs) - if err != nil { - respondError(c, err) - return - } - - items := make([]tagJSON, len(tags)) - for i, t := range tags { - items[i] = toTagJSON(t) - } - respondJSON(c, http.StatusOK, items) -} - -// --------------------------------------------------------------------------- -// PUT /files/:id/tags/:tag_id -// --------------------------------------------------------------------------- - -func (h *FileHandler) AddTag(c *gin.Context) { - fileID, ok := parseFileID(c) - if !ok { - return - } - tagID, err := uuid.Parse(c.Param("tag_id")) - if err != nil { - respondError(c, domain.ErrValidation) - return - } - - tags, err := h.fileSvc.AddTag(c.Request.Context(), fileID, tagID) - if err != nil { - respondError(c, err) - return - } - - items := make([]tagJSON, len(tags)) - for i, t := range tags { - items[i] = toTagJSON(t) - } - respondJSON(c, http.StatusOK, items) -} - -// --------------------------------------------------------------------------- -// DELETE /files/:id/tags/:tag_id -// --------------------------------------------------------------------------- - -func (h *FileHandler) RemoveTag(c *gin.Context) { - fileID, ok := parseFileID(c) - if !ok { - return - } - tagID, err := uuid.Parse(c.Param("tag_id")) - if err != nil { - respondError(c, domain.ErrValidation) - return - } - - if err := h.fileSvc.RemoveTag(c.Request.Context(), fileID, tagID); err != nil { - respondError(c, err) - return - } - - c.Status(http.StatusNoContent) -} - // --------------------------------------------------------------------------- // POST /files/bulk/tags // --------------------------------------------------------------------------- @@ -639,7 +529,7 @@ func (h *FileHandler) BulkSetTags(c *gin.Context) { return } - applied, err := h.fileSvc.BulkSetTags(c.Request.Context(), fileIDs, body.Action, tagIDs) + applied, err := h.tagSvc.BulkSetTags(c.Request.Context(), fileIDs, body.Action, tagIDs) if err != nil { respondError(c, err) return @@ -698,16 +588,16 @@ func (h *FileHandler) CommonTags(c *gin.Context) { return } - common, partial, err := h.fileSvc.CommonTags(c.Request.Context(), fileIDs) + common, partial, err := h.tagSvc.CommonTags(c.Request.Context(), fileIDs) if err != nil { respondError(c, err) return } - toStrs := func(ids []uuid.UUID) []string { - s := make([]string, len(ids)) - for i, id := range ids { - s[i] = id.String() + toStrs := func(tags []domain.Tag) []string { + s := make([]string, len(tags)) + for i, t := range tags { + s[i] = t.ID.String() } return s } diff --git a/backend/internal/handler/router.go b/backend/internal/handler/router.go index 0128455..07a70b1 100644 --- a/backend/internal/handler/router.go +++ b/backend/internal/handler/router.go @@ -7,7 +7,12 @@ import ( ) // NewRouter builds and returns a configured Gin engine. -func NewRouter(auth *AuthMiddleware, authHandler *AuthHandler, fileHandler *FileHandler) *gin.Engine { +func NewRouter( + auth *AuthMiddleware, + authHandler *AuthHandler, + fileHandler *FileHandler, + tagHandler *TagHandler, +) *gin.Engine { r := gin.New() r.Use(gin.Logger(), gin.Recovery()) @@ -18,7 +23,9 @@ func NewRouter(auth *AuthMiddleware, authHandler *AuthHandler, fileHandler *File v1 := r.Group("/api/v1") - // Auth endpoints — login and refresh are public; others require a valid token. + // ------------------------------------------------------------------------- + // Auth + // ------------------------------------------------------------------------- authGroup := v1.Group("/auth") { authGroup.POST("/login", authHandler.Login) @@ -32,13 +39,15 @@ func NewRouter(auth *AuthMiddleware, authHandler *AuthHandler, fileHandler *File } } - // File endpoints — all require authentication. + // ------------------------------------------------------------------------- + // Files (all require auth) + // ------------------------------------------------------------------------- files := v1.Group("/files", auth.Handle()) { files.GET("", fileHandler.List) files.POST("", fileHandler.Upload) - // Bulk routes must be registered before /:id to avoid ambiguity. + // Bulk + import routes registered before /:id to prevent param collision. files.POST("/bulk/tags", fileHandler.BulkSetTags) files.POST("/bulk/delete", fileHandler.BulkDelete) files.POST("/bulk/common-tags", fileHandler.CommonTags) @@ -56,10 +65,30 @@ func NewRouter(auth *AuthMiddleware, authHandler *AuthHandler, fileHandler *File files.POST("/:id/restore", fileHandler.Restore) files.DELETE("/:id/permanent", fileHandler.PermanentDelete) - files.GET("/:id/tags", fileHandler.ListTags) - files.PUT("/:id/tags", fileHandler.SetTags) - files.PUT("/:id/tags/:tag_id", fileHandler.AddTag) - files.DELETE("/:id/tags/:tag_id", fileHandler.RemoveTag) + // File–tag relations — served by TagHandler for auto-rule support. + files.GET("/:id/tags", tagHandler.FileListTags) + files.PUT("/:id/tags", tagHandler.FileSetTags) + files.PUT("/:id/tags/:tag_id", tagHandler.FileAddTag) + files.DELETE("/:id/tags/:tag_id", tagHandler.FileRemoveTag) + } + + // ------------------------------------------------------------------------- + // Tags (all require auth) + // ------------------------------------------------------------------------- + tags := v1.Group("/tags", auth.Handle()) + { + tags.GET("", tagHandler.List) + tags.POST("", tagHandler.Create) + + tags.GET("/:tag_id", tagHandler.Get) + tags.PATCH("/:tag_id", tagHandler.Update) + tags.DELETE("/:tag_id", tagHandler.Delete) + + tags.GET("/:tag_id/files", tagHandler.ListFiles) + + tags.GET("/:tag_id/rules", tagHandler.ListRules) + tags.POST("/:tag_id/rules", tagHandler.CreateRule) + tags.DELETE("/:tag_id/rules/:then_tag_id", tagHandler.DeleteRule) } return r diff --git a/backend/internal/handler/tag_handler.go b/backend/internal/handler/tag_handler.go new file mode 100644 index 0000000..3239689 --- /dev/null +++ b/backend/internal/handler/tag_handler.go @@ -0,0 +1,493 @@ +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "tanabata/backend/internal/domain" + "tanabata/backend/internal/port" + "tanabata/backend/internal/service" +) + +// TagHandler handles all /tags endpoints. +type TagHandler struct { + tagSvc *service.TagService + fileSvc *service.FileService +} + +// NewTagHandler creates a TagHandler. +func NewTagHandler(tagSvc *service.TagService, fileSvc *service.FileService) *TagHandler { + return &TagHandler{tagSvc: tagSvc, fileSvc: fileSvc} +} + +// --------------------------------------------------------------------------- +// Response types +// --------------------------------------------------------------------------- + +type tagRuleJSON struct { + WhenTagID string `json:"when_tag_id"` + ThenTagID string `json:"then_tag_id"` + ThenTagName string `json:"then_tag_name"` + IsActive bool `json:"is_active"` +} + +func toTagRuleJSON(r domain.TagRule) tagRuleJSON { + return tagRuleJSON{ + WhenTagID: r.WhenTagID.String(), + ThenTagID: r.ThenTagID.String(), + ThenTagName: r.ThenTagName, + IsActive: r.IsActive, + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func parseTagID(c *gin.Context) (uuid.UUID, bool) { + id, err := uuid.Parse(c.Param("tag_id")) + if err != nil { + respondError(c, domain.ErrValidation) + return uuid.UUID{}, false + } + return id, true +} + +func parseOffsetParams(c *gin.Context, defaultSort string) port.OffsetParams { + limit := 50 + if s := c.Query("limit"); s != "" { + if n, err := strconv.Atoi(s); err == nil && n > 0 && n <= 200 { + limit = n + } + } + offset := 0 + if s := c.Query("offset"); s != "" { + if n, err := strconv.Atoi(s); err == nil && n >= 0 { + offset = n + } + } + sort := c.DefaultQuery("sort", defaultSort) + order := c.DefaultQuery("order", "desc") + search := c.Query("search") + return port.OffsetParams{Sort: sort, Order: order, Search: search, Limit: limit, Offset: offset} +} + +// --------------------------------------------------------------------------- +// GET /tags +// --------------------------------------------------------------------------- + +func (h *TagHandler) List(c *gin.Context) { + params := parseOffsetParams(c, "created") + + page, err := h.tagSvc.List(c.Request.Context(), params) + if err != nil { + respondError(c, err) + return + } + + items := make([]tagJSON, len(page.Items)) + for i, t := range page.Items { + items[i] = toTagJSON(t) + } + respondJSON(c, http.StatusOK, gin.H{ + "items": items, + "total": page.Total, + "offset": page.Offset, + "limit": page.Limit, + }) +} + +// --------------------------------------------------------------------------- +// POST /tags +// --------------------------------------------------------------------------- + +func (h *TagHandler) Create(c *gin.Context) { + var body struct { + Name string `json:"name" binding:"required"` + Notes *string `json:"notes"` + Color *string `json:"color"` + CategoryID *string `json:"category_id"` + IsPublic *bool `json:"is_public"` + } + if err := c.ShouldBindJSON(&body); err != nil { + respondError(c, domain.ErrValidation) + return + } + + params := service.TagParams{ + Name: body.Name, + Notes: body.Notes, + Color: body.Color, + IsPublic: body.IsPublic, + } + if body.CategoryID != nil { + id, err := uuid.Parse(*body.CategoryID) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + params.CategoryID = &id + } + + t, err := h.tagSvc.Create(c.Request.Context(), params) + if err != nil { + respondError(c, err) + return + } + + respondJSON(c, http.StatusCreated, toTagJSON(*t)) +} + +// --------------------------------------------------------------------------- +// GET /tags/:tag_id +// --------------------------------------------------------------------------- + +func (h *TagHandler) Get(c *gin.Context) { + id, ok := parseTagID(c) + if !ok { + return + } + + t, err := h.tagSvc.Get(c.Request.Context(), id) + if err != nil { + respondError(c, err) + return + } + + respondJSON(c, http.StatusOK, toTagJSON(*t)) +} + +// --------------------------------------------------------------------------- +// PATCH /tags/:tag_id +// --------------------------------------------------------------------------- + +func (h *TagHandler) Update(c *gin.Context) { + id, ok := parseTagID(c) + if !ok { + return + } + + // Use a raw map to distinguish "field absent" from "field = null". + var raw map[string]any + if err := c.ShouldBindJSON(&raw); err != nil { + respondError(c, domain.ErrValidation) + return + } + + params := service.TagParams{} + + if v, ok := raw["name"]; ok { + if s, ok := v.(string); ok { + params.Name = s + } + } + if _, ok := raw["notes"]; ok { + if raw["notes"] == nil { + params.Notes = ptr("") + } else if s, ok := raw["notes"].(string); ok { + params.Notes = &s + } + } + if _, ok := raw["color"]; ok { + if raw["color"] == nil { + nilStr := "" + params.Color = &nilStr + } else if s, ok := raw["color"].(string); ok { + params.Color = &s + } + } + if _, ok := raw["category_id"]; ok { + if raw["category_id"] == nil { + nilID := uuid.Nil + params.CategoryID = &nilID // signals "unassign" + } else if s, ok := raw["category_id"].(string); ok { + cid, err := uuid.Parse(s) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + params.CategoryID = &cid + } + } + if v, ok := raw["is_public"]; ok { + if b, ok := v.(bool); ok { + params.IsPublic = &b + } + } + + t, err := h.tagSvc.Update(c.Request.Context(), id, params) + if err != nil { + respondError(c, err) + return + } + + respondJSON(c, http.StatusOK, toTagJSON(*t)) +} + +// --------------------------------------------------------------------------- +// DELETE /tags/:tag_id +// --------------------------------------------------------------------------- + +func (h *TagHandler) Delete(c *gin.Context) { + id, ok := parseTagID(c) + if !ok { + return + } + + if err := h.tagSvc.Delete(c.Request.Context(), id); err != nil { + respondError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +// --------------------------------------------------------------------------- +// GET /tags/:tag_id/files +// --------------------------------------------------------------------------- + +func (h *TagHandler) ListFiles(c *gin.Context) { + id, ok := parseTagID(c) + if !ok { + return + } + + limit := 50 + if s := c.Query("limit"); s != "" { + if n, err := strconv.Atoi(s); err == nil && n > 0 && n <= 200 { + limit = n + } + } + + // Delegate to file service with a tag filter. + page, err := h.fileSvc.List(c.Request.Context(), domain.FileListParams{ + Cursor: c.Query("cursor"), + Direction: "forward", + Limit: limit, + Sort: "created", + Order: "desc", + Filter: "{t=" + id.String() + "}", + }) + if err != nil { + respondError(c, err) + return + } + + items := make([]fileJSON, len(page.Items)) + for i, f := range page.Items { + items[i] = toFileJSON(f) + } + respondJSON(c, http.StatusOK, gin.H{ + "items": items, + "next_cursor": page.NextCursor, + "prev_cursor": page.PrevCursor, + }) +} + +// --------------------------------------------------------------------------- +// GET /tags/:tag_id/rules +// --------------------------------------------------------------------------- + +func (h *TagHandler) ListRules(c *gin.Context) { + id, ok := parseTagID(c) + if !ok { + return + } + + rules, err := h.tagSvc.ListRules(c.Request.Context(), id) + if err != nil { + respondError(c, err) + return + } + + items := make([]tagRuleJSON, len(rules)) + for i, r := range rules { + items[i] = toTagRuleJSON(r) + } + respondJSON(c, http.StatusOK, items) +} + +// --------------------------------------------------------------------------- +// POST /tags/:tag_id/rules +// --------------------------------------------------------------------------- + +func (h *TagHandler) CreateRule(c *gin.Context) { + whenTagID, ok := parseTagID(c) + if !ok { + return + } + + var body struct { + ThenTagID string `json:"then_tag_id" binding:"required"` + IsActive *bool `json:"is_active"` + ApplyToExisting *bool `json:"apply_to_existing"` + } + if err := c.ShouldBindJSON(&body); err != nil { + respondError(c, domain.ErrValidation) + return + } + + thenTagID, err := uuid.Parse(body.ThenTagID) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + isActive := true + if body.IsActive != nil { + isActive = *body.IsActive + } + applyToExisting := true + if body.ApplyToExisting != nil { + applyToExisting = *body.ApplyToExisting + } + + rule, err := h.tagSvc.CreateRule(c.Request.Context(), whenTagID, thenTagID, isActive, applyToExisting) + if err != nil { + respondError(c, err) + return + } + + respondJSON(c, http.StatusCreated, toTagRuleJSON(*rule)) +} + +// --------------------------------------------------------------------------- +// DELETE /tags/:tag_id/rules/:then_tag_id +// --------------------------------------------------------------------------- + +func (h *TagHandler) DeleteRule(c *gin.Context) { + whenTagID, ok := parseTagID(c) + if !ok { + return + } + + thenTagID, err := uuid.Parse(c.Param("then_tag_id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + if err := h.tagSvc.DeleteRule(c.Request.Context(), whenTagID, thenTagID); err != nil { + respondError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +// --------------------------------------------------------------------------- +// File-tag endpoints wired through TagService +// (called from file routes, shared handler logic lives here) +// --------------------------------------------------------------------------- + +// FileListTags handles GET /files/:id/tags. +func (h *TagHandler) FileListTags(c *gin.Context) { + fileID, err := uuid.Parse(c.Param("id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + tags, err := h.tagSvc.ListFileTags(c.Request.Context(), fileID) + if err != nil { + respondError(c, err) + return + } + + items := make([]tagJSON, len(tags)) + for i, t := range tags { + items[i] = toTagJSON(t) + } + respondJSON(c, http.StatusOK, items) +} + +// FileSetTags handles PUT /files/:id/tags. +func (h *TagHandler) FileSetTags(c *gin.Context) { + fileID, err := uuid.Parse(c.Param("id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + var body struct { + TagIDs []string `json:"tag_ids" binding:"required"` + } + if err := c.ShouldBindJSON(&body); err != nil { + respondError(c, domain.ErrValidation) + return + } + + tagIDs, err := parseUUIDs(body.TagIDs) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + tags, err := h.tagSvc.SetFileTags(c.Request.Context(), fileID, tagIDs) + if err != nil { + respondError(c, err) + return + } + + items := make([]tagJSON, len(tags)) + for i, t := range tags { + items[i] = toTagJSON(t) + } + respondJSON(c, http.StatusOK, items) +} + +// FileAddTag handles PUT /files/:id/tags/:tag_id. +func (h *TagHandler) FileAddTag(c *gin.Context) { + fileID, err := uuid.Parse(c.Param("id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + tagID, err := uuid.Parse(c.Param("tag_id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + tags, err := h.tagSvc.AddFileTag(c.Request.Context(), fileID, tagID) + if err != nil { + respondError(c, err) + return + } + + items := make([]tagJSON, len(tags)) + for i, t := range tags { + items[i] = toTagJSON(t) + } + respondJSON(c, http.StatusOK, items) +} + +// FileRemoveTag handles DELETE /files/:id/tags/:tag_id. +func (h *TagHandler) FileRemoveTag(c *gin.Context) { + fileID, err := uuid.Parse(c.Param("id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + tagID, err := uuid.Parse(c.Param("tag_id")) + if err != nil { + respondError(c, domain.ErrValidation) + return + } + + if err := h.tagSvc.RemoveFileTag(c.Request.Context(), fileID, tagID); err != nil { + respondError(c, err) + return + } + + c.Status(http.StatusNoContent) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func ptr(s string) *string { return &s } \ No newline at end of file diff --git a/backend/internal/port/repository.go b/backend/internal/port/repository.go index d638c05..4925bd0 100644 --- a/backend/internal/port/repository.go +++ b/backend/internal/port/repository.go @@ -63,6 +63,19 @@ type TagRepo interface { Create(ctx context.Context, t *domain.Tag) (*domain.Tag, error) Update(ctx context.Context, id uuid.UUID, t *domain.Tag) (*domain.Tag, error) Delete(ctx context.Context, id uuid.UUID) error + + // ListByFile returns all tags assigned to a specific file, ordered by name. + ListByFile(ctx context.Context, fileID uuid.UUID) ([]domain.Tag, error) + // AddFileTag inserts a single file→tag relation. No-op if already present. + AddFileTag(ctx context.Context, fileID, tagID uuid.UUID) error + // RemoveFileTag deletes a single file→tag relation. + RemoveFileTag(ctx context.Context, fileID, tagID uuid.UUID) error + // SetFileTags replaces all tags on a file (full replace semantics). + SetFileTags(ctx context.Context, fileID uuid.UUID, tagIDs []uuid.UUID) error + // CommonTagsForFiles returns tags present on every one of the given files. + CommonTagsForFiles(ctx context.Context, fileIDs []uuid.UUID) ([]domain.Tag, error) + // PartialTagsForFiles returns tags present on some but not all of the given files. + PartialTagsForFiles(ctx context.Context, fileIDs []uuid.UUID) ([]domain.Tag, error) } // TagRuleRepo is the persistence interface for auto-tag rules. diff --git a/backend/internal/service/file_service.go b/backend/internal/service/file_service.go index c414481..f8a5042 100644 --- a/backend/internal/service/file_service.go +++ b/backend/internal/service/file_service.go @@ -73,6 +73,7 @@ type FileService struct { storage port.FileStorage acl *ACLService audit *AuditService + tags *TagService tx port.Transactor importPath string // default server-side import directory } @@ -84,6 +85,7 @@ func NewFileService( storage port.FileStorage, acl *ACLService, audit *AuditService, + tags *TagService, tx port.Transactor, importPath string, ) *FileService { @@ -93,6 +95,7 @@ func NewFileService( storage: storage, acl: acl, audit: audit, + tags: tags, tx: tx, importPath: importPath, } @@ -166,10 +169,7 @@ func (s *FileService) Upload(ctx context.Context, p UploadParams) (*domain.File, } if len(p.TagIDs) > 0 { - if err := s.files.SetTags(ctx, created.ID, p.TagIDs); err != nil { - return err - } - tags, err := s.files.ListTags(ctx, created.ID) + tags, err := s.tags.SetFileTags(ctx, created.ID, p.TagIDs) if err != nil { return err } @@ -249,10 +249,7 @@ func (s *FileService) Update(ctx context.Context, id uuid.UUID, p UpdateParams) return updateErr } if p.TagIDs != nil { - if err := s.files.SetTags(ctx, id, *p.TagIDs); err != nil { - return err - } - tags, err := s.files.ListTags(ctx, id) + tags, err := s.tags.SetFileTags(ctx, id, *p.TagIDs) if err != nil { return err } @@ -447,120 +444,6 @@ func (s *FileService) GetPreview(ctx context.Context, id uuid.UUID) (io.ReadClos return s.storage.Preview(ctx, id) } -// --------------------------------------------------------------------------- -// Tag operations -// --------------------------------------------------------------------------- - -// ListFileTags returns the tags on a file, enforcing view ACL. -func (s *FileService) ListFileTags(ctx context.Context, fileID uuid.UUID) ([]domain.Tag, error) { - if _, err := s.Get(ctx, fileID); err != nil { - return nil, err - } - return s.files.ListTags(ctx, fileID) -} - -// SetFileTags replaces all tags on a file (full replace semantics), enforcing edit ACL. -func (s *FileService) SetFileTags(ctx context.Context, fileID uuid.UUID, tagIDs []uuid.UUID) ([]domain.Tag, error) { - userID, isAdmin, _ := domain.UserFromContext(ctx) - - f, err := s.files.GetByID(ctx, fileID) - if err != nil { - return nil, err - } - ok, err := s.acl.CanEdit(ctx, userID, isAdmin, f.CreatorID, fileObjectTypeID, fileID) - if err != nil { - return nil, err - } - if !ok { - return nil, domain.ErrForbidden - } - - if err := s.files.SetTags(ctx, fileID, tagIDs); err != nil { - return nil, err - } - - objType := fileObjectType - _ = s.audit.Log(ctx, "file_tag_add", &objType, &fileID, nil) - return s.files.ListTags(ctx, fileID) -} - -// AddTag adds a single tag to a file, enforcing edit ACL. -func (s *FileService) AddTag(ctx context.Context, fileID, tagID uuid.UUID) ([]domain.Tag, error) { - userID, isAdmin, _ := domain.UserFromContext(ctx) - - f, err := s.files.GetByID(ctx, fileID) - if err != nil { - return nil, err - } - ok, err := s.acl.CanEdit(ctx, userID, isAdmin, f.CreatorID, fileObjectTypeID, fileID) - if err != nil { - return nil, err - } - if !ok { - return nil, domain.ErrForbidden - } - - current, err := s.files.ListTags(ctx, fileID) - if err != nil { - return nil, err - } - // Only add if not already present. - for _, t := range current { - if t.ID == tagID { - return current, nil - } - } - ids := make([]uuid.UUID, 0, len(current)+1) - for _, t := range current { - ids = append(ids, t.ID) - } - ids = append(ids, tagID) - - if err := s.files.SetTags(ctx, fileID, ids); err != nil { - return nil, err - } - - objType := fileObjectType - _ = s.audit.Log(ctx, "file_tag_add", &objType, &fileID, map[string]any{"tag_id": tagID}) - return s.files.ListTags(ctx, fileID) -} - -// RemoveTag removes a single tag from a file, enforcing edit ACL. -func (s *FileService) RemoveTag(ctx context.Context, fileID, tagID uuid.UUID) error { - userID, isAdmin, _ := domain.UserFromContext(ctx) - - f, err := s.files.GetByID(ctx, fileID) - if err != nil { - return err - } - ok, err := s.acl.CanEdit(ctx, userID, isAdmin, f.CreatorID, fileObjectTypeID, fileID) - if err != nil { - return err - } - if !ok { - return domain.ErrForbidden - } - - current, err := s.files.ListTags(ctx, fileID) - if err != nil { - return err - } - ids := make([]uuid.UUID, 0, len(current)) - for _, t := range current { - if t.ID != tagID { - ids = append(ids, t.ID) - } - } - - if err := s.files.SetTags(ctx, fileID, ids); err != nil { - return err - } - - objType := fileObjectType - _ = s.audit.Log(ctx, "file_tag_remove", &objType, &fileID, map[string]any{"tag_id": tagID}) - return nil -} - // --------------------------------------------------------------------------- // Bulk operations // --------------------------------------------------------------------------- @@ -569,7 +452,6 @@ func (s *FileService) RemoveTag(ctx context.Context, fileID, tagID uuid.UUID) er func (s *FileService) BulkDelete(ctx context.Context, fileIDs []uuid.UUID) error { for _, id := range fileIDs { if err := s.Delete(ctx, id); err != nil { - // Skip files not found or forbidden; surface real errors. if err == domain.ErrNotFound || err == domain.ErrForbidden { continue } @@ -579,78 +461,6 @@ func (s *FileService) BulkDelete(ctx context.Context, fileIDs []uuid.UUID) error return nil } -// BulkSetTags adds or removes the given tags on multiple files. -// For "add": tags are appended to each file's existing set. -// For "remove": tags are removed from each file's existing set. -// Returns the tag IDs that were applied (the input tagIDs, for add). -func (s *FileService) BulkSetTags(ctx context.Context, fileIDs []uuid.UUID, action string, tagIDs []uuid.UUID) ([]uuid.UUID, error) { - for _, fileID := range fileIDs { - switch action { - case "add": - for _, tagID := range tagIDs { - if _, err := s.AddTag(ctx, fileID, tagID); err != nil { - if err == domain.ErrNotFound || err == domain.ErrForbidden { - continue - } - return nil, err - } - } - case "remove": - for _, tagID := range tagIDs { - if err := s.RemoveTag(ctx, fileID, tagID); err != nil { - if err == domain.ErrNotFound || err == domain.ErrForbidden { - continue - } - return nil, err - } - } - default: - return nil, domain.ErrValidation - } - } - if action == "add" { - return tagIDs, nil - } - return []uuid.UUID{}, nil -} - -// CommonTags loads the tag sets for all given files and splits them into: -// - common: tag IDs present on every file -// - partial: tag IDs present on some but not all files -func (s *FileService) CommonTags(ctx context.Context, fileIDs []uuid.UUID) (common, partial []uuid.UUID, err error) { - if len(fileIDs) == 0 { - return nil, nil, nil - } - - // Count how many files each tag appears on. - counts := map[uuid.UUID]int{} - for _, fid := range fileIDs { - tags, err := s.files.ListTags(ctx, fid) - if err != nil { - return nil, nil, err - } - for _, t := range tags { - counts[t.ID]++ - } - } - - n := len(fileIDs) - for id, cnt := range counts { - if cnt == n { - common = append(common, id) - } else { - partial = append(partial, id) - } - } - if common == nil { - common = []uuid.UUID{} - } - if partial == nil { - partial = []uuid.UUID{} - } - return common, partial, nil -} - // --------------------------------------------------------------------------- // Import // --------------------------------------------------------------------------- diff --git a/backend/internal/service/tag_service.go b/backend/internal/service/tag_service.go new file mode 100644 index 0000000..311f1db --- /dev/null +++ b/backend/internal/service/tag_service.go @@ -0,0 +1,393 @@ +package service + +import ( + "context" + "encoding/json" + + "github.com/google/uuid" + + "tanabata/backend/internal/domain" + "tanabata/backend/internal/port" +) + +const tagObjectType = "tag" +const tagObjectTypeID int16 = 2 // second row in 007_seed_data.sql object_types + +// TagParams holds the fields for creating or patching a tag. +type TagParams struct { + Name string + Notes *string + Color *string // nil = no change; pointer to empty string = clear + CategoryID *uuid.UUID // nil = no change; Nil UUID = unassign + Metadata json.RawMessage + IsPublic *bool +} + +// TagService handles tag CRUD, tag-rule management, and file–tag operations +// including automatic recursive rule application. +type TagService struct { + tags port.TagRepo + rules port.TagRuleRepo + acl *ACLService + audit *AuditService + tx port.Transactor +} + +// NewTagService creates a TagService. +func NewTagService( + tags port.TagRepo, + rules port.TagRuleRepo, + acl *ACLService, + audit *AuditService, + tx port.Transactor, +) *TagService { + return &TagService{ + tags: tags, + rules: rules, + acl: acl, + audit: audit, + tx: tx, + } +} + +// --------------------------------------------------------------------------- +// Tag CRUD +// --------------------------------------------------------------------------- + +// List returns a paginated, optionally filtered list of tags. +func (s *TagService) List(ctx context.Context, params port.OffsetParams) (*domain.TagOffsetPage, error) { + return s.tags.List(ctx, params) +} + +// Get returns a tag by ID. +func (s *TagService) Get(ctx context.Context, id uuid.UUID) (*domain.Tag, error) { + return s.tags.GetByID(ctx, id) +} + +// Create inserts a new tag record. +func (s *TagService) Create(ctx context.Context, p TagParams) (*domain.Tag, error) { + userID, _, _ := domain.UserFromContext(ctx) + + t := &domain.Tag{ + Name: p.Name, + Notes: p.Notes, + Color: p.Color, + CategoryID: p.CategoryID, + Metadata: p.Metadata, + CreatorID: userID, + } + if p.IsPublic != nil { + t.IsPublic = *p.IsPublic + } + + created, err := s.tags.Create(ctx, t) + if err != nil { + return nil, err + } + + objType := tagObjectType + _ = s.audit.Log(ctx, "tag_create", &objType, &created.ID, nil) + return created, nil +} + +// Update applies a partial patch to a tag. +// The service reads the current tag first so the caller only needs to supply +// the fields that should change. +func (s *TagService) Update(ctx context.Context, id uuid.UUID, p TagParams) (*domain.Tag, error) { + userID, isAdmin, _ := domain.UserFromContext(ctx) + + current, err := s.tags.GetByID(ctx, id) + if err != nil { + return nil, err + } + + ok, err := s.acl.CanEdit(ctx, userID, isAdmin, current.CreatorID, tagObjectTypeID, id) + if err != nil { + return nil, err + } + if !ok { + return nil, domain.ErrForbidden + } + + // Merge patch into current. + patch := *current // copy + if p.Name != "" { + patch.Name = p.Name + } + if p.Notes != nil { + patch.Notes = p.Notes + } + if p.Color != nil { + patch.Color = p.Color + } + if p.CategoryID != nil { + if *p.CategoryID == uuid.Nil { + patch.CategoryID = nil // explicit unassign + } else { + patch.CategoryID = p.CategoryID + } + } + if len(p.Metadata) > 0 { + patch.Metadata = p.Metadata + } + if p.IsPublic != nil { + patch.IsPublic = *p.IsPublic + } + + updated, err := s.tags.Update(ctx, id, &patch) + if err != nil { + return nil, err + } + + objType := tagObjectType + _ = s.audit.Log(ctx, "tag_edit", &objType, &id, nil) + return updated, nil +} + +// Delete removes a tag by ID, enforcing edit ACL. +func (s *TagService) Delete(ctx context.Context, id uuid.UUID) error { + userID, isAdmin, _ := domain.UserFromContext(ctx) + + t, err := s.tags.GetByID(ctx, id) + if err != nil { + return err + } + + ok, err := s.acl.CanEdit(ctx, userID, isAdmin, t.CreatorID, tagObjectTypeID, id) + if err != nil { + return err + } + if !ok { + return domain.ErrForbidden + } + + if err := s.tags.Delete(ctx, id); err != nil { + return err + } + + objType := tagObjectType + _ = s.audit.Log(ctx, "tag_delete", &objType, &id, nil) + return nil +} + +// --------------------------------------------------------------------------- +// Tag rules +// --------------------------------------------------------------------------- + +// ListRules returns all rules for a tag (when this tag is applied, these follow). +func (s *TagService) ListRules(ctx context.Context, tagID uuid.UUID) ([]domain.TagRule, error) { + return s.rules.ListByTag(ctx, tagID) +} + +// CreateRule adds a tag rule. If applyToExisting is true, the then_tag is +// retroactively applied to all files that already carry the when_tag. +// Retroactive application requires a FileRepo; it is deferred until wired +// in a future iteration (see port.FileRepo.ListByTag). +func (s *TagService) CreateRule(ctx context.Context, whenTagID, thenTagID uuid.UUID, isActive, _ bool) (*domain.TagRule, error) { + return s.rules.Create(ctx, domain.TagRule{ + WhenTagID: whenTagID, + ThenTagID: thenTagID, + IsActive: isActive, + }) +} + +// DeleteRule removes a tag rule. +func (s *TagService) DeleteRule(ctx context.Context, whenTagID, thenTagID uuid.UUID) error { + return s.rules.Delete(ctx, whenTagID, thenTagID) +} + +// --------------------------------------------------------------------------- +// File–tag operations (with auto-rule expansion) +// --------------------------------------------------------------------------- + +// ListFileTags returns the tags on a file. +func (s *TagService) ListFileTags(ctx context.Context, fileID uuid.UUID) ([]domain.Tag, error) { + return s.tags.ListByFile(ctx, fileID) +} + +// SetFileTags replaces all tags on a file, then applies active rules for all +// newly set tags (BFS expansion). Returns the full resulting tag set. +func (s *TagService) SetFileTags(ctx context.Context, fileID uuid.UUID, tagIDs []uuid.UUID) ([]domain.Tag, error) { + expanded, err := s.expandTagSet(ctx, tagIDs) + if err != nil { + return nil, err + } + + if err := s.tags.SetFileTags(ctx, fileID, expanded); err != nil { + return nil, err + } + + objType := fileObjectType + _ = s.audit.Log(ctx, "file_tag_add", &objType, &fileID, nil) + return s.tags.ListByFile(ctx, fileID) +} + +// AddFileTag adds a single tag to a file, then recursively applies active rules. +// Returns the full resulting tag set. +func (s *TagService) AddFileTag(ctx context.Context, fileID, tagID uuid.UUID) ([]domain.Tag, error) { + // Compute the full set including rule-expansion from tagID. + extra, err := s.expandTagSet(ctx, []uuid.UUID{tagID}) + if err != nil { + return nil, err + } + + // Fetch current tags so we don't lose them. + current, err := s.tags.ListByFile(ctx, fileID) + if err != nil { + return nil, err + } + + // Union: existing + expanded new tags. + seen := make(map[uuid.UUID]bool, len(current)+len(extra)) + for _, t := range current { + seen[t.ID] = true + } + merged := make([]uuid.UUID, len(current)) + for i, t := range current { + merged[i] = t.ID + } + for _, id := range extra { + if !seen[id] { + seen[id] = true + merged = append(merged, id) + } + } + + if err := s.tags.SetFileTags(ctx, fileID, merged); err != nil { + return nil, err + } + + objType := fileObjectType + _ = s.audit.Log(ctx, "file_tag_add", &objType, &fileID, map[string]any{"tag_id": tagID}) + return s.tags.ListByFile(ctx, fileID) +} + +// RemoveFileTag removes a single tag from a file. +func (s *TagService) RemoveFileTag(ctx context.Context, fileID, tagID uuid.UUID) error { + if err := s.tags.RemoveFileTag(ctx, fileID, tagID); err != nil { + return err + } + + objType := fileObjectType + _ = s.audit.Log(ctx, "file_tag_remove", &objType, &fileID, map[string]any{"tag_id": tagID}) + return nil +} + +// BulkSetTags adds or removes tags on multiple files (with rule expansion for add). +// Returns the tagIDs that were applied (the expanded input set for add; empty for remove). +func (s *TagService) BulkSetTags(ctx context.Context, fileIDs []uuid.UUID, action string, tagIDs []uuid.UUID) ([]uuid.UUID, error) { + if action != "add" && action != "remove" { + return nil, domain.ErrValidation + } + + // Pre-expand tag set once; all files get the same expansion. + var expanded []uuid.UUID + if action == "add" { + var err error + expanded, err = s.expandTagSet(ctx, tagIDs) + if err != nil { + return nil, err + } + } + + for _, fileID := range fileIDs { + switch action { + case "add": + current, err := s.tags.ListByFile(ctx, fileID) + if err != nil { + if err == domain.ErrNotFound { + continue + } + return nil, err + } + seen := make(map[uuid.UUID]bool, len(current)) + merged := make([]uuid.UUID, len(current)) + for i, t := range current { + seen[t.ID] = true + merged[i] = t.ID + } + for _, id := range expanded { + if !seen[id] { + seen[id] = true + merged = append(merged, id) + } + } + if err := s.tags.SetFileTags(ctx, fileID, merged); err != nil { + return nil, err + } + case "remove": + current, err := s.tags.ListByFile(ctx, fileID) + if err != nil { + if err == domain.ErrNotFound { + continue + } + return nil, err + } + remove := make(map[uuid.UUID]bool, len(tagIDs)) + for _, id := range tagIDs { + remove[id] = true + } + kept := make([]uuid.UUID, 0, len(current)) + for _, t := range current { + if !remove[t.ID] { + kept = append(kept, t.ID) + } + } + if err := s.tags.SetFileTags(ctx, fileID, kept); err != nil { + return nil, err + } + } + } + + if action == "add" { + return expanded, nil + } + return []uuid.UUID{}, nil +} + +// CommonTags returns tags present on ALL given files and tags present on SOME. +func (s *TagService) CommonTags(ctx context.Context, fileIDs []uuid.UUID) (common, partial []domain.Tag, err error) { + common, err = s.tags.CommonTagsForFiles(ctx, fileIDs) + if err != nil { + return nil, nil, err + } + partial, err = s.tags.PartialTagsForFiles(ctx, fileIDs) + if err != nil { + return nil, nil, err + } + return common, partial, nil +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// expandTagSet runs a BFS from the given seed tags, following active tag rules, +// and returns the full set of tag IDs that should be applied (seeds + auto-applied). +func (s *TagService) expandTagSet(ctx context.Context, seeds []uuid.UUID) ([]uuid.UUID, error) { + visited := make(map[uuid.UUID]bool, len(seeds)) + queue := make([]uuid.UUID, 0, len(seeds)) + + for _, id := range seeds { + if !visited[id] { + visited[id] = true + queue = append(queue, id) + } + } + + for i := 0; i < len(queue); i++ { + tagID := queue[i] + rules, err := s.rules.ListByTag(ctx, tagID) + if err != nil { + return nil, err + } + for _, r := range rules { + if r.IsActive && !visited[r.ThenTagID] { + visited[r.ThenTagID] = true + queue = append(queue, r.ThenTagID) + } + } + } + + return queue, nil +} \ No newline at end of file