package postgres import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "tanabata/internal/domain" ) type FileRepository struct { db *pgxpool.Pool } func NewFileRepository(db *pgxpool.Pool) *FileRepository { return &FileRepository{db: db} } // Check if user can view file func (s *FileRepository) GetAccess(ctx context.Context, user_id int, file_id string) (canView, canEdit bool, domainErr *domain.DomainError) { row := s.db.QueryRow(ctx, ` SELECT COALESCE(a.view, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE), COALESCE(a.edit, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE) FROM data.files f LEFT JOIN acl.files a ON a.file_id=f.id AND a.user_id=$1 LEFT JOIN system.users u ON u.id=$1 WHERE f.id=$2 `, user_id, file_id) err := row.Scan(&canView, &canEdit) if err != nil { if errors.Is(err, pgx.ErrNoRows) { domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) return } var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) } return } // Get a set of files func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort string, limit, offset int) (files domain.Slice[domain.FileItem], domainErr *domain.DomainError) { filterCond, err := filterToSQL(filter) if err != nil { domainErr = domain.NewDomainError(err, domain.ErrValidation, "filter", err.Error()) return } sortExpr, err := sortToSQL(sort) if err != nil { domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort param", err.Error()) return } // prepare query query := ` SELECT f.id, f.name, m.name, m.extension, uuid_extract_timestamp(f.id), u.name, u.is_admin FROM data.files f JOIN system.mime m ON m.id=f.mime_id JOIN system.users u ON u.id=f.creator_id WHERE NOT f.is_deleted AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND ` query += filterCond queryCount := query query += sortExpr if limit >= 0 { query += fmt.Sprintf(" LIMIT %d", limit) } if offset > 0 { query += fmt.Sprintf(" OFFSET %d", offset) } // execute query domainErr = transaction(ctx, s.db, func(ctx context.Context, tx pgx.Tx) (domainErr *domain.DomainError) { rows, err := tx.Query(ctx, query, user_id) if err != nil && !errors.Is(err, pgx.ErrNoRows) { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { switch pgErr.Code { case "22P02", "22007": domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return case "42P10": domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort field", sort[1:]) return } } domainErr = domain.NewUnexpectedError(err) return } defer rows.Close() count := 0 for rows.Next() { var file domain.FileItem err = rows.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin) if err != nil { domainErr = domain.NewUnexpectedError(err) return } files.Data = append(files.Data, file) count++ } err = rows.Err() if err != nil { domainErr = domain.NewUnexpectedError(err) return } files.Pagination.Limit = limit files.Pagination.Offset = offset files.Pagination.Count = count row := tx.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount), user_id) err = row.Scan(&files.Pagination.Total) if err != nil { domainErr = domain.NewUnexpectedError(err) } return }) return } // Get file func (s *FileRepository) Get(ctx context.Context, user_id int, file_id string) (file domain.FileFull, domainErr *domain.DomainError) { row := s.db.QueryRow(ctx, ` SELECT f.id, f.name, m.name, m.extension, uuid_extract_timestamp(f.id), u.name, u.is_admin, f.notes, f.metadata, (SELECT COUNT(*) FROM activity.file_views fv WHERE fv.file_id=$2 AND fv.user_id=$1) FROM data.files f JOIN system.mime m ON m.id=f.mime_id JOIN system.users u ON u.id=f.creator_id WHERE NOT f.is_deleted AND f.id=$2 AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) `, user_id, file_id) err := row.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin, &file.Notes, &file.Metadata, &file.Viewed) if err != nil { if errors.Is(err, pgx.ErrNoRows) { domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) return } var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) return } return } // Add file func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string, datetime time.Time, notes string, metadata json.RawMessage) (file domain.FileCore, domainErr *domain.DomainError) { var mime_id int var extension string row := s.db.QueryRow(ctx, "SELECT id, extension FROM system.mime WHERE name=$1", mime) err := row.Scan(&mime_id, &extension) if err != nil { if errors.Is(err, pgx.ErrNoRows) { domainErr = domain.NewDomainError(err, domain.ErrMIMENotSupported, mime) return } domainErr = domain.NewUnexpectedError(err) return } row = s.db.QueryRow(ctx, ` INSERT INTO data.files (name, mime_id, datetime, creator_id, notes, metadata) VALUES (NULLIF($1, ''), $2, $3, $4, NULLIF($5 ,''), $6) RETURNING id `, name, mime_id, datetime, user_id, notes, metadata) err = row.Scan(&file.ID) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) return } file.Name = &name file.MIME.Name = mime file.MIME.Extension = extension return } // Update file func (s *FileRepository) Update(ctx context.Context, user_id int, file_id string, updates map[string]interface{}) (domainErr *domain.DomainError) { if len(updates) == 0 { domainErr = domain.NewDomainError(errors.ErrUnsupported, domain.ErrValidation, "request body", "no fields provided for update") return } writableFields := map[string]bool{ "name": true, "datetime": true, "notes": true, "metadata": true, } query := "UPDATE data.files SET" newValues := []interface{}{user_id} count := 2 for field, value := range updates { if !writableFields[field] { domainErr = domain.NewDomainError(errors.ErrUnsupported, domain.ErrValidation, "field", field) return } query += fmt.Sprintf(" %s=NULLIF($%d, '')", field, count) newValues = append(newValues, value) count++ } query += fmt.Sprintf( " WHERE id=$%d AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$%d AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", count, count) newValues = append(newValues, file_id) commandTag, err := s.db.Exec(ctx, query, newValues...) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) return } if commandTag.RowsAffected() == 0 { domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) return } return } // Delete file func (s *FileRepository) Delete(ctx context.Context, user_id int, file_id string) (domainErr *domain.DomainError) { commandTag, err := s.db.Exec(ctx, "DELETE FROM data.files WHERE id=$2 AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", user_id, file_id) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) return } if commandTag.RowsAffected() == 0 { domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) return } return } func (s *FileRepository) GetTags(ctx context.Context, user_id int, file_id string) (tags []domain.TagItem, domainErr *domain.DomainError) { rows, err := s.db.Query(ctx, ` SELECT t.id, t.name, t.color, c.id, c.name, c.color FROM data.tags t LEFT JOIN data.categories c ON c.id=t.category_id JOIN data.file_tag ft ON ft.tag_id=t.id AND ft.file_id=$2 JOIN data.files f ON f.id=$2 WHERE NOT f.is_deleted AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) `, user_id, file_id) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) return } domainErr = domain.NewUnexpectedError(err) return } defer rows.Close() for rows.Next() { var tag domain.TagItem err = rows.Scan(&tag.ID, &tag.Name, &tag.Color, &tag.Category.ID, &tag.Category.Name, &tag.Category.Color) if err != nil { domainErr = domain.NewUnexpectedError(err) return } tags = append(tags, tag) } err = rows.Err() if err != nil { domainErr = domain.NewUnexpectedError(err) } return }