138 lines
3.5 KiB
Go
138 lines
3.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/H1K0/Kiraku/internal/models"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
var connPool *pgxpool.Pool
|
|
|
|
func InitDB(connString string) error {
|
|
poolConfig, err := pgxpool.ParseConfig(connString)
|
|
if err != nil {
|
|
return fmt.Errorf("error while parsing connection string: %w", err)
|
|
}
|
|
|
|
poolConfig.MaxConns = 10
|
|
poolConfig.MinConns = 0
|
|
poolConfig.MaxConnLifetime = time.Hour
|
|
poolConfig.HealthCheckPeriod = 30 * time.Second
|
|
|
|
connPool, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("error while initializing DB connection pool: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//#region Users
|
|
|
|
func UserLogin(ctx context.Context, login, password string) (user_id string, err error) {
|
|
row := connPool.QueryRow(ctx, "SELECT id FROM users WHERE login=$1 AND password=crypt($2, password)", login, password)
|
|
err = row.Scan(&user_id)
|
|
return
|
|
}
|
|
|
|
func UserAuth(ctx context.Context, user_id string) (ok, editor bool) {
|
|
row := connPool.QueryRow(ctx, "SELECT editor FROM users WHERE id=$1", user_id)
|
|
err := row.Scan(&editor)
|
|
ok = (err == nil)
|
|
return
|
|
}
|
|
|
|
//#endregion Users
|
|
|
|
//#region Persons
|
|
|
|
func PersonsGet(ctx context.Context, user_id, filter, sort string, limit, offset int) (persons models.Persons, statusCode int, err error) {
|
|
if ok, _ := UserAuth(ctx, user_id); !ok {
|
|
err = fmt.Errorf("Unauthorized")
|
|
statusCode = http.StatusUnauthorized
|
|
return
|
|
}
|
|
queryGet := "SELECT id, name, coalesce(sort_name, '') FROM persons WHERE position($1 in lower(name))>0 OR position($1 in lower(sort_name))>0"
|
|
if sort != "" {
|
|
sort_options := strings.Split(sort, ",")
|
|
queryGet += " ORDER BY "
|
|
for i, sort_option := range sort_options {
|
|
sort_order := sort_option[:1]
|
|
sort_field := sort_option[1:]
|
|
switch sort_order {
|
|
case "+":
|
|
sort_order = "ASC"
|
|
case "-":
|
|
sort_order = "DESC"
|
|
default:
|
|
err = fmt.Errorf("invalid sorting order mark: %q", sort)
|
|
statusCode = http.StatusBadRequest
|
|
return
|
|
}
|
|
switch sort_field {
|
|
case "name":
|
|
case "sortName":
|
|
sort_field = "coalesce(sort_name, name)"
|
|
case "birthdate":
|
|
case "deathdate":
|
|
default:
|
|
err = fmt.Errorf("invalid sorting field: %q", sort_field)
|
|
statusCode = http.StatusBadRequest
|
|
return
|
|
}
|
|
if i > 0 {
|
|
queryGet += ", "
|
|
}
|
|
queryGet += fmt.Sprintf("%s %s NULLS LAST", sort_field, sort_order)
|
|
}
|
|
}
|
|
queryCount := queryGet
|
|
if limit >= 0 {
|
|
queryGet += fmt.Sprintf(" LIMIT %d", limit)
|
|
}
|
|
if offset > 0 {
|
|
queryGet += fmt.Sprintf(" OFFSET %d", offset)
|
|
}
|
|
filter = strings.ToLower(filter)
|
|
rows, err := connPool.Query(ctx, queryGet, filter)
|
|
if err != nil {
|
|
statusCode = http.StatusInternalServerError
|
|
return
|
|
}
|
|
count := 0
|
|
for rows.Next() {
|
|
var person models.PersonBrief
|
|
err = rows.Scan(&person.ID, &person.Name, &person.SortName)
|
|
if err != nil {
|
|
err = fmt.Errorf("error while fetching persons: %w", err)
|
|
statusCode = http.StatusInternalServerError
|
|
return
|
|
}
|
|
persons.Persons = append(persons.Persons, person)
|
|
count++
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
statusCode = http.StatusInternalServerError
|
|
return
|
|
}
|
|
persons.Pagination.Limit = limit
|
|
persons.Pagination.Offset = offset
|
|
persons.Pagination.Count = count
|
|
queryCount = fmt.Sprintf("SELECT count(*) FROM (%s) tmp", queryCount)
|
|
row := connPool.QueryRow(ctx, queryCount, filter)
|
|
err = row.Scan(&persons.Pagination.Total)
|
|
if err != nil {
|
|
statusCode = http.StatusInternalServerError
|
|
} else {
|
|
statusCode = http.StatusOK
|
|
}
|
|
return
|
|
}
|
|
|
|
//#endregion Persons
|