diff --git a/internal/db/db.go b/internal/db/db.go index 831654d..1b3fc66 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -8,6 +8,8 @@ import ( "time" "github.com/H1K0/Kiraku/internal/models" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" ) @@ -31,6 +33,24 @@ func InitDB(connString string) error { return nil } +func transaction(ctx context.Context, handler func(pgx.Tx) (statusCode int, err error)) (statusCode int, err error) { + tx, err := connPool.Begin(ctx) + if err != nil { + statusCode = http.StatusInternalServerError + return + } + statusCode, err = handler(tx) + if err != nil { + tx.Rollback(ctx) + return + } + err = tx.Commit(ctx) + if err != nil { + statusCode = http.StatusInternalServerError + } + return +} + //#region Users func UserLogin(ctx context.Context, login, password string) (user_id string, err error) { @@ -51,12 +71,13 @@ func UserAuth(ctx context.Context, user_id string) (ok, editor bool) { //#region Persons func PersonsGet(ctx context.Context, user_id, filter, sort string, limit, offset int) (persons models.Persons, statusCode int, err error) { - if ok, _ := UserAuth(ctx, user_id); !ok { - err = fmt.Errorf("Unauthorized") + ok, _ := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") statusCode = http.StatusUnauthorized return } - queryGet := "SELECT id, name, coalesce(sort_name, '') FROM persons WHERE position($1 in lower(name))>0 OR position($1 in lower(sort_name))>0" + queryGet := "SELECT id, name, COALESCE(sort_name, '') FROM persons WHERE POSITION($1 IN LOWER(name))>0 OR POSITION($1 IN LOWER(sort_name))>0" if sort != "" { sort_options := strings.Split(sort, ",") queryGet += " ORDER BY " @@ -76,7 +97,7 @@ func PersonsGet(ctx context.Context, user_id, filter, sort string, limit, offset switch sort_field { case "name": case "sortName": - sort_field = "coalesce(sort_name, name)" + sort_field = "COALESCE(sort_name, name)" case "birthdate": case "deathdate": default: @@ -98,39 +119,200 @@ func PersonsGet(ctx context.Context, user_id, filter, sort string, limit, offset queryGet += fmt.Sprintf(" OFFSET %d", offset) } filter = strings.ToLower(filter) - rows, err := connPool.Query(ctx, queryGet, filter) - if err != nil { - statusCode = http.StatusInternalServerError - return - } - count := 0 - for rows.Next() { - var person models.PersonBrief - err = rows.Scan(&person.ID, &person.Name, &person.SortName) + statusCode, err = transaction(ctx, func(tx pgx.Tx) (statusCode int, err error) { + rows, err := tx.Query(ctx, queryGet, filter) if err != nil { - err = fmt.Errorf("error while fetching persons: %w", err) statusCode = http.StatusInternalServerError return } - persons.Persons = append(persons.Persons, person) - count++ - } - err = rows.Err() + count := 0 + for rows.Next() { + var person models.PersonBrief + err = rows.Scan(&person.ID, &person.Name, &person.SortName) + if err != nil { + statusCode = http.StatusInternalServerError + return + } + persons.Persons = append(persons.Persons, person) + count++ + } + err = rows.Err() + if err != nil { + statusCode = http.StatusInternalServerError + return + } + persons.Pagination.Limit = limit + persons.Pagination.Offset = offset + persons.Pagination.Count = count + queryCount = fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount) + row := tx.QueryRow(ctx, queryCount, filter) + err = row.Scan(&persons.Pagination.Total) + if err != nil { + statusCode = http.StatusInternalServerError + } + return + }) if err != nil { - statusCode = http.StatusInternalServerError return } - persons.Pagination.Limit = limit - persons.Pagination.Offset = offset - persons.Pagination.Count = count - queryCount = fmt.Sprintf("SELECT count(*) FROM (%s) tmp", queryCount) - row := connPool.QueryRow(ctx, queryCount, filter) - err = row.Scan(&persons.Pagination.Total) - if err != nil { - statusCode = http.StatusInternalServerError - } else { - statusCode = http.StatusOK + statusCode = http.StatusOK + return +} + +func PersonGet(ctx context.Context, user_id, person_id string) (person models.Person, statusCode int, err error) { + ok, _ := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return } + row := connPool.QueryRow(ctx, "SELECT id, name, COALESCE(sort_name, ''), COALESCE(TO_CHAR(birthdate, 'YYYY-MM-DD'), ''), COALESCE(TO_CHAR(deathdate, 'YYYY-MM-DD'), ''), COALESCE(info, '') FROM persons WHERE id=$1", person_id) + err = row.Scan(&person.ID, &person.Name, &person.SortName, &person.Birthdate, &person.Deathdate, &person.Deathdate) + if err != nil { + pgErr := err.(*pgconn.PgError) + if err == pgx.ErrNoRows { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + } else if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else { + statusCode = http.StatusInternalServerError + } + return + } + statusCode = http.StatusOK + return +} + +func PersonAdd(ctx context.Context, user_id, name, sortName, birthdate, deathdate, info string) (person models.Person, statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + row := connPool.QueryRow( + ctx, + "INSERT INTO persons (name, sort_name, birthdate, deathdate, info) "+ + "VALUES ($1, NULLIF($2, ''), NULLIF($3, '')::date, NULLIF($4, '')::date, NULLIF($5, '')) "+ + "RETURNING id, name, COALESCE(sort_name, ''), COALESCE(TO_CHAR(birthdate, 'YYYY-MM-DD'), ''), COALESCE(TO_CHAR(deathdate, 'YYYY-MM-DD'), ''), COALESCE(info, '')", + name, sortName, birthdate, deathdate, info, + ) + err = row.Scan(&person.ID, &person.Name, &person.SortName, &person.Birthdate, &person.Deathdate, &person.Info) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else if pgErr.Code == "23505" { + err = fmt.Errorf("a person with this name already exists") + statusCode = http.StatusConflict + } else { + statusCode = http.StatusInternalServerError + } + return + } + statusCode = http.StatusOK + return +} + +func PersonUpdate(ctx context.Context, user_id, person_id string, values map[string]string) (person models.Person, statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + statusCode, err = transaction(ctx, func(tx pgx.Tx) (statusCode int, err error) { + for _, field := range []string{"name", "sortName", "birthdate", "deathdate", "info"} { + value, ok := values[field] + if !ok { + continue + } + if field == "sortName" { + field = "sort_name" + } + var query string + if field == "birthdate" || field == "deathdate" { + query = fmt.Sprintf("UPDATE persons SET %s=NULLIF($2, '')::date WHERE id=$1", field) + } else { + query = fmt.Sprintf("UPDATE persons SET %s=NULLIF($2, '') WHERE id=$1", field) + } + var commandTag pgconn.CommandTag + commandTag, err = tx.Exec(ctx, query, person_id, value) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else if pgErr.Code == "23505" { + err = fmt.Errorf("a person with this name already exists") + statusCode = http.StatusConflict + } else { + statusCode = http.StatusInternalServerError + } + return + } + if commandTag.RowsAffected() == 0 { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + } + row := tx.QueryRow(ctx, "SELECT id, name, COALESCE(sort_name, ''), COALESCE(TO_CHAR(birthdate, 'YYYY-MM-DD'), ''), COALESCE(TO_CHAR(deathdate, 'YYYY-MM-DD'), ''), COALESCE(info, '') FROM persons WHERE id=$1", person_id) + err = row.Scan(&person.ID, &person.Name, &person.SortName, &person.Birthdate, &person.Deathdate, &person.Info) + if err != nil { + statusCode = http.StatusInternalServerError + } + return + }) + if err != nil { + return + } + statusCode = http.StatusOK + return +} + +func PersonDelete(ctx context.Context, user_id, person_id string) (statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + commandTag, err := connPool.Exec(ctx, "DELETE FROM persons WHERE id=$1", person_id) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else { + statusCode = http.StatusInternalServerError + } + return + } + if commandTag.RowsAffected() == 0 { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + statusCode = http.StatusNoContent return }