diff --git a/backend/internal/domain/errors.go b/backend/internal/domain/errors.go index 29da1a5..a4bd62c 100644 --- a/backend/internal/domain/errors.go +++ b/backend/internal/domain/errors.go @@ -1,18 +1,20 @@ package domain +import "fmt" + type ErrorCode string const ( // File errors - ErrFileNotFound ErrorCode = "FILE_NOT_FOUND" - ErrMIMENotSupported ErrorCode = "MIME_NOT_SUPPORTED" + ErrCodeFileNotFound ErrorCode = "FILE_NOT_FOUND" + ErrCodeMIMENotSupported ErrorCode = "MIME_NOT_SUPPORTED" // Tag errors - ErrTagNotFound ErrorCode = "TAG_NOT_FOUND" + ErrCodeTagNotFound ErrorCode = "TAG_NOT_FOUND" // General errors - ErrValidation ErrorCode = "VALIDATION_ERROR" - ErrInternal ErrorCode = "INTERNAL_SERVER_ERROR" + ErrCodeBadRequest ErrorCode = "BAD_REQUEST" + ErrCodeInternal ErrorCode = "INTERNAL_SERVER_ERROR" ) type DomainError struct { @@ -22,25 +24,42 @@ type DomainError struct { Details []any `json:"-"` } -func (e *DomainError) Error() string { - if e.Err != nil { - return e.Message + ": " + e.Err.Error() - } - return e.Message +func (e *DomainError) Wrap(err error) *DomainError { + e.Err = err + return e } -func NewDomainError(err error, code ErrorCode, details ...any) *DomainError { +func NewErrorFileNotFound(file_id string) *DomainError { return &DomainError{ - Err: err, - Code: code, - Details: details, + Code: ErrCodeFileNotFound, + Message: fmt.Sprintf("File not found: %q", file_id), } } -func NewUnexpectedError(err error) *DomainError { +func NewErrorMIMENotSupported(mime string) *DomainError { return &DomainError{ - Err: err, - Code: ErrInternal, + Code: ErrCodeMIMENotSupported, + Message: fmt.Sprintf("MIME not supported: %q", mime), + } +} + +func NewErrorTagNotFound(tag_id string) *DomainError { + return &DomainError{ + Code: ErrCodeTagNotFound, + Message: fmt.Sprintf("Tag not found: %q", tag_id), + } +} + +func NewErrorBadRequest(message string) *DomainError { + return &DomainError{ + Code: ErrCodeBadRequest, + Message: message, + } +} + +func NewErrorUnexpected() *DomainError { + return &DomainError{ + Code: ErrCodeInternal, Message: "An unexpected error occured", } } diff --git a/backend/internal/infrastructure/persistence/postgres/db.go b/backend/internal/infrastructure/persistence/postgres/db.go index 2cc4be4..de1b4f6 100644 --- a/backend/internal/infrastructure/persistence/postgres/db.go +++ b/backend/internal/infrastructure/persistence/postgres/db.go @@ -38,7 +38,7 @@ func New(dbURL string) (*pgxpool.Pool, error) { func transaction(ctx context.Context, db *pgxpool.Pool, handler func(context.Context, pgx.Tx) *domain.DomainError) (domainErr *domain.DomainError) { tx, err := db.Begin(ctx) if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } domainErr = handler(ctx, tx) @@ -48,7 +48,7 @@ func transaction(ctx context.Context, db *pgxpool.Pool, handler func(context.Con } err = tx.Commit(ctx) if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) } return } diff --git a/backend/internal/infrastructure/persistence/postgres/file_repository.go b/backend/internal/infrastructure/persistence/postgres/file_repository.go index 83a32d3..8cbb8af 100644 --- a/backend/internal/infrastructure/persistence/postgres/file_repository.go +++ b/backend/internal/infrastructure/persistence/postgres/file_repository.go @@ -36,15 +36,18 @@ func (s *FileRepository) GetAccess(ctx context.Context, user_id int, file_id str err := row.Scan(&canView, &canEdit) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err) return } var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) - return + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "22P02": + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err) + return + } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) } return } @@ -53,28 +56,28 @@ func (s *FileRepository) GetAccess(ctx context.Context, user_id int, file_id str func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort string, limit, offset int) (files domain.Slice[domain.FileItem], domainErr *domain.DomainError) { filterCond, err := filterToSQL(filter) if err != nil { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "filter", err.Error()) + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid filter string: %q", filter)).Wrap(err) return } sortExpr, err := sortToSQL(sort) if err != nil { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort param", err.Error()) + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid sorting parameter: %q", sort)).Wrap(err) return } // prepare query query := ` - SELECT - f.id, - f.name, - m.name, - m.extension, - uuid_extract_timestamp(f.id), - u.name, - u.is_admin - FROM data.files f - JOIN system.mime m ON m.id=f.mime_id - JOIN system.users u ON u.id=f.creator_id - WHERE f.is_deleted IS FALSE AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND + SELECT + f.id, + f.name, + m.name, + m.extension, + uuid_extract_timestamp(f.id), + u.name, + u.is_admin + FROM data.files f + JOIN system.mime m ON m.id=f.mime_id + JOIN system.users u ON u.id=f.creator_id + WHERE f.is_deleted IS FALSE AND (f.creator_id=$1 OR (SELECT view FROM acl.files WHERE file_id=f.id AND user_id=$1) OR (SELECT is_admin FROM system.users WHERE id=$1)) AND ` query += filterCond queryCount := query @@ -92,15 +95,12 @@ func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort var pgErr *pgconn.PgError if errors.As(err, &pgErr) { switch pgErr.Code { - case "22P02", "22007": - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) - return case "42P10": - domainErr = domain.NewDomainError(err, domain.ErrValidation, "sort field", sort[1:]) + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid sorting field: %q", sort[1:])).Wrap(err) return } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } defer rows.Close() @@ -109,7 +109,7 @@ func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort var file domain.FileItem err = rows.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin) if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } files.Data = append(files.Data, file) @@ -117,7 +117,7 @@ func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort } err = rows.Err() if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } files.Pagination.Limit = limit @@ -126,7 +126,7 @@ func (s *FileRepository) GetSlice(ctx context.Context, user_id int, filter, sort row := tx.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount), user_id) err = row.Scan(&files.Pagination.Total) if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) } return }) @@ -155,15 +155,18 @@ func (s *FileRepository) Get(ctx context.Context, user_id int, file_id string) ( err := row.Scan(&file.ID, &file.Name, &file.MIME.Name, &file.MIME.Extension, &file.CreatedAt, &file.Creator.Name, &file.Creator.IsAdmin, &file.Notes, &file.Metadata, &file.Viewed) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err) return } var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) - return + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "22P02": + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err) + return + } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } return @@ -177,10 +180,10 @@ func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string err := row.Scan(&mime_id, &extension) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - domainErr = domain.NewDomainError(err, domain.ErrMIMENotSupported, mime) + domainErr = domain.NewErrorMIMENotSupported(mime).Wrap(err) return } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } row = s.db.QueryRow(ctx, ` @@ -191,11 +194,17 @@ func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string err = row.Scan(&file.ID) if err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) - return + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "22007": + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid datetime: %q", datetime)).Wrap(err) + return + case "23502": + domainErr = domain.NewErrorBadRequest("Unable to set NULL to some fields").Wrap(err) + return + } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } file.Name = &name @@ -207,7 +216,7 @@ func (s *FileRepository) Add(ctx context.Context, user_id int, name, mime string // Update file func (s *FileRepository) Update(ctx context.Context, file_id string, updates map[string]interface{}) (domainErr *domain.DomainError) { if len(updates) == 0 { - domainErr = domain.NewDomainError(nil, domain.ErrValidation, "request body", "no fields provided for update") + // domainErr = domain.NewErrorBadRequest(nil, "No fields provided for update") return } query := "UPDATE data.files SET" @@ -222,34 +231,34 @@ func (s *FileRepository) Update(ctx context.Context, file_id string, updates map case "metadata": query += fmt.Sprintf(" %s=NULLIF($%d, '')::jsonb", field, count) default: - domainErr = domain.NewDomainError(nil, domain.ErrValidation, "field", field) + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Unknown field: %q", field)) return } newValues = append(newValues, value) count++ } - query += fmt.Sprintf(" WHERE id=$1") + query += fmt.Sprintf(" WHERE id=$1 AND is_deleted IS FALSE") commandTag, err := s.db.Exec(ctx, query, newValues...) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { switch pgErr.Code { - case "22P02", "22007": - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + case "22P02": + domainErr = domain.NewErrorBadRequest("Invalid format of some values").Wrap(err) return - case "42804": - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + case "22007": + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid datetime: %q", updates["datetime"])).Wrap(err) return case "23502": - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + domainErr = domain.NewErrorBadRequest("Some fields cannot be empty").Wrap(err) return } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } if commandTag.RowsAffected() == 0 { - domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err) return } return @@ -258,19 +267,22 @@ func (s *FileRepository) Update(ctx context.Context, file_id string, updates map // Delete file func (s *FileRepository) Delete(ctx context.Context, file_id string) (domainErr *domain.DomainError) { commandTag, err := s.db.Exec(ctx, - "UPDATE data.files SET is_deleted=true WHERE id=$1", + "UPDATE data.files SET is_deleted=true WHERE id=$1 AND is_deleted IS FALSE", file_id) if err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) - return + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "22P02": + domainErr = domain.NewErrorBadRequest(fmt.Sprintf("Invalid file id: %q", file_id)).Wrap(err) + return + } } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } if commandTag.RowsAffected() == 0 { - domainErr = domain.NewDomainError(err, domain.ErrFileNotFound, file_id) + domainErr = domain.NewErrorFileNotFound(file_id).Wrap(err) return } return @@ -295,10 +307,10 @@ func (s *FileRepository) GetTags(ctx context.Context, user_id int, file_id strin if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && (pgErr.Code == "22P02" || pgErr.Code == "22007") { - domainErr = domain.NewDomainError(err, domain.ErrValidation, "format", pgErr.Message) + domainErr = domain.NewErrorBadRequest(pgErr.Message).Wrap(err) return } - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } defer rows.Close() @@ -306,14 +318,14 @@ func (s *FileRepository) GetTags(ctx context.Context, user_id int, file_id strin var tag domain.TagItem err = rows.Scan(&tag.ID, &tag.Name, &tag.Color, &tag.Category.ID, &tag.Category.Name, &tag.Category.Color) if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) return } tags = append(tags, tag) } err = rows.Err() if err != nil { - domainErr = domain.NewUnexpectedError(err) + domainErr = domain.NewErrorUnexpected().Wrap(err) } return } diff --git a/backend/internal/interfaces/rest/handlers/error_handler.go b/backend/internal/interfaces/rest/handlers/error_handler.go index b03fb4e..1f3f225 100644 --- a/backend/internal/interfaces/rest/handlers/error_handler.go +++ b/backend/internal/interfaces/rest/handlers/error_handler.go @@ -1,7 +1,6 @@ package rest import ( - "fmt" "net/http" "tanabata/internal/domain" @@ -17,28 +16,28 @@ type ErrorMapper struct{} func (m *ErrorMapper) MapError(err domain.DomainError) (int, ErrorResponse) { switch err.Code { - case domain.ErrFileNotFound: + case domain.ErrCodeFileNotFound: return http.StatusNotFound, ErrorResponse{ Error: "Not Found", Code: string(err.Code), - Message: fmt.Sprintf("File %q not found", err.Details...), + Message: err.Message, } - case domain.ErrMIMENotSupported: + case domain.ErrCodeMIMENotSupported: return http.StatusNotFound, ErrorResponse{ Error: "MIME not supported", Code: string(err.Code), - Message: fmt.Sprintf("MIME not supported: %q", err.Details...), + Message: err.Message, } - case domain.ErrValidation: + case domain.ErrCodeBadRequest: return http.StatusNotFound, ErrorResponse{ Error: "Bad Request", Code: string(err.Code), - Message: fmt.Sprintf("Invalid %s: %s", err.Details...), + Message: err.Message, } } return http.StatusInternalServerError, ErrorResponse{ Error: "Internal Server Error", Code: string(err.Code), - Message: "An unexpected error occured", + Message: err.Message, } }