diff --git a/internal/db/roles.go b/internal/db/roles.go new file mode 100644 index 0000000..855cc0a --- /dev/null +++ b/internal/db/roles.go @@ -0,0 +1,241 @@ +package db + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/H1K0/Kiraku/internal/models" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +func RolesGet(ctx context.Context, user_id, filter, sort string, limit, offset int) (roles models.Roles, statusCode int, err error) { + ok, _ := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + queryGet := "SELECT id, name FROM roles WHERE POSITION($1 IN LOWER(name))>0" + if sort != "" { + sort_options := strings.Split(sort, ",") + queryGet += " ORDER BY " + for i, sort_option := range sort_options { + sort_order := sort_option[:1] + sort_field := sort_option[1:] + switch sort_order { + case "+": + sort_order = "ASC" + case "-": + sort_order = "DESC" + default: + err = fmt.Errorf("invalid sorting order mark: %q", sort) + statusCode = http.StatusBadRequest + return + } + switch sort_field { + case "name": + default: + err = fmt.Errorf("invalid sorting field: %q", sort_field) + statusCode = http.StatusBadRequest + return + } + if i > 0 { + queryGet += ", " + } + queryGet += fmt.Sprintf("%s %s NULLS LAST", sort_field, sort_order) + } + } + queryCount := queryGet + if limit >= 0 { + queryGet += fmt.Sprintf(" LIMIT %d", limit) + } + if offset > 0 { + queryGet += fmt.Sprintf(" OFFSET %d", offset) + } + filter = strings.ToLower(filter) + statusCode, err = transaction(ctx, func(tx pgx.Tx) (statusCode int, err error) { + rows, err := tx.Query(ctx, queryGet, filter) + if err != nil { + statusCode = http.StatusInternalServerError + return + } + count := 0 + for rows.Next() { + var role models.Role + err = rows.Scan(&role.ID, &role.Name) + if err != nil { + statusCode = http.StatusInternalServerError + return + } + roles.Roles = append(roles.Roles, role) + count++ + } + err = rows.Err() + if err != nil { + statusCode = http.StatusInternalServerError + return + } + roles.Pagination.Limit = limit + roles.Pagination.Offset = offset + roles.Pagination.Count = count + queryCount = fmt.Sprintf("SELECT COUNT(*) FROM (%s) tmp", queryCount) + row := tx.QueryRow(ctx, queryCount, filter) + err = row.Scan(&roles.Pagination.Total) + if err != nil { + statusCode = http.StatusInternalServerError + } + return + }) + if err != nil { + return + } + statusCode = http.StatusOK + return +} + +func RoleGet(ctx context.Context, user_id, person_id string) (role models.Role, statusCode int, err error) { + ok, _ := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + row := connPool.QueryRow(ctx, "SELECT id, name FROM roles WHERE id=$1", person_id) + err = row.Scan(&role.ID, &role.Name) + if err != nil { + if err == pgx.ErrNoRows { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else { + statusCode = http.StatusInternalServerError + } + return + } + statusCode = http.StatusOK + return +} + +func RoleAdd(ctx context.Context, user_id, name string) (role models.Role, statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + row := connPool.QueryRow(ctx, "INSERT INTO roles (name) VALUES ($1) RETURNING id, name", name) + err = row.Scan(&role.ID, &role.Name) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else if pgErr.Code == "23505" { + err = fmt.Errorf("a role with this name already exists") + statusCode = http.StatusConflict + } else { + statusCode = http.StatusInternalServerError + } + return + } + statusCode = http.StatusOK + return +} + +func RoleUpdate(ctx context.Context, user_id, role_id string, values map[string]string) (role models.Role, statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + statusCode, err = transaction(ctx, func(tx pgx.Tx) (statusCode int, err error) { + for _, field := range []string{"name"} { + value, ok := values[field] + if !ok { + continue + } + var commandTag pgconn.CommandTag + commandTag, err = tx.Exec(ctx, fmt.Sprintf("UPDATE roles SET %s=NULLIF($2, '') WHERE id=$1", field), role_id, value) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" || pgErr.Code == "22007" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else if pgErr.Code == "23505" { + err = fmt.Errorf("a person with this name already exists") + statusCode = http.StatusConflict + } else { + statusCode = http.StatusInternalServerError + } + return + } + if commandTag.RowsAffected() == 0 { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + } + row := tx.QueryRow(ctx, "SELECT id, name FROM roles WHERE id=$1", role_id) + err = row.Scan(&role.ID, &role.Name) + if err != nil { + statusCode = http.StatusInternalServerError + } + return + }) + if err != nil { + return + } + statusCode = http.StatusOK + return +} + +func RoleDelete(ctx context.Context, user_id, role_id string) (statusCode int, err error) { + ok, editor := UserAuth(ctx, user_id) + if !ok { + err = fmt.Errorf("unauthorized") + statusCode = http.StatusUnauthorized + return + } + if !editor { + err = fmt.Errorf("not allowed") + statusCode = http.StatusForbidden + return + } + commandTag, err := connPool.Exec(ctx, "DELETE FROM roles WHERE id=$1", role_id) + if err != nil { + pgErr := err.(*pgconn.PgError) + if pgErr.Code == "22P02" { + err = fmt.Errorf("%s", pgErr.Message) + statusCode = http.StatusBadRequest + } else { + statusCode = http.StatusInternalServerError + } + return + } + if commandTag.RowsAffected() == 0 { + err = fmt.Errorf("not found") + statusCode = http.StatusNotFound + return + } + statusCode = http.StatusNoContent + return +}