package db import ( "context" "fmt" "net/http" "strings" "time" "github.com/H1K0/Kiraku/internal/models" "github.com/jackc/pgx/v5/pgxpool" ) var connPool *pgxpool.Pool func InitDB(connString string) error { poolConfig, err := pgxpool.ParseConfig(connString) if err != nil { return fmt.Errorf("error while parsing connection string: %w", err) } poolConfig.MaxConns = 10 poolConfig.MinConns = 0 poolConfig.MaxConnLifetime = time.Hour poolConfig.HealthCheckPeriod = 30 * time.Second connPool, err = pgxpool.NewWithConfig(context.Background(), poolConfig) if err != nil { return fmt.Errorf("error while initializing DB connection pool: %w", err) } return nil } //#region Users func UserLogin(ctx context.Context, login, password string) (user_id string, err error) { row := connPool.QueryRow(ctx, "SELECT id FROM users WHERE login=$1 AND password=crypt($2, password)", login, password) err = row.Scan(&user_id) return } func UserAuth(ctx context.Context, user_id string) (ok, editor bool) { row := connPool.QueryRow(ctx, "SELECT editor FROM users WHERE id=$1", user_id) err := row.Scan(&editor) ok = (err == nil) return } //#endregion Users //#region Persons func PersonsGet(ctx context.Context, user_id, filter, sort string, limit, offset int) (persons models.Persons, statusCode int, err error) { if ok, _ := UserAuth(ctx, user_id); !ok { err = fmt.Errorf("Unauthorized") statusCode = http.StatusUnauthorized return } queryGet := "SELECT id, name, coalesce(sort_name, '') FROM persons WHERE position($1 in lower(name))>0 OR position($1 in lower(sort_name))>0" if sort != "" { sort_options := strings.Split(sort, ",") queryGet += " ORDER BY " for i, sort_option := range sort_options { sort_order := sort_option[:1] sort_field := sort_option[1:] switch sort_order { case "+": sort_order = "ASC" case "-": sort_order = "DESC" default: err = fmt.Errorf("invalid sorting order mark: %q", sort) statusCode = http.StatusBadRequest return } switch sort_field { case "name": case "sortName": sort_field = "coalesce(sort_name, name)" case "birthdate": case "deathdate": default: err = fmt.Errorf("invalid sorting field: %q", sort_field) statusCode = http.StatusBadRequest return } if i > 0 { queryGet += ", " } queryGet += fmt.Sprintf("%s %s NULLS LAST", sort_field, sort_order) } } queryCount := queryGet if limit >= 0 { queryGet += fmt.Sprintf(" LIMIT %d", limit) } if offset > 0 { queryGet += fmt.Sprintf(" OFFSET %d", offset) } filter = strings.ToLower(filter) rows, err := connPool.Query(ctx, queryGet, filter) if err != nil { statusCode = http.StatusInternalServerError return } count := 0 for rows.Next() { var person models.PersonBrief err = rows.Scan(&person.ID, &person.Name, &person.SortName) if err != nil { err = fmt.Errorf("error while fetching persons: %w", err) statusCode = http.StatusInternalServerError return } persons.Persons = append(persons.Persons, person) count++ } err = rows.Err() if err != nil { statusCode = http.StatusInternalServerError return } persons.Pagination.Limit = limit persons.Pagination.Offset = offset persons.Pagination.Count = count queryCount = fmt.Sprintf("SELECT count(*) FROM (%s) tmp", queryCount) row := connPool.QueryRow(ctx, queryCount, filter) err = row.Scan(&persons.Pagination.Total) if err != nil { statusCode = http.StatusInternalServerError } else { statusCode = http.StatusOK } return } //#endregion Persons