diff --git a/backend/internal/domain/entity.go b/backend/internal/domain/entity.go index 76fa57c..9b23718 100644 --- a/backend/internal/domain/entity.go +++ b/backend/internal/domain/entity.go @@ -52,7 +52,6 @@ type ( Creator User `json:"creator"` Notes pgtype.Text `json:"notes"` Metadata json.RawMessage `json:"metadata"` - Tags []TagCore `json:"tags"` Viewed int `json:"viewed"` } ) diff --git a/backend/internal/domain/errors.go b/backend/internal/domain/errors.go new file mode 100644 index 0000000..cf9bba5 --- /dev/null +++ b/backend/internal/domain/errors.go @@ -0,0 +1,46 @@ +package domain + +type ErrorCode string + +const ( + // File errors + ErrFileNotFound ErrorCode = "FILE_NOT_FOUND" + ErrMIMENotSupported ErrorCode = "MIME_NOT_SUPPORTEDF" + + // Tag errors + ErrTagNotFound ErrorCode = "TAG_NOT_FOUND" + + // General errors + ErrValidation ErrorCode = "VALIDATION_ERROR" + ErrInternal ErrorCode = "INTERNAL_SERVER_ERROR" +) + +type DomainError struct { + Err error `json:"-"` + Code ErrorCode `json:"code"` + Message string `json:"message"` + Details []any `json:"-"` +} + +func (e *DomainError) Error() string { + if e.Err != nil { + return e.Message + ": " + e.Err.Error() + } + return e.Message +} + +func NewDomainError(err error, code ErrorCode, details ...any) *DomainError { + return &DomainError{ + Err: err, + Code: code, + Details: details, + } +} + +func NewUnexpectedError(err error) *DomainError { + return &DomainError{ + Err: err, + Code: ErrInternal, + Message: "An unexpected error occured", + } +} diff --git a/backend/internal/domain/repositories.go b/backend/internal/domain/repositories.go new file mode 100644 index 0000000..1cff795 --- /dev/null +++ b/backend/internal/domain/repositories.go @@ -0,0 +1,16 @@ +package domain + +import ( + "context" + "encoding/json" + "time" +) + +type FileRepository interface { + GetAccess(ctx context.Context, user_id int, file_id string) (canView, canEdit bool, domainErr *DomainError) + GetSlice(ctx context.Context, user_id int, filter, sort string, limit, offset int) (files Slice[FileItem], domainErr *DomainError) + Get(ctx context.Context, user_id int, file_id string) (file FileFull, domainErr *DomainError) + Add(ctx context.Context, user_id int, name, mime string, datetime time.Time, notes string, metadata json.RawMessage) (file FileCore, domainErr *DomainError) + Update(ctx context.Context, user_id int, file_id string, updates map[string]interface{}) (domainErr *DomainError) + Delete(ctx context.Context, user_id int, file_id string) (domainErr *DomainError) +} diff --git a/backend/internal/infrastructure/persistence/postgres/db.go b/backend/internal/infrastructure/persistence/postgres/db.go new file mode 100644 index 0000000..4103121 --- /dev/null +++ b/backend/internal/infrastructure/persistence/postgres/db.go @@ -0,0 +1,50 @@ +package postgres + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "tanabata/internal/domain" +) + +// Initialize PostgreSQL database driver +func New(dbURL string) (*pgxpool.Pool, error) { + poolConfig, err := pgxpool.ParseConfig(dbURL) + if err != nil { + return nil, fmt.Errorf("error while parsing connection string: %w", err) + } + + poolConfig.MaxConns = 100 + poolConfig.MinConns = 0 + poolConfig.MaxConnLifetime = time.Hour + poolConfig.HealthCheckPeriod = 30 * time.Second + + db, err := pgxpool.NewWithConfig(context.Background(), poolConfig) + if err != nil { + return nil, fmt.Errorf("error while initializing DB connections pool: %w", err) + } + return db, nil +} + +// Transaction wrapper +func transaction(ctx context.Context, db *pgxpool.Pool, handler func(context.Context, pgx.Tx) *domain.DomainError) (domainErr *domain.DomainError) { + tx, err := db.Begin(ctx) + if err != nil { + domainErr = domain.NewUnexpectedError(err) + return + } + domainErr = handler(ctx, tx) + if domainErr != nil { + tx.Rollback(ctx) + return + } + err = tx.Commit(ctx) + if err != nil { + domainErr = domain.NewUnexpectedError(err) + } + return +} diff --git a/backend/internal/infrastructure/persistence/postgres/file_repository.go b/backend/internal/infrastructure/persistence/postgres/file_repository.go new file mode 100644 index 0000000..a0e7f61 --- /dev/null +++ b/backend/internal/infrastructure/persistence/postgres/file_repository.go @@ -0,0 +1,272 @@ +package postgres + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + + "tanabata/internal/domain" +) + +type FileRepository struct { + db *pgxpool.Pool +} + +func NewFileRepository(db *pgxpool.Pool) *FileRepository { + return &FileRepository{db: db} +} + +// Check if user can view file +func (s *FileRepository) GetAccess(ctx context.Context, user_id int, file_id string) (canView, canEdit bool, domainErr *domain.DomainError) { + row := s.db.QueryRow(ctx, ` + SELECT + COALESCE(a.view, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE), + COALESCE(a.edit, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE) + FROM data.files f + LEFT JOIN acl.files a ON a.file_id=f.id AND a.user_id=$1 + LEFT JOIN system.users u ON u.id=$1 + WHERE f.id=$2 + `, user_id, file_id) + err := row.Scan(&canView, &canEdit) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + return + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + } + domainErr = domain.NewUnexpectedError(err) + } + return +} + +// Get a set of files +func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort string, limit, offset int) (files domain.Slice[domain.FileItem], domainErr *domain.DomainError) { + filterCond, err := filterToSQL(filter) + if err != nil { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "filter", err.Error()) + return + } + sortExpr, err := sortToSQL(sort) + if err != nil { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort param", err.Error()) + return + } + // prepare query + query := ` + SELECT + f.id, + f.name, + m.name, + m.extension, + uuid_extract_timestamp(f.id), + u.name, + u.is_admin + FROM data.files f + JOIN system.mime m ON m.id=f.mime_id + JOIN system.users u ON u.id=f.creator_id + WHERE NOT f.is_deleted AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND + ` + query += filterCond + queryCount := query + query += sortExpr + if limit >= 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + // execute query + domainErr = transaction(ctx, s.db, func(ctx context.Context, tx pgx.Tx) (domainErr *domain.DomainError) { + rows, err := tx.Query(ctx, query, user_id) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "22P02", "22007": + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + case "42P10": + domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort field", sort[1:]) + return + } + } + domainErr = domain.NewUnexpectedError(err) + return + } + defer rows.Close() + count := 0 + for rows.Next() { + var file domain.FileItem + err = rows.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin) + if err != nil { + domainErr = domain.NewUnexpectedError(err) + return + } + files.Data = append(files.Data, file) + count++ + } + err = rows.Err() + if err != nil { + domainErr = domain.NewUnexpectedError(err) + return + } + files.Pagination.Limit = limit + files.Pagination.Offset = offset + files.Pagination.Count = count + row := tx.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount), user_id) + err = row.Scan(&files.Pagination.Total) + if err != nil { + domainErr = domain.NewUnexpectedError(err) + } + return + }) + return +} + +// Get file +func (s *FileRepository) Get(ctx context.Context, user_id int, file_id string) (file domain.FileFull, domainErr *domain.DomainError) { + row := s.db.QueryRow(ctx, ` + SELECT + f.id, + f.name, + m.name, + m.extension, + uuid_extract_timestamp(f.id), + u.name, + u.is_admin, + f.notes, + f.metadata, + (SELECT COUNT(*) FROM activity.file_views fv WHERE fv.file_id=$2 AND fv.user_id=$1) + FROM data.files f + JOIN system.mime m ON m.id=f.mime_id + JOIN system.users u ON u.id=f.creator_id + WHERE NOT f.is_deleted AND f.id=$2 AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) + `, user_id, file_id) + err := row.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin, &file.Notes, &file.Metadata, &file.Viewed) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + return + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + } + domainErr = domain.NewUnexpectedError(err) + return + } + return +} + +// Add file +func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string, datetime time.Time, notes string, metadata json.RawMessage) (file domain.FileCore, domainErr *domain.DomainError) { + var mime_id int + var extension string + row := s.db.QueryRow(ctx, "SELECT id, extension FROM system.mime WHERE name=$1", mime) + err := row.Scan(&mime_id, &extension) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + domainErr = domain.NewDomainError(err, domain.ErrMIMENotSupported, mime) + return + } + domainErr = domain.NewUnexpectedError(err) + return + } + row = s.db.QueryRow(ctx, ` + INSERT INTO data.files (name, mime_id, datetime, creator_id, notes, metadata) + VALUES (NULLIF($1, ''), $2, $3, $4, NULLIF($5 ,''), $6) + RETURNING id + `, name, mime_id, datetime, user_id, notes, metadata) + err = row.Scan(&file.ID) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + } + domainErr = domain.NewUnexpectedError(err) + return + } + file.Name.String = name + file.Name.Valid = (name != "") + file.MIME.Name = mime + file.MIME.Extension = extension + return +} + +// Update file +func (s *FileRepository) Update(ctx context.Context, user_id int, file_id string, updates map[string]interface{}) (domainErr *domain.DomainError) { + if len(updates) == 0 { + domainErr = domain.NewDomainError(errors.ErrUnsupported, domain.ErrValidation, "request body", "no fields provided for update") + return + } + writableFields := map[string]bool{ + "name": true, + "datetime": true, + "notes": true, + "metadata": true, + } + query := "UPDATE data.files SET" + newValues := []interface{}{user_id} + count := 2 + for field, value := range updates { + if !writableFields[field] { + domainErr = domain.NewDomainError(errors.ErrUnsupported, domain.ErrValidation, "field", field) + return + } + query += fmt.Sprintf(" %s=NULLIF($%d, '')", field, count) + newValues = append(newValues, value) + count++ + } + query += fmt.Sprintf( + " WHERE id=$%d AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$%d AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", + count, count) + newValues = append(newValues, file_id) + commandTag, err := s.db.Exec(ctx, query, newValues...) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + } + domainErr = domain.NewUnexpectedError(err) + return + } + if commandTag.RowsAffected() == 0 { + domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + return + } + return +} + +// Delete file +func (s *FileRepository) Delete(ctx context.Context, user_id int, file_id string) (domainErr *domain.DomainError) { + commandTag, err := s.db.Exec(ctx, + "DELETE FROM data.files WHERE id=$2 AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", + user_id, file_id) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { + domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + return + } + domainErr = domain.NewUnexpectedError(err) + return + } + if commandTag.RowsAffected() == 0 { + domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + return + } + return +} diff --git a/backend/internal/storage/postgres/utils.go b/backend/internal/infrastructure/persistence/postgres/utils.go similarity index 55% rename from backend/internal/storage/postgres/utils.go rename to backend/internal/infrastructure/persistence/postgres/utils.go index f190f5a..a755e97 100644 --- a/backend/internal/storage/postgres/utils.go +++ b/backend/internal/infrastructure/persistence/postgres/utils.go @@ -1,21 +1,52 @@ package postgres import ( + "errors" "fmt" "net/http" "strconv" "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) +// Handle database error +func handleDBError(errIn error) (statusCode int, err error) { + if errIn == nil { + statusCode = http.StatusOK + return + } + if errors.Is(errIn, pgx.ErrNoRows) { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + var pgErr *pgconn.PgError + if errors.As(errIn, &pgErr) { + switch pgErr.Code { + case "22P02", "22007": // Invalid data format + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + return + case "23505": // Unique constraint violation + err = fmt.Errorf("already exists") + statusCode = http.StatusConflict + return + } + } + return http.StatusInternalServerError, errIn +} + // Convert "filter" URL param to SQL "WHERE" condition -func filterToSQL(filter string) (sql string, statusCode int, err error) { +func filterToSQL(filter string) (sql string, err error) { // filterTokens := strings.Split(string(filter), ";") sql = "(true)" return } // Convert "sort" URL param to SQL "ORDER BY" -func sortToSQL(sort string) (sql string, statusCode int, err error) { +func sortToSQL(sort string) (sql string, err error) { if sort == "" { return } @@ -32,7 +63,6 @@ func sortToSQL(sort string) (sql string, statusCode int, err error) { sortOrder = "DESC" default: err = fmt.Errorf("invalid sorting order mark: %q", sortOrder) - statusCode = http.StatusBadRequest return } // validate sorting column @@ -40,7 +70,6 @@ func sortToSQL(sort string) (sql string, statusCode int, err error) { n, err = strconv.Atoi(sortColumn) if err != nil || n < 0 { err = fmt.Errorf("invalid sorting column: %q", sortColumn) - statusCode = http.StatusBadRequest return } // add sorting option to query diff --git a/backend/internal/interfaces/rest/handlers/error_handler.go b/backend/internal/interfaces/rest/handlers/error_handler.go new file mode 100644 index 0000000..b03fb4e --- /dev/null +++ b/backend/internal/interfaces/rest/handlers/error_handler.go @@ -0,0 +1,44 @@ +package rest + +import ( + "fmt" + "net/http" + + "tanabata/internal/domain" +) + +type ErrorResponse struct { + Error string `json:"error"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type ErrorMapper struct{} + +func (m *ErrorMapper) MapError(err domain.DomainError) (int, ErrorResponse) { + switch err.Code { + case domain.ErrFileNotFound: + return http.StatusNotFound, ErrorResponse{ + Error: "Not Found", + Code: string(err.Code), + Message: fmt.Sprintf("File %q not found", err.Details...), + } + case domain.ErrMIMENotSupported: + return http.StatusNotFound, ErrorResponse{ + Error: "MIME not supported", + Code: string(err.Code), + Message: fmt.Sprintf("MIME not supported: %q", err.Details...), + } + case domain.ErrValidation: + return http.StatusNotFound, ErrorResponse{ + Error: "Bad Request", + Code: string(err.Code), + Message: fmt.Sprintf("Invalid %s: %s", err.Details...), + } + } + return http.StatusInternalServerError, ErrorResponse{ + Error: "Internal Server Error", + Code: string(err.Code), + Message: "An unexpected error occured", + } +} diff --git a/backend/internal/storage/postgres/files.go b/backend/internal/storage/postgres/files.go deleted file mode 100644 index e5f96c9..0000000 --- a/backend/internal/storage/postgres/files.go +++ /dev/null @@ -1,259 +0,0 @@ -package postgres - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/jackc/pgx/v5" - - "tanabata/internal/domain" -) - -// Check if user can view file -func FileGetAccess(user_id int, file_id string) (canView, canEdit bool, err error) { - ctx := context.Background() - row := connPool.QueryRow(ctx, ` - SELECT - COALESCE(a.view, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE), - COALESCE(a.edit, FALSE) OR f.creator_id=$1 OR COALESCE(u.is_admin, FALSE) - FROM data.files f - LEFT JOIN acl.files a ON a.file_id=f.id AND a.user_id=$1 - LEFT JOIN system.users u ON u.id=$1 - WHERE f.id=$2 - `, user_id, file_id) - err = row.Scan(&canView, &canEdit) - return -} - -// Get a set of files -func FileGetSlice(user_id int, filter, sort string, limit, offset int) (files domain.Slice[domain.FileItem], statusCode int, err error) { - filterCond, statusCode, err := filterToSQL(filter) - if err != nil { - return - } - sortExpr, statusCode, err := sortToSQL(sort) - if err != nil { - return - } - // prepare query - query := ` - SELECT - f.id, - f.name, - m.name, - m.extension, - uuid_extract_timestamp(f.id), - u.name, - u.is_admin - FROM data.files f - JOIN system.mime m ON m.id=f.mime_id - JOIN system.users u ON u.id=f.creator_id - WHERE NOT f.is_deleted AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND - ` - query += filterCond - queryCount := query - query += sortExpr - if limit >= 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - // execute query - statusCode, err = transaction(func(ctx context.Context, tx pgx.Tx) (statusCode int, err error) { - rows, err := tx.Query(ctx, query, user_id) - if err != nil { - statusCode, err = handleDBError(err) - return - } - defer rows.Close() - count := 0 - for rows.Next() { - var file domain.FileItem - err = rows.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin) - if err != nil { - statusCode = http.StatusInternalServerError - return - } - files.Data = append(files.Data, file) - count++ - } - err = rows.Err() - if err != nil { - statusCode = http.StatusInternalServerError - return - } - files.Pagination.Limit = limit - files.Pagination.Offset = offset - files.Pagination.Count = count - row := tx.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount), user_id) - err = row.Scan(&files.Pagination.Total) - if err != nil { - statusCode = http.StatusInternalServerError - } - return - }) - if err == nil { - statusCode = http.StatusOK - } - return -} - -// Get file -func FileGet(user_id int, file_id string) (file domain.FileFull, statusCode int, err error) { - ctx := context.Background() - row := connPool.QueryRow(ctx, ` - SELECT - f.id, - f.name, - m.name, - m.extension, - uuid_extract_timestamp(f.id), - u.name, - u.is_admin, - f.notes, - f.metadata, - (SELECT COUNT(*) FROM activity.file_views fv WHERE fv.file_id=$2 AND fv.user_id=$1) - FROM data.files f - JOIN system.mime m ON m.id=f.mime_id - JOIN system.users u ON u.id=f.creator_id - WHERE NOT f.is_deleted AND f.id=$2 AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) - `, user_id, file_id) - err = row.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin, &file.Notes, &file.Metadata, &file.Viewed) - if err != nil { - statusCode, err = handleDBError(err) - return - } - rows, err := connPool.Query(ctx, ` - SELECT - t.id, - t.name, - COALESCE(t.color, c.color) - FROM data.tags t - LEFT JOIN data.categories c ON c.id=t.category_id - JOIN data.file_tag ft ON ft.tag_id=t.id - WHERE ft.file_id=$1 - `, file_id) - if err != nil { - statusCode, err = handleDBError(err) - return - } - defer rows.Close() - for rows.Next() { - var tag domain.TagCore - err = rows.Scan(&tag.ID, &tag.Name, &tag.Color) - if err != nil { - statusCode = http.StatusInternalServerError - return - } - file.Tags = append(file.Tags, tag) - } - err = rows.Err() - if err != nil { - statusCode = http.StatusInternalServerError - return - } - statusCode = http.StatusOK - return -} - -// Add file -func FileAdd(user_id int, name, mime string, datetime time.Time, notes string, metadata json.RawMessage) (file domain.FileCore, statusCode int, err error) { - ctx := context.Background() - var mime_id int - var extension string - row := connPool.QueryRow(ctx, "SELECT id, extension FROM system.mime WHERE name=$1", mime) - err = row.Scan(&mime_id, &extension) - if err != nil { - if err == pgx.ErrNoRows { - err = fmt.Errorf("unsupported file type: %q", mime) - statusCode = http.StatusBadRequest - } else { - statusCode, err = handleDBError(err) - } - return - } - row = connPool.QueryRow(ctx, ` - INSERT INTO data.files (name, mime_id, datetime, creator_id, notes, metadata) - VALUES (NULLIF($1, ''), $2, $3, $4, NULLIF($5 ,''), $6) - RETURNING id - `, name, mime_id, datetime, user_id, notes, metadata) - err = row.Scan(&file.ID) - if err != nil { - statusCode, err = handleDBError(err) - return - } - file.Name.String = name - file.Name.Valid = (name != "") - file.MIME.Name = mime - file.MIME.Extension = extension - statusCode = http.StatusOK - return -} - -// Update file -func FileUpdate(user_id int, file_id string, updates map[string]interface{}) (statusCode int, err error) { - if len(updates) == 0 { - err = fmt.Errorf("no fields provided for update") - statusCode = http.StatusBadRequest - return - } - writableFields := map[string]bool{ - "name": true, - "datetime": true, - "notes": true, - "metadata": true, - } - query := "UPDATE data.files SET" - newValues := []interface{}{user_id} - count := 2 - for field, value := range updates { - if !writableFields[field] { - err = fmt.Errorf("invalid field: %q", field) - statusCode = http.StatusBadRequest - return - } - query += fmt.Sprintf(" %s=NULLIF($%d, '')", field, count) - newValues = append(newValues, value) - count++ - } - query += fmt.Sprintf( - " WHERE id=$%d AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$%d AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", - count, count) - newValues = append(newValues, file_id) - ctx := context.Background() - commandTag, err := connPool.Exec(ctx, query, newValues...) - if err != nil { - statusCode, err = handleDBError(err) - return - } - if commandTag.RowsAffected() == 0 { - err = fmt.Errorf("not found") - statusCode = http.StatusNotFound - return - } - statusCode = http.StatusNoContent - return -} - -// Delete file -func FileDelete(user_id int, file_id string) (statusCode int, err error) { - ctx := context.Background() - commandTag, err := connPool.Exec(ctx, - "DELETE FROM data.files WHERE id=$2 AND (creator_id=$1 OR (SELECT edit FROM acl.files WHERE file_id=$2 AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1))", - user_id, file_id) - if err != nil { - statusCode, err = handleDBError(err) - return - } - if commandTag.RowsAffected() == 0 { - err = fmt.Errorf("not found") - statusCode = http.StatusNotFound - return - } - statusCode = http.StatusNoContent - return -} diff --git a/backend/internal/storage/postgres/store.go b/backend/internal/storage/postgres/store.go deleted file mode 100644 index 56723f3..0000000 --- a/backend/internal/storage/postgres/store.go +++ /dev/null @@ -1,79 +0,0 @@ -package postgres - -import ( - "context" - "errors" - "fmt" - "net/http" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" -) - -var connPool *pgxpool.Pool - -func InitDB(connString string) error { - poolConfig, err := pgxpool.ParseConfig(connString) - if err != nil { - return fmt.Errorf("error while parsing connection string: %w", err) - } - - poolConfig.MaxConns = 100 - poolConfig.MinConns = 0 - poolConfig.MaxConnLifetime = time.Hour - poolConfig.HealthCheckPeriod = 30 * time.Second - - connPool, err = pgxpool.NewWithConfig(context.Background(), poolConfig) - if err != nil { - return fmt.Errorf("error while initializing DB connections pool: %w", err) - } - return nil -} - -func transaction(handler func(context.Context, pgx.Tx) (statusCode int, err error)) (statusCode int, err error) { - ctx := context.Background() - tx, err := connPool.Begin(ctx) - if err != nil { - statusCode = http.StatusInternalServerError - return - } - statusCode, err = handler(ctx, tx) - if err != nil { - tx.Rollback(ctx) - return - } - err = tx.Commit(ctx) - if err != nil { - statusCode = http.StatusInternalServerError - } - return -} - -// Handle database error -func handleDBError(errIn error) (statusCode int, err error) { - if errIn == nil { - statusCode = http.StatusOK - return - } - if errors.Is(errIn, pgx.ErrNoRows) { - err = fmt.Errorf("not found") - statusCode = http.StatusNotFound - return - } - var pgErr *pgconn.PgError - if errors.As(errIn, &pgErr) { - switch pgErr.Code { - case "22P02", "22007": // Invalid data format - err = fmt.Errorf("%s", pgErr.Message) - statusCode = http.StatusBadRequest - return - case "23505": // Unique constraint violation - err = fmt.Errorf("already exists") - statusCode = http.StatusConflict - return - } - } - return http.StatusInternalServerError, errIn -}