332 lines
9.6 KiB
Go
332 lines
9.6 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"tanabata/internal/domain"
|
|
)
|
|
|
|
type FileRepository struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewFileRepository(db *pgxpool.Pool) *FileRepository {
|
|
return &FileRepository{db: db}
|
|
}
|
|
|
|
// Get user permissions on file
|
|
func (s *FileRepository) GetAccess(ctx context.Context, user_id int, file_id string) (canView, canEdit bool, domainErr *domain.DomainError) {
|
|
row := s.db.QueryRow(ctx, `
|
|
SELECT
|
|
COALESCE(a.view, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE),
|
|
COALESCE(a.edit, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE)
|
|
FROM data.files f
|
|
LEFT JOIN acl.files a ON a.file_id=f.id AND a.user_id=$1
|
|
LEFT JOIN system.users u ON u.id=$1
|
|
WHERE f.id=$2
|
|
`, user_id, file_id)
|
|
err := row.Scan(&canView, &canEdit)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err)
|
|
return
|
|
}
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "22P02":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Get a set of files
|
|
func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort string, limit, offset int) (files domain.Slice[domain.FileItem], domainErr *domain.DomainError) {
|
|
filterCond, err := filterToSQL(filter)
|
|
if err != nil {
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid filter string: %q", filter)).Wrap(err)
|
|
return
|
|
}
|
|
sortExpr, err := sortToSQL(sort)
|
|
if err != nil {
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid sorting parameter: %q", sort)).Wrap(err)
|
|
return
|
|
}
|
|
// prepare query
|
|
query := `
|
|
SELECT
|
|
f.id,
|
|
f.name,
|
|
m.name,
|
|
m.extension,
|
|
uuid_extract_timestamp(f.id),
|
|
u.name,
|
|
u.is_admin
|
|
FROM data.files f
|
|
JOIN system.mime m ON m.id=f.mime_id
|
|
JOIN system.users u ON u.id=f.creator_id
|
|
WHERE f.is_deleted IS FALSE AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND
|
|
`
|
|
query += filterCond
|
|
queryCount := query
|
|
query += sortExpr
|
|
if limit >= 0 {
|
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
|
}
|
|
if offset > 0 {
|
|
query += fmt.Sprintf(" OFFSET %d", offset)
|
|
}
|
|
// execute query
|
|
domainErr = transaction(ctx, s.db, func(ctx context.Context, tx pgx.Tx) (domainErr *domain.DomainError) {
|
|
rows, err := tx.Query(ctx, query, user_id)
|
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "42P10":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid sorting field: %q", sort[1:])).Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
count := 0
|
|
for rows.Next() {
|
|
var file domain.FileItem
|
|
err = rows.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin)
|
|
if err != nil {
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
files.Data = append(files.Data, file)
|
|
count++
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
files.Pagination.Limit = limit
|
|
files.Pagination.Offset = offset
|
|
files.Pagination.Count = count
|
|
row := tx.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount), user_id)
|
|
err = row.Scan(&files.Pagination.Total)
|
|
if err != nil {
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
}
|
|
return
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get file
|
|
func (s *FileRepository) Get(ctx context.Context, user_id int, file_id string) (file domain.FileFull, domainErr *domain.DomainError) {
|
|
row := s.db.QueryRow(ctx, `
|
|
SELECT
|
|
f.id,
|
|
f.name,
|
|
m.name,
|
|
m.extension,
|
|
uuid_extract_timestamp(f.id),
|
|
u.name,
|
|
u.is_admin,
|
|
f.notes,
|
|
f.metadata,
|
|
(SELECT COUNT(*) FROM activity.file_views fv WHERE fv.file_id=$2 AND fv.user_id=$1)
|
|
FROM data.files f
|
|
JOIN system.mime m ON m.id=f.mime_id
|
|
JOIN system.users u ON u.id=f.creator_id
|
|
WHERE f.is_deleted IS FALSE
|
|
`, user_id, file_id)
|
|
err := row.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin, &file.Notes, &file.Metadata, &file.Viewed)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err)
|
|
return
|
|
}
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "22P02":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// Add file
|
|
func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string, datetime time.Time, notes string, metadata json.RawMessage) (file domain.FileCore, domainErr *domain.DomainError) {
|
|
var mime_id int
|
|
var extension string
|
|
row := s.db.QueryRow(ctx, "SELECT id, extension FROM system.mime WHERE name=$1", mime)
|
|
err := row.Scan(&mime_id, &extension)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
domainErr = domain.NewErrorMIMENotSupported(mime).Wrap(err)
|
|
return
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
row = s.db.QueryRow(ctx, `
|
|
INSERT INTO data.files (name, mime_id, datetime, creator_id, notes, metadata)
|
|
VALUES (NULLIF($1, ''), $2, $3, $4, NULLIF($5 ,''), $6)
|
|
RETURNING id
|
|
`, name, mime_id, datetime, user_id, notes, metadata)
|
|
err = row.Scan(&file.ID)
|
|
if err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "22007":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid datetime: %q", datetime)).Wrap(err)
|
|
return
|
|
case "23502":
|
|
domainErr = domain.NewErrorBadRequest("Unable to set NULL to some fields").Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
file.Name = &name
|
|
file.MIME.Name = mime
|
|
file.MIME.Extension = extension
|
|
return
|
|
}
|
|
|
|
// Update file
|
|
func (s *FileRepository) Update(ctx context.Context, file_id string, updates map[string]interface{}) (domainErr *domain.DomainError) {
|
|
if len(updates) == 0 {
|
|
// domainErr = domain.NewErrorBadRequest(nil, "No fields provided for update")
|
|
return
|
|
}
|
|
query := "UPDATE data.files SET"
|
|
newValues := []interface{}{file_id}
|
|
count := 2
|
|
for field, value := range updates {
|
|
switch field {
|
|
case "name", "notes":
|
|
query += fmt.Sprintf(" %s=NULLIF($%d, '')", field, count)
|
|
case "datetime":
|
|
query += fmt.Sprintf(" %s=NULLIF($%d, '')::timestamptz", field, count)
|
|
case "metadata":
|
|
query += fmt.Sprintf(" %s=NULLIF($%d, '')::jsonb", field, count)
|
|
default:
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Unknown field: %q", field))
|
|
return
|
|
}
|
|
newValues = append(newValues, value)
|
|
count++
|
|
}
|
|
query += fmt.Sprintf(" WHERE id=$1 AND is_deleted IS FALSE")
|
|
commandTag, err := s.db.Exec(ctx, query, newValues...)
|
|
if err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "22P02":
|
|
domainErr = domain.NewErrorBadRequest("Invalid format of some values").Wrap(err)
|
|
return
|
|
case "22007":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid datetime: %q", updates["datetime"])).Wrap(err)
|
|
return
|
|
case "23502":
|
|
domainErr = domain.NewErrorBadRequest("Some fields cannot be empty").Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
if commandTag.RowsAffected() == 0 {
|
|
domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// Delete file
|
|
func (s *FileRepository) Delete(ctx context.Context, file_id string) (domainErr *domain.DomainError) {
|
|
commandTag, err := s.db.Exec(ctx,
|
|
"UPDATE data.files SET is_deleted=true WHERE id=$1 AND is_deleted IS FALSE",
|
|
file_id)
|
|
if err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
switch pgErr.Code {
|
|
case "22P02":
|
|
domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err)
|
|
return
|
|
}
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
if commandTag.RowsAffected() == 0 {
|
|
domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// Get list of tags of file
|
|
func (s *FileRepository) GetTags(ctx context.Context, user_id int, file_id string) (tags []domain.TagItem, domainErr *domain.DomainError) {
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT
|
|
t.id,
|
|
t.name,
|
|
t.color,
|
|
c.id,
|
|
c.name,
|
|
c.color
|
|
FROM data.tags t
|
|
LEFT JOIN data.categories c ON c.id=t.category_id
|
|
JOIN data.file_tag ft ON ft.tag_id=t.id AND ft.file_id=$2
|
|
JOIN data.files f ON f.id=$2
|
|
WHERE NOT f.is_deleted AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))
|
|
`, user_id, file_id)
|
|
if err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") {
|
|
domainErr = domain.NewErrorBadRequest(pgErr.Message).Wrap(err)
|
|
return
|
|
}
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var tag domain.TagItem
|
|
err = rows.Scan(&tag.ID, &tag.Name, &tag.Color, &tag.Category.ID, &tag.Category.Name, &tag.Category.Color)
|
|
if err != nil {
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
return
|
|
}
|
|
tags = append(tags, tag)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
domainErr = domain.NewErrorUnexpected().Wrap(err)
|
|
}
|
|
return
|
|
}
|