c147c47acb
Single-page form wrapping src.generate.generate_period: pick model, mode, key, style, function, time, sampling params and optional prefix; returns the chord grid plus downloadable .chord and .mid files. Russian usage instructions are embedded on the same page. Auto-length output is capped at 16 bars (the period maximum) so a model that never emits EOS can't run away into dozens of NC/hold bars. Added per the author's explicit request — web UI was previously out of scope; updated CLAUDE.md and README accordingly. Choices for style/ function/time are derived from VOCAB so the form can't drift from the tokenizer. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
320 lines
14 KiB
Python
320 lines
14 KiB
Python
"""Minimal Gradio web UI for hamori — generate a harmonic period from a browser.
|
||
|
||
A single-page form that wraps :func:`src.generate.generate_period`. Pick the
|
||
mode, key, style and sampling parameters; the app returns the chord grid plus
|
||
downloadable ``.chord`` and ``.mid`` files. Russian usage instructions are
|
||
embedded on the same page (see the "Инструкция" accordion).
|
||
|
||
This is a convenience front-end for demonstration only — all generation logic
|
||
lives in ``src/``. The CLI (``scripts/generate.py``) remains the canonical
|
||
entry point.
|
||
|
||
Usage:
|
||
python app.py # serve on http://127.0.0.1:7860
|
||
python app.py --share # also create a temporary public link
|
||
python app.py --port 8000
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import tempfile
|
||
from dataclasses import replace
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from uuid import uuid4
|
||
|
||
import gradio as gr
|
||
import torch
|
||
|
||
from src.generate import generate_period
|
||
from src.midi_export import chord_file_to_midi
|
||
from src.model import ChordTransformer
|
||
from src.tokenizer import VOCAB, write_chord_file
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Constants — choices derived from the vocabulary so the form never drifts.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
CHECKPOINT_DIR = Path(__file__).resolve().parent / "checkpoints"
|
||
|
||
NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
||
|
||
# A period is 4–16 bars by definition; cap runaway "auto" output at this ceiling.
|
||
MAX_PERIOD_BARS = 16
|
||
|
||
STYLES = [t[len("STYLE_"):] for t in VOCAB if t.startswith("STYLE_")]
|
||
FUNCTIONS = [t[len("FUNC_"):] for t in VOCAB if t.startswith("FUNC_")]
|
||
TIMES = [t[len("TIME_"):] for t in VOCAB if t.startswith("TIME_")]
|
||
|
||
# Files generated for download live here for the lifetime of the process.
|
||
OUTPUT_DIR = Path(tempfile.mkdtemp(prefix="hamori_webui_"))
|
||
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Model loading (cached per checkpoint)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@lru_cache(maxsize=2)
|
||
def _load_model(checkpoint: str) -> ChordTransformer:
|
||
"""Load and cache a ChordTransformer by checkpoint stem ('pretrained' / 'finetuned')."""
|
||
path = CHECKPOINT_DIR / f"{checkpoint}.pt"
|
||
ckpt = torch.load(path, map_location=DEVICE, weights_only=True)
|
||
model = ChordTransformer(**ckpt["model_config"])
|
||
model.load_state_dict(ckpt["model_state"])
|
||
model.to(DEVICE)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
def _available_checkpoints() -> list[str]:
|
||
return sorted(p.stem for p in CHECKPOINT_DIR.glob("*.pt"))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Generation callback
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _format_bars(period) -> str:
|
||
"""Render the bar grid as aligned text, one bar per line."""
|
||
width = max((len(s) for bar in period.bars for s in bar), default=1)
|
||
lines = []
|
||
for i, bar in enumerate(period.bars, 1):
|
||
cells = " ".join(s.rjust(width) for s in bar)
|
||
lines.append(f"Bar {i:2d} | {cells}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate(
|
||
checkpoint: str,
|
||
mode: str,
|
||
key: str,
|
||
style: str,
|
||
function: str,
|
||
time: str,
|
||
subdivision: int,
|
||
auto_bars: bool,
|
||
n_bars: int,
|
||
temperature: float,
|
||
top_p: float,
|
||
repetition_penalty: float,
|
||
tonic_anchor: bool,
|
||
prefix_text: str,
|
||
seed,
|
||
tempo: int,
|
||
):
|
||
"""Run one generation and return (status, bar grid, .chord path, .mid path)."""
|
||
try:
|
||
model = _load_model(checkpoint)
|
||
except Exception as exc: # noqa: BLE001 — surface any load error to the UI
|
||
return f"❌ Не удалось загрузить чекпойнт «{checkpoint}»: {exc}", "", None, None
|
||
|
||
target_key = f"{key}_{mode}"
|
||
|
||
# Prefix: explicit text wins; otherwise optionally anchor to the tonic.
|
||
prefix_chords: list[str] | None
|
||
prefix_text = (prefix_text or "").strip()
|
||
if prefix_text:
|
||
prefix_chords = prefix_text.split()
|
||
elif tonic_anchor:
|
||
prefix_chords = [key + ("m" if mode == "minor" else "")]
|
||
else:
|
||
prefix_chords = None
|
||
|
||
seed_val = int(seed) if seed is not None else None
|
||
bars_arg = None if auto_bars else int(n_bars)
|
||
|
||
try:
|
||
period = generate_period(
|
||
model=model,
|
||
mode=mode,
|
||
time=time,
|
||
subdivision=int(subdivision),
|
||
style=style,
|
||
function=function,
|
||
key=target_key,
|
||
prefix=prefix_chords,
|
||
temperature=float(temperature),
|
||
top_p=float(top_p),
|
||
n_bars=bars_arg,
|
||
seed=seed_val,
|
||
repetition_penalty=float(repetition_penalty),
|
||
)
|
||
except Exception as exc: # noqa: BLE001 — show generation errors verbatim
|
||
return f"❌ Ошибка генерации: {exc}", "", None, None
|
||
|
||
# "Auto" lets the model close via EOS, but if it never does it can run away
|
||
# into dozens of NC/hold bars. A period is 4–16 bars — cap the tail.
|
||
truncated = False
|
||
if len(period.bars) > MAX_PERIOD_BARS:
|
||
period = replace(period, bars=period.bars[:MAX_PERIOD_BARS])
|
||
truncated = True
|
||
|
||
period = replace(period, title=f"hamori — {key} {mode}, {function}")
|
||
|
||
stem = f"hamori_{key.replace('#', 'sharp')}_{mode}_{uuid4().hex[:6]}"
|
||
chord_path = OUTPUT_DIR / f"{stem}.chord"
|
||
midi_path = OUTPUT_DIR / f"{stem}.mid"
|
||
write_chord_file(period, chord_path)
|
||
chord_file_to_midi(chord_path, midi_path, tempo=int(tempo))
|
||
|
||
status = (
|
||
f"✅ Готово — {len(period.bars)} тактов · {target_key} · "
|
||
f"модель: {checkpoint} · seed: {seed_val if seed_val is not None else 'random'}"
|
||
)
|
||
if truncated:
|
||
status += f" · обрезано до {MAX_PERIOD_BARS} тактов (период ≤ 16)"
|
||
return status, _format_bars(period), str(chord_path), str(midi_path)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Russian instructions (rendered inline)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
INSTRUCTIONS_RU = """
|
||
## Как пользоваться
|
||
|
||
**hamori** генерирует одну гармоническую фразу (период, 4–16 тактов) в заданной
|
||
тональности и стиле. Это инструмент-подсказчик: он предлагает аккордовую
|
||
последовательность, которую вы дорабатываете в DAW.
|
||
|
||
### Параметры
|
||
|
||
| Поле | Что задаёт |
|
||
|------|------------|
|
||
| **Модель** | `finetuned` — обучена на вашем корпусе (рекомендуется). `pretrained` — только McGill Billboard, более «общий» звук. |
|
||
| **Лад / Тональность** | Мажор или минор и тоника результата (например, `F# major`). Модель генерирует в C/Am и транспонирует в выбранную тональность. |
|
||
| **Стиль / Функция** | Метки условия. `H1K0` — авторский стиль. Функция — роль фрагмента (куплет, припев…). |
|
||
| **Размер / Subdivision** | Тактовый размер и число позиций на такт. По умолчанию `4/4` и `4`. |
|
||
| **Число тактов** | Длина периода. «Авто» — модель сама решает, где закрыть фразу. |
|
||
| **Temperature** | Разброс. `1.0` — норма. Выше — смелее и хаотичнее, ниже — предсказуемее. |
|
||
| **Top-p** | Нуклеус-сэмплинг. `0.9` — норма. Ниже — консервативнее. |
|
||
| **Repetition penalty** | Борется с зацикливанием (I–V–I–V). `0.0` — выкл. Для `pretrained` попробуйте `0.5–1.0`; для `finetuned` обычно не нужно. |
|
||
| **Tonic anchor** | Если префикс пуст — начинать с тоники, чтобы фраза держалась в тональности. |
|
||
| **Префикс** | Свои стартовые аккорды через пробел в выбранной тональности, напр. `Cmaj7 . Am7 .`. `.` — держать, `NC` — без аккорда. Если задан, перекрывает tonic anchor. |
|
||
| **Seed** | Фиксирует случайность для воспроизводимости. Очистите поле для случайного результата. |
|
||
| **Tempo** | Темп MIDI-файла (BPM). На сами аккорды не влияет. |
|
||
|
||
### Результат
|
||
|
||
- **Сетка аккордов** — текстовый предпросмотр периода.
|
||
- **`.chord`** — исходный формат проекта (человекочитаемый).
|
||
- **`.mid`** — импортируйте в REAPER перетаскиванием на дорожку.
|
||
|
||
### Рекомендации для старта
|
||
Модель `finetuned`, `temperature = 1.0`, `top-p = 0.9`, tonic anchor включён.
|
||
Если получается монотонно или зациклено — поднимите temperature или добавьте
|
||
repetition penalty.
|
||
"""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# UI definition
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def build_ui() -> gr.Blocks:
|
||
checkpoints = _available_checkpoints()
|
||
default_ckpt = "finetuned" if "finetuned" in checkpoints else (
|
||
checkpoints[0] if checkpoints else "finetuned"
|
||
)
|
||
|
||
with gr.Blocks(title="hamori — генератор гармонии") as demo:
|
||
gr.Markdown(
|
||
"# hamori 🎶 — генератор гармонических периодов\n"
|
||
"Заполните форму и нажмите **Сгенерировать**. "
|
||
"Подробности — в разделе «Инструкция» внизу."
|
||
)
|
||
|
||
with gr.Row():
|
||
# ---- Left: form ------------------------------------------------
|
||
with gr.Column(scale=1):
|
||
checkpoint = gr.Radio(
|
||
choices=checkpoints or ["finetuned", "pretrained"],
|
||
value=default_ckpt, label="Модель",
|
||
)
|
||
with gr.Row():
|
||
mode = gr.Radio(["major", "minor"], value="major", label="Лад")
|
||
key = gr.Dropdown(NOTE_NAMES, value="C", label="Тональность")
|
||
with gr.Row():
|
||
style = gr.Dropdown(
|
||
STYLES, value="H1K0" if "H1K0" in STYLES else STYLES[0],
|
||
label="Стиль",
|
||
)
|
||
function = gr.Dropdown(
|
||
FUNCTIONS,
|
||
value="chorus" if "chorus" in FUNCTIONS else FUNCTIONS[0],
|
||
label="Функция",
|
||
)
|
||
with gr.Row():
|
||
time = gr.Dropdown(
|
||
TIMES, value="4/4" if "4/4" in TIMES else TIMES[0],
|
||
label="Размер",
|
||
)
|
||
subdivision = gr.Radio([4, 8], value=4, label="Subdivision")
|
||
|
||
with gr.Row():
|
||
auto_bars = gr.Checkbox(value=False, label="Авто (длина сама)")
|
||
n_bars = gr.Slider(4, 16, value=8, step=1, label="Число тактов")
|
||
|
||
with gr.Accordion("Сэмплирование", open=True):
|
||
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05,
|
||
label="Temperature")
|
||
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
||
repetition_penalty = gr.Slider(
|
||
0.0, 2.0, value=0.0, step=0.1, label="Repetition penalty",
|
||
)
|
||
|
||
with gr.Accordion("Дополнительно", open=False):
|
||
tonic_anchor = gr.Checkbox(value=True, label="Tonic anchor")
|
||
prefix_text = gr.Textbox(
|
||
label="Префикс (аккорды через пробел)",
|
||
placeholder="напр. Cmaj7 . Am7 .",
|
||
)
|
||
seed = gr.Number(value=42, precision=0,
|
||
label="Seed (пусто = случайно)")
|
||
tempo = gr.Number(value=90, precision=0, label="Tempo (BPM)")
|
||
|
||
run = gr.Button("Сгенерировать", variant="primary")
|
||
|
||
# ---- Right: outputs -------------------------------------------
|
||
with gr.Column(scale=1):
|
||
status = gr.Markdown()
|
||
bars_out = gr.Textbox(label="Сетка аккордов", lines=10,
|
||
interactive=False)
|
||
chord_file = gr.File(label="Скачать .chord")
|
||
midi_file = gr.File(label="Скачать .mid")
|
||
|
||
with gr.Accordion("Инструкция", open=False):
|
||
gr.Markdown(INSTRUCTIONS_RU)
|
||
|
||
run.click(
|
||
fn=generate,
|
||
inputs=[
|
||
checkpoint, mode, key, style, function, time, subdivision,
|
||
auto_bars, n_bars, temperature, top_p, repetition_penalty,
|
||
tonic_anchor, prefix_text, seed, tempo,
|
||
],
|
||
outputs=[status, bars_out, chord_file, midi_file],
|
||
)
|
||
|
||
return demo
|
||
|
||
|
||
def main() -> None:
|
||
ap = argparse.ArgumentParser(description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
ap.add_argument("--port", type=int, default=7860, help="Server port (default: 7860).")
|
||
ap.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1).")
|
||
ap.add_argument("--share", action="store_true",
|
||
help="Create a temporary public Gradio link.")
|
||
args = ap.parse_args()
|
||
|
||
demo = build_ui()
|
||
demo.launch(server_name=args.host, server_port=args.port, share=args.share)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|