From dea6a55dfcb34d096d8a964c2b8871afbad0ec7e Mon Sep 17 00:00:00 2001 From: Masahiko AMANO Date: Sat, 4 Apr 2026 16:55:23 +0300 Subject: [PATCH] feat: implement FileRepo and filter DSL parser MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit filter_parser.go — recursive-descent parser for the {token,...} DSL. Tokens: t=UUID (tag), m=INT (MIME exact), m~PATTERN (MIME LIKE), operators & | ! ( ) with standard NOT>AND>OR precedence. All values go through pgx parameters ($N) — SQL injection impossible. file_repo.go — full FileRepo: - Create/GetByID/Update via CTE RETURNING with JOIN for one round-trip - SoftDelete/Restore/DeletePermanent with RowsAffected guards - SetTags: full replace (DELETE + INSERT per tag) - ListTags: delegates to loadTagsBatch (single query for N files) - List: keyset cursor pagination (bidirectional), anchor mode, filter DSL, search ILIKE, trash flag, 4 sort columns. Cursor is base64url(JSON) encoding sort position; backward pagination fetches in reversed ORDER BY then reverses the slice. Co-Authored-By: Claude Sonnet 4.6 --- backend/internal/db/postgres/file_repo.go | 795 ++++++++++++++++++ backend/internal/db/postgres/filter_parser.go | 286 +++++++ 2 files changed, 1081 insertions(+) create mode 100644 backend/internal/db/postgres/file_repo.go create mode 100644 backend/internal/db/postgres/filter_parser.go diff --git a/backend/internal/db/postgres/file_repo.go b/backend/internal/db/postgres/file_repo.go new file mode 100644 index 0000000..e8dbd9e --- /dev/null +++ b/backend/internal/db/postgres/file_repo.go @@ -0,0 +1,795 @@ +package postgres + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "tanabata/backend/internal/db" + "tanabata/backend/internal/domain" + "tanabata/backend/internal/port" +) + +// --------------------------------------------------------------------------- +// Row structs +// --------------------------------------------------------------------------- + +type fileRow struct { + ID uuid.UUID `db:"id"` + OriginalName *string `db:"original_name"` + MIMEType string `db:"mime_type"` + MIMEExtension string `db:"mime_extension"` + ContentDatetime time.Time `db:"content_datetime"` + Notes *string `db:"notes"` + Metadata json.RawMessage `db:"metadata"` + EXIF json.RawMessage `db:"exif"` + PHash *int64 `db:"phash"` + CreatorID int16 `db:"creator_id"` + CreatorName string `db:"creator_name"` + IsPublic bool `db:"is_public"` + IsDeleted bool `db:"is_deleted"` +} + +// fileTagRow is used for both single-file and batch tag loading. +// file_id is always selected so the same struct works for both cases. +type fileTagRow struct { + FileID uuid.UUID `db:"file_id"` + ID uuid.UUID `db:"id"` + Name string `db:"name"` + Notes *string `db:"notes"` + Color *string `db:"color"` + CategoryID *uuid.UUID `db:"category_id"` + CategoryName *string `db:"category_name"` + CategoryColor *string `db:"category_color"` + Metadata json.RawMessage `db:"metadata"` + CreatorID int16 `db:"creator_id"` + CreatorName string `db:"creator_name"` + IsPublic bool `db:"is_public"` +} + +// anchorValRow holds the sort-column values fetched for an anchor file. +type anchorValRow struct { + ContentDatetime time.Time `db:"content_datetime"` + OriginalName string `db:"original_name"` // COALESCE(original_name,'') applied in SQL + MIMEType string `db:"mime_type"` +} + +// --------------------------------------------------------------------------- +// Converters +// --------------------------------------------------------------------------- + +func toFile(r fileRow) domain.File { + return domain.File{ + ID: r.ID, + OriginalName: r.OriginalName, + MIMEType: r.MIMEType, + MIMEExtension: r.MIMEExtension, + ContentDatetime: r.ContentDatetime, + Notes: r.Notes, + Metadata: r.Metadata, + EXIF: r.EXIF, + PHash: r.PHash, + CreatorID: r.CreatorID, + CreatorName: r.CreatorName, + IsPublic: r.IsPublic, + IsDeleted: r.IsDeleted, + CreatedAt: domain.UUIDCreatedAt(r.ID), + } +} + +func toTagFromFileTag(r fileTagRow) domain.Tag { + return domain.Tag{ + ID: r.ID, + Name: r.Name, + Notes: r.Notes, + Color: r.Color, + CategoryID: r.CategoryID, + CategoryName: r.CategoryName, + CategoryColor: r.CategoryColor, + Metadata: r.Metadata, + CreatorID: r.CreatorID, + CreatorName: r.CreatorName, + IsPublic: r.IsPublic, + CreatedAt: domain.UUIDCreatedAt(r.ID), + } +} + +// --------------------------------------------------------------------------- +// Cursor +// --------------------------------------------------------------------------- + +type fileCursor struct { + Sort string `json:"s"` // canonical sort name + Order string `json:"o"` // "ASC" or "DESC" + ID string `json:"id"` // UUID of the boundary file + Val string `json:"v"` // sort column value; empty for "created" (id IS the key) +} + +func encodeCursor(c fileCursor) string { + b, _ := json.Marshal(c) + return base64.RawURLEncoding.EncodeToString(b) +} + +func decodeCursor(s string) (fileCursor, error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return fileCursor{}, fmt.Errorf("cursor: invalid encoding") + } + var c fileCursor + if err := json.Unmarshal(b, &c); err != nil { + return fileCursor{}, fmt.Errorf("cursor: invalid format") + } + return c, nil +} + +// makeCursor builds a fileCursor from a boundary row and the current sort/order. +func makeCursor(r fileRow, sort, order string) fileCursor { + var val string + switch sort { + case "content_datetime": + val = r.ContentDatetime.UTC().Format(time.RFC3339Nano) + case "original_name": + if r.OriginalName != nil { + val = *r.OriginalName + } + case "mime": + val = r.MIMEType + // "created": val is empty; f.id is the sort key. + } + return fileCursor{Sort: sort, Order: order, ID: r.ID.String(), Val: val} +} + +// --------------------------------------------------------------------------- +// Sort helpers +// --------------------------------------------------------------------------- + +func normSort(s string) string { + switch s { + case "content_datetime", "original_name", "mime": + return s + default: + return "created" + } +} + +func normOrder(o string) string { + if strings.EqualFold(o, "asc") { + return "ASC" + } + return "DESC" +} + +// buildKeysetCond returns a keyset WHERE fragment and an ORDER BY fragment. +// +// - forward=true: items after the cursor in the sort order (standard next-page) +// - forward=false: items before the cursor (previous-page); ORDER BY is reversed, +// caller must reverse the result slice after fetching +// - incl=true: include the cursor file itself (anchor case; uses ≤ / ≥) +// +// All user values are bound as parameters — no SQL injection possible. +func buildKeysetCond( + sort, order string, + forward, incl bool, + cursorID uuid.UUID, cursorVal string, + n int, args []any, +) (where, orderBy string, nextN int, outArgs []any) { + // goDown=true → want smaller values → primary comparison is "<". + // Applies for DESC+forward and ASC+backward. + goDown := (order == "DESC") == forward + + var op, idOp string + if goDown { + op = "<" + if incl { + idOp = "<=" + } else { + idOp = "<" + } + } else { + op = ">" + if incl { + idOp = ">=" + } else { + idOp = ">" + } + } + + // Effective ORDER BY direction: reversed for backward so the DB returns + // the closest items first (the ones we keep after trimming the extra). + dir := order + if !forward { + if order == "DESC" { + dir = "ASC" + } else { + dir = "DESC" + } + } + + switch sort { + case "created": + // Single-column keyset: f.id (UUID v7, so ordering = chronological). + where = fmt.Sprintf("f.id %s $%d", idOp, n) + orderBy = fmt.Sprintf("f.id %s", dir) + outArgs = append(args, cursorID) + n++ + + case "content_datetime": + // Two-column keyset: (content_datetime, id). + // $n is referenced twice in the SQL (< and =) but passed once in args — + // PostgreSQL extended protocol allows multiple references to $N. + t, _ := time.Parse(time.RFC3339Nano, cursorVal) + where = fmt.Sprintf( + "(f.content_datetime %s $%d OR (f.content_datetime = $%d AND f.id %s $%d))", + op, n, n, idOp, n+1) + orderBy = fmt.Sprintf("f.content_datetime %s, f.id %s", dir, dir) + outArgs = append(args, t, cursorID) + n += 2 + + case "original_name": + // COALESCE treats NULL names as '' for stable pagination. + where = fmt.Sprintf( + "(COALESCE(f.original_name,'') %s $%d OR (COALESCE(f.original_name,'') = $%d AND f.id %s $%d))", + op, n, n, idOp, n+1) + orderBy = fmt.Sprintf("COALESCE(f.original_name,'') %s, f.id %s", dir, dir) + outArgs = append(args, cursorVal, cursorID) + n += 2 + + default: // "mime" + where = fmt.Sprintf( + "(mt.name %s $%d OR (mt.name = $%d AND f.id %s $%d))", + op, n, n, idOp, n+1) + orderBy = fmt.Sprintf("mt.name %s, f.id %s", dir, dir) + outArgs = append(args, cursorVal, cursorID) + n += 2 + } + + nextN = n + return +} + +// defaultOrderBy returns the natural ORDER BY for the first page (no cursor). +func defaultOrderBy(sort, order string) string { + switch sort { + case "created": + return fmt.Sprintf("f.id %s", order) + case "content_datetime": + return fmt.Sprintf("f.content_datetime %s, f.id %s", order, order) + case "original_name": + return fmt.Sprintf("COALESCE(f.original_name,'') %s, f.id %s", order, order) + default: // "mime" + return fmt.Sprintf("mt.name %s, f.id %s", order, order) + } +} + +// --------------------------------------------------------------------------- +// FileRepo +// --------------------------------------------------------------------------- + +// FileRepo implements port.FileRepo using PostgreSQL. +type FileRepo struct { + pool *pgxpool.Pool +} + +// NewFileRepo creates a FileRepo backed by pool. +func NewFileRepo(pool *pgxpool.Pool) *FileRepo { + return &FileRepo{pool: pool} +} + +var _ port.FileRepo = (*FileRepo)(nil) + +// fileSelectCTE is the SELECT appended after a CTE named "r" that exposes +// all file columns (including mime_id). Used by Create, Update, and Restore +// to get the full denormalized record in a single round-trip. +const fileSelectCTE = ` + SELECT r.id, r.original_name, + mt.name AS mime_type, mt.extension AS mime_extension, + r.content_datetime, r.notes, r.metadata, r.exif, r.phash, + r.creator_id, u.name AS creator_name, + r.is_public, r.is_deleted + FROM r + JOIN core.mime_types mt ON mt.id = r.mime_id + JOIN core.users u ON u.id = r.creator_id` + +// --------------------------------------------------------------------------- +// Create +// --------------------------------------------------------------------------- + +// Create inserts a new file record. The MIME type is resolved from +// f.MIMEType (name string) via a subquery; the DB generates the UUID v7 id. +func (r *FileRepo) Create(ctx context.Context, f *domain.File) (*domain.File, error) { + const sqlStr = ` + WITH r AS ( + INSERT INTO data.files + (original_name, mime_id, content_datetime, notes, metadata, exif, phash, creator_id, is_public) + VALUES ( + $1, + (SELECT id FROM core.mime_types WHERE name = $2), + $3, $4, $5, $6, $7, $8, $9 + ) + RETURNING id, original_name, mime_id, content_datetime, notes, + metadata, exif, phash, creator_id, is_public, is_deleted + )` + fileSelectCTE + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, + f.OriginalName, f.MIMEType, f.ContentDatetime, + f.Notes, f.Metadata, f.EXIF, f.PHash, + f.CreatorID, f.IsPublic, + ) + if err != nil { + return nil, fmt.Errorf("FileRepo.Create: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[fileRow]) + if err != nil { + return nil, fmt.Errorf("FileRepo.Create scan: %w", err) + } + created := toFile(row) + return &created, nil +} + +// --------------------------------------------------------------------------- +// GetByID +// --------------------------------------------------------------------------- + +func (r *FileRepo) GetByID(ctx context.Context, id uuid.UUID) (*domain.File, error) { + const sqlStr = ` + SELECT f.id, f.original_name, + mt.name AS mime_type, mt.extension AS mime_extension, + f.content_datetime, f.notes, f.metadata, f.exif, f.phash, + f.creator_id, u.name AS creator_name, + f.is_public, f.is_deleted + FROM data.files f + JOIN core.mime_types mt ON mt.id = f.mime_id + JOIN core.users u ON u.id = f.creator_id + WHERE f.id = $1` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, id) + if err != nil { + return nil, fmt.Errorf("FileRepo.GetByID: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[fileRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("FileRepo.GetByID scan: %w", err) + } + f := toFile(row) + tags, err := r.ListTags(ctx, id) + if err != nil { + return nil, err + } + f.Tags = tags + return &f, nil +} + +// --------------------------------------------------------------------------- +// Update +// --------------------------------------------------------------------------- + +// Update applies editable metadata fields. MIME type and EXIF are immutable. +func (r *FileRepo) Update(ctx context.Context, id uuid.UUID, f *domain.File) (*domain.File, error) { + const sqlStr = ` + WITH r AS ( + UPDATE data.files + SET original_name = $2, + content_datetime = $3, + notes = $4, + metadata = $5, + is_public = $6 + WHERE id = $1 + RETURNING id, original_name, mime_id, content_datetime, notes, + metadata, exif, phash, creator_id, is_public, is_deleted + )` + fileSelectCTE + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, + id, f.OriginalName, f.ContentDatetime, + f.Notes, f.Metadata, f.IsPublic, + ) + if err != nil { + return nil, fmt.Errorf("FileRepo.Update: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[fileRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("FileRepo.Update scan: %w", err) + } + updated := toFile(row) + tags, err := r.ListTags(ctx, id) + if err != nil { + return nil, err + } + updated.Tags = tags + return &updated, nil +} + +// --------------------------------------------------------------------------- +// SoftDelete / Restore / DeletePermanent +// --------------------------------------------------------------------------- + +// SoftDelete moves a file to trash (is_deleted = true). Returns ErrNotFound +// if the file does not exist or is already in trash. +func (r *FileRepo) SoftDelete(ctx context.Context, id uuid.UUID) error { + const sqlStr = `UPDATE data.files SET is_deleted = true WHERE id = $1 AND is_deleted = false` + q := connOrTx(ctx, r.pool) + tag, err := q.Exec(ctx, sqlStr, id) + if err != nil { + return fmt.Errorf("FileRepo.SoftDelete: %w", err) + } + if tag.RowsAffected() == 0 { + return domain.ErrNotFound + } + return nil +} + +// Restore moves a file out of trash (is_deleted = false). Returns ErrNotFound +// if the file does not exist or is not in trash. +func (r *FileRepo) Restore(ctx context.Context, id uuid.UUID) (*domain.File, error) { + const sqlStr = ` + WITH r AS ( + UPDATE data.files + SET is_deleted = false + WHERE id = $1 AND is_deleted = true + RETURNING id, original_name, mime_id, content_datetime, notes, + metadata, exif, phash, creator_id, is_public, is_deleted + )` + fileSelectCTE + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, id) + if err != nil { + return nil, fmt.Errorf("FileRepo.Restore: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[fileRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("FileRepo.Restore scan: %w", err) + } + restored := toFile(row) + tags, err := r.ListTags(ctx, id) + if err != nil { + return nil, err + } + restored.Tags = tags + return &restored, nil +} + +// DeletePermanent removes a file record permanently. Only allowed when the +// file is already in trash (is_deleted = true). +func (r *FileRepo) DeletePermanent(ctx context.Context, id uuid.UUID) error { + const sqlStr = `DELETE FROM data.files WHERE id = $1 AND is_deleted = true` + q := connOrTx(ctx, r.pool) + tag, err := q.Exec(ctx, sqlStr, id) + if err != nil { + return fmt.Errorf("FileRepo.DeletePermanent: %w", err) + } + if tag.RowsAffected() == 0 { + return domain.ErrNotFound + } + return nil +} + +// --------------------------------------------------------------------------- +// ListTags / SetTags +// --------------------------------------------------------------------------- + +// ListTags returns all tags assigned to a file, ordered by tag name. +func (r *FileRepo) ListTags(ctx context.Context, fileID uuid.UUID) ([]domain.Tag, error) { + m, err := r.loadTagsBatch(ctx, []uuid.UUID{fileID}) + if err != nil { + return nil, err + } + return m[fileID], nil +} + +// SetTags replaces all tags on a file (full replace semantics). +func (r *FileRepo) SetTags(ctx context.Context, fileID uuid.UUID, tagIDs []uuid.UUID) error { + q := connOrTx(ctx, r.pool) + const del = `DELETE FROM data.file_tag WHERE file_id = $1` + if _, err := q.Exec(ctx, del, fileID); err != nil { + return fmt.Errorf("FileRepo.SetTags delete: %w", err) + } + if len(tagIDs) == 0 { + return nil + } + const ins = `INSERT INTO data.file_tag (file_id, tag_id) VALUES ($1, $2)` + for _, tagID := range tagIDs { + if _, err := q.Exec(ctx, ins, fileID, tagID); err != nil { + return fmt.Errorf("FileRepo.SetTags insert: %w", err) + } + } + return nil +} + +// --------------------------------------------------------------------------- +// List +// --------------------------------------------------------------------------- + +// List returns a cursor-paginated page of files. +// +// Pagination is keyset-based for stable performance on large tables. +// Cursor encodes the sort position; the caller provides direction. +// Anchor mode centres the result around a specific file UUID. +func (r *FileRepo) List(ctx context.Context, params domain.FileListParams) (*domain.FilePage, error) { + sort := normSort(params.Sort) + order := normOrder(params.Order) + forward := params.Direction != "backward" + limit := db.ClampLimit(params.Limit, 50, 200) + + // --- resolve cursor / anchor --- + var ( + cursorID uuid.UUID + cursorVal string + hasCursor bool + isAnchor bool + ) + + switch { + case params.Cursor != "": + cur, err := decodeCursor(params.Cursor) + if err != nil { + return nil, fmt.Errorf("%w: %v", domain.ErrValidation, err) + } + id, err := uuid.Parse(cur.ID) + if err != nil { + return nil, domain.ErrValidation + } + // Lock in the sort/order encoded in the cursor so changing query + // parameters mid-session doesn't corrupt pagination. + sort = normSort(cur.Sort) + order = normOrder(cur.Order) + cursorID = id + cursorVal = cur.Val + hasCursor = true + + case params.Anchor != nil: + av, err := r.fetchAnchorVals(ctx, *params.Anchor) + if err != nil { + return nil, err + } + cursorID = *params.Anchor + switch sort { + case "content_datetime": + cursorVal = av.ContentDatetime.UTC().Format(time.RFC3339Nano) + case "original_name": + cursorVal = av.OriginalName + case "mime": + cursorVal = av.MIMEType + // "created": cursorVal stays ""; cursorID is the sort key. + } + hasCursor = true + isAnchor = true + } + + // Without a cursor there is no meaningful "backward" direction. + if !hasCursor { + forward = true + } + + // --- build WHERE and ORDER BY --- + var conds []string + args := make([]any, 0, 8) + n := 1 + + conds = append(conds, fmt.Sprintf("f.is_deleted = $%d", n)) + args = append(args, params.Trash) + n++ + + if params.Search != "" { + conds = append(conds, fmt.Sprintf("f.original_name ILIKE $%d", n)) + args = append(args, "%"+params.Search+"%") + n++ + } + + if params.Filter != "" { + filterSQL, nextN, filterArgs, err := ParseFilter(params.Filter, n) + if err != nil { + return nil, fmt.Errorf("%w: %v", domain.ErrValidation, err) + } + if filterSQL != "" { + conds = append(conds, filterSQL) + n = nextN + args = append(args, filterArgs...) + } + } + + var orderBy string + if hasCursor { + ksWhere, ksOrder, nextN, ksArgs := buildKeysetCond( + sort, order, forward, isAnchor, cursorID, cursorVal, n, args) + conds = append(conds, ksWhere) + n = nextN + args = ksArgs + orderBy = ksOrder + } else { + orderBy = defaultOrderBy(sort, order) + } + + where := "" + if len(conds) > 0 { + where = "WHERE " + strings.Join(conds, " AND ") + } + + // Fetch one extra row to detect whether more items exist beyond this page. + args = append(args, limit+1) + sqlStr := fmt.Sprintf(` + SELECT f.id, f.original_name, + mt.name AS mime_type, mt.extension AS mime_extension, + f.content_datetime, f.notes, f.metadata, f.exif, f.phash, + f.creator_id, u.name AS creator_name, + f.is_public, f.is_deleted + FROM data.files f + JOIN core.mime_types mt ON mt.id = f.mime_id + JOIN core.users u ON u.id = f.creator_id + %s + ORDER BY %s + LIMIT $%d`, where, orderBy, n) + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, args...) + if err != nil { + return nil, fmt.Errorf("FileRepo.List: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[fileRow]) + if err != nil { + return nil, fmt.Errorf("FileRepo.List scan: %w", err) + } + + // --- trim extra row and reverse for backward --- + hasMore := len(collected) > limit + if hasMore { + collected = collected[:limit] + } + if !forward { + // Results were fetched in reversed ORDER BY; invert to restore the + // natural sort order expected by the caller. + for i, j := 0, len(collected)-1; i < j; i, j = i+1, j-1 { + collected[i], collected[j] = collected[j], collected[i] + } + } + + // --- assemble page --- + page := &domain.FilePage{ + Items: make([]domain.File, len(collected)), + } + for i, row := range collected { + page.Items[i] = toFile(row) + } + + // --- set next/prev cursors --- + // next_cursor: navigate further in the forward direction. + // prev_cursor: navigate further in the backward direction. + if len(collected) > 0 { + firstCur := encodeCursor(makeCursor(collected[0], sort, order)) + lastCur := encodeCursor(makeCursor(collected[len(collected)-1], sort, order)) + + if forward { + // We only know a prev page exists if we arrived via cursor. + if hasCursor { + page.PrevCursor = &firstCur + } + if hasMore { + page.NextCursor = &lastCur + } + } else { + // Backward: last item (after reversal) is closest to original cursor. + if hasCursor { + page.NextCursor = &lastCur + } + if hasMore { + page.PrevCursor = &firstCur + } + } + } + + // --- batch-load tags --- + if len(page.Items) > 0 { + fileIDs := make([]uuid.UUID, len(page.Items)) + for i, f := range page.Items { + fileIDs[i] = f.ID + } + tagMap, err := r.loadTagsBatch(ctx, fileIDs) + if err != nil { + return nil, err + } + for i, f := range page.Items { + page.Items[i].Tags = tagMap[f.ID] // nil becomes []domain.Tag{} via loadTagsBatch + } + } + + return page, nil +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// fetchAnchorVals returns the sort-column values for the given file. +// Used to set up a cursor when the caller provides an anchor UUID. +func (r *FileRepo) fetchAnchorVals(ctx context.Context, fileID uuid.UUID) (*anchorValRow, error) { + const sqlStr = ` + SELECT f.content_datetime, + COALESCE(f.original_name, '') AS original_name, + mt.name AS mime_type + FROM data.files f + JOIN core.mime_types mt ON mt.id = f.mime_id + WHERE f.id = $1` + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, fileID) + if err != nil { + return nil, fmt.Errorf("FileRepo.fetchAnchorVals: %w", err) + } + row, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[anchorValRow]) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrNotFound + } + return nil, fmt.Errorf("FileRepo.fetchAnchorVals scan: %w", err) + } + return &row, nil +} + +// loadTagsBatch fetches tags for multiple files in a single query and returns +// them as a map keyed by file ID. Every requested file ID appears as a key +// (with an empty slice if the file has no tags). +func (r *FileRepo) loadTagsBatch(ctx context.Context, fileIDs []uuid.UUID) (map[uuid.UUID][]domain.Tag, error) { + if len(fileIDs) == 0 { + return nil, nil + } + + // Build a parameterised IN list. The max page size is 200, so at most 200 + // placeholders — well within PostgreSQL's limits. + placeholders := make([]string, len(fileIDs)) + args := make([]any, len(fileIDs)) + for i, id := range fileIDs { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + + sqlStr := fmt.Sprintf(` + SELECT ft.file_id, + t.id, t.name, t.notes, t.color, + t.category_id, + c.name AS category_name, + c.color AS category_color, + t.metadata, t.creator_id, u.name AS creator_name, t.is_public + FROM data.file_tag ft + JOIN data.tags t ON t.id = ft.tag_id + JOIN core.users u ON u.id = t.creator_id + LEFT JOIN data.categories c ON c.id = t.category_id + WHERE ft.file_id IN (%s) + ORDER BY ft.file_id, t.name`, strings.Join(placeholders, ",")) + + q := connOrTx(ctx, r.pool) + rows, err := q.Query(ctx, sqlStr, args...) + if err != nil { + return nil, fmt.Errorf("FileRepo.loadTagsBatch: %w", err) + } + collected, err := pgx.CollectRows(rows, pgx.RowToStructByName[fileTagRow]) + if err != nil { + return nil, fmt.Errorf("FileRepo.loadTagsBatch scan: %w", err) + } + + result := make(map[uuid.UUID][]domain.Tag, len(fileIDs)) + for _, fid := range fileIDs { + result[fid] = []domain.Tag{} // guarantee every key has a non-nil slice + } + for _, row := range collected { + result[row.FileID] = append(result[row.FileID], toTagFromFileTag(row)) + } + return result, nil +} diff --git a/backend/internal/db/postgres/filter_parser.go b/backend/internal/db/postgres/filter_parser.go new file mode 100644 index 0000000..f3f87f8 --- /dev/null +++ b/backend/internal/db/postgres/filter_parser.go @@ -0,0 +1,286 @@ +package postgres + +import ( + "fmt" + "strconv" + "strings" + + "github.com/google/uuid" +) + +// --------------------------------------------------------------------------- +// Token types +// --------------------------------------------------------------------------- + +type filterTokenKind int + +const ( + ftkAnd filterTokenKind = iota + ftkOr + ftkNot + ftkLParen + ftkRParen + ftkTag // t= + ftkMimeExact // m= + ftkMimeLike // m~ +) + +type filterToken struct { + kind filterTokenKind + tagID uuid.UUID // ftkTag + untagged bool // ftkTag with zero UUID → "file has no tags" + mimeID int16 // ftkMimeExact + pattern string // ftkMimeLike +} + +// --------------------------------------------------------------------------- +// AST nodes +// --------------------------------------------------------------------------- + +// filterNode produces a parameterized SQL fragment. +// n is the index of the next available positional parameter ($n). +// Returns the fragment, the updated n, and the extended args slice. +type filterNode interface { + toSQL(n int, args []any) (string, int, []any) +} + +type andNode struct{ left, right filterNode } +type orNode struct{ left, right filterNode } +type notNode struct{ child filterNode } +type leafNode struct{ tok filterToken } + +func (a *andNode) toSQL(n int, args []any) (string, int, []any) { + ls, n, args := a.left.toSQL(n, args) + rs, n, args := a.right.toSQL(n, args) + return "(" + ls + " AND " + rs + ")", n, args +} + +func (o *orNode) toSQL(n int, args []any) (string, int, []any) { + ls, n, args := o.left.toSQL(n, args) + rs, n, args := o.right.toSQL(n, args) + return "(" + ls + " OR " + rs + ")", n, args +} + +func (no *notNode) toSQL(n int, args []any) (string, int, []any) { + cs, n, args := no.child.toSQL(n, args) + return "(NOT " + cs + ")", n, args +} + +func (l *leafNode) toSQL(n int, args []any) (string, int, []any) { + switch l.tok.kind { + case ftkTag: + if l.tok.untagged { + return "NOT EXISTS (SELECT 1 FROM data.file_tag ft WHERE ft.file_id = f.id)", n, args + } + s := fmt.Sprintf( + "EXISTS (SELECT 1 FROM data.file_tag ft WHERE ft.file_id = f.id AND ft.tag_id = $%d)", n) + return s, n + 1, append(args, l.tok.tagID) + case ftkMimeExact: + return fmt.Sprintf("f.mime_id = $%d", n), n + 1, append(args, l.tok.mimeID) + case ftkMimeLike: + // mt alias comes from the JOIN in the main file query (always present). + return fmt.Sprintf("mt.name LIKE $%d", n), n + 1, append(args, l.tok.pattern) + } + panic("filterNode.toSQL: unknown leaf kind") +} + +// --------------------------------------------------------------------------- +// Lexer +// --------------------------------------------------------------------------- + +// lexFilter tokenises the DSL string {a,b,c,...} into filterTokens. +func lexFilter(dsl string) ([]filterToken, error) { + dsl = strings.TrimSpace(dsl) + if !strings.HasPrefix(dsl, "{") || !strings.HasSuffix(dsl, "}") { + return nil, fmt.Errorf("filter DSL must be wrapped in braces: {…}") + } + inner := strings.TrimSpace(dsl[1 : len(dsl)-1]) + if inner == "" { + return nil, nil + } + + parts := strings.Split(inner, ",") + tokens := make([]filterToken, 0, len(parts)) + + for _, raw := range parts { + p := strings.TrimSpace(raw) + switch { + case p == "&": + tokens = append(tokens, filterToken{kind: ftkAnd}) + case p == "|": + tokens = append(tokens, filterToken{kind: ftkOr}) + case p == "!": + tokens = append(tokens, filterToken{kind: ftkNot}) + case p == "(": + tokens = append(tokens, filterToken{kind: ftkLParen}) + case p == ")": + tokens = append(tokens, filterToken{kind: ftkRParen}) + case strings.HasPrefix(p, "t="): + id, err := uuid.Parse(p[2:]) + if err != nil { + return nil, fmt.Errorf("filter: invalid tag UUID %q", p[2:]) + } + tokens = append(tokens, filterToken{kind: ftkTag, tagID: id, untagged: id == uuid.Nil}) + case strings.HasPrefix(p, "m="): + v, err := strconv.ParseInt(p[2:], 10, 16) + if err != nil { + return nil, fmt.Errorf("filter: invalid MIME ID %q", p[2:]) + } + tokens = append(tokens, filterToken{kind: ftkMimeExact, mimeID: int16(v)}) + case strings.HasPrefix(p, "m~"): + // The pattern value is passed as a query parameter, so no SQL injection risk. + tokens = append(tokens, filterToken{kind: ftkMimeLike, pattern: p[2:]}) + default: + return nil, fmt.Errorf("filter: unknown token %q", p) + } + } + return tokens, nil +} + +// --------------------------------------------------------------------------- +// Recursive-descent parser +// --------------------------------------------------------------------------- + +type filterParser struct { + tokens []filterToken + pos int +} + +func (p *filterParser) peek() (filterToken, bool) { + if p.pos >= len(p.tokens) { + return filterToken{}, false + } + return p.tokens[p.pos], true +} + +func (p *filterParser) next() filterToken { + t := p.tokens[p.pos] + p.pos++ + return t +} + +// Grammar (standard NOT > AND > OR precedence): +// +// expr := or_expr +// or_expr := and_expr ('|' and_expr)* +// and_expr := not_expr ('&' not_expr)* +// not_expr := '!' not_expr | atom +// atom := '(' expr ')' | leaf + +func (p *filterParser) parseExpr() (filterNode, error) { return p.parseOr() } + +func (p *filterParser) parseOr() (filterNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for { + t, ok := p.peek() + if !ok || t.kind != ftkOr { + break + } + p.next() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = &orNode{left, right} + } + return left, nil +} + +func (p *filterParser) parseAnd() (filterNode, error) { + left, err := p.parseNot() + if err != nil { + return nil, err + } + for { + t, ok := p.peek() + if !ok || t.kind != ftkAnd { + break + } + p.next() + right, err := p.parseNot() + if err != nil { + return nil, err + } + left = &andNode{left, right} + } + return left, nil +} + +func (p *filterParser) parseNot() (filterNode, error) { + t, ok := p.peek() + if ok && t.kind == ftkNot { + p.next() + child, err := p.parseNot() // right-recursive to allow !!x + if err != nil { + return nil, err + } + return ¬Node{child}, nil + } + return p.parseAtom() +} + +func (p *filterParser) parseAtom() (filterNode, error) { + t, ok := p.peek() + if !ok { + return nil, fmt.Errorf("filter: unexpected end of expression") + } + if t.kind == ftkLParen { + p.next() + expr, err := p.parseExpr() + if err != nil { + return nil, err + } + rp, ok := p.peek() + if !ok || rp.kind != ftkRParen { + return nil, fmt.Errorf("filter: expected ')'") + } + p.next() + return expr, nil + } + switch t.kind { + case ftkTag, ftkMimeExact, ftkMimeLike: + p.next() + return &leafNode{t}, nil + default: + return nil, fmt.Errorf("filter: unexpected token at position %d", p.pos) + } +} + +// --------------------------------------------------------------------------- +// Public entry point +// --------------------------------------------------------------------------- + +// ParseFilter parses a filter DSL string into a parameterized SQL fragment. +// +// argStart is the 1-based index for the first $N placeholder; this lets the +// caller interleave filter parameters with other query parameters. +// +// Returns ("", argStart, nil, nil) for an empty or trivial DSL. +// SQL injection is structurally impossible: every user-supplied value is +// bound as a query parameter ($N), never interpolated into the SQL string. +func ParseFilter(dsl string, argStart int) (sql string, nextN int, args []any, err error) { + dsl = strings.TrimSpace(dsl) + if dsl == "" || dsl == "{}" { + return "", argStart, nil, nil + } + toks, err := lexFilter(dsl) + if err != nil { + return "", argStart, nil, err + } + if len(toks) == 0 { + return "", argStart, nil, nil + } + p := &filterParser{tokens: toks} + node, err := p.parseExpr() + if err != nil { + return "", argStart, nil, err + } + if p.pos != len(p.tokens) { + return "", argStart, nil, fmt.Errorf("filter: trailing tokens at position %d", p.pos) + } + sql, nextN, args = node.toSQL(argStart, nil) + return sql, nextN, args, nil +}