Files
hamori/app.py
T
H1K0 c147c47acb feat: add minimal Gradio web UI (app.py)
Single-page form wrapping src.generate.generate_period: pick model, mode,
key, style, function, time, sampling params and optional prefix; returns
the chord grid plus downloadable .chord and .mid files. Russian usage
instructions are embedded on the same page.

Auto-length output is capped at 16 bars (the period maximum) so a model
that never emits EOS can't run away into dozens of NC/hold bars.

Added per the author's explicit request — web UI was previously out of
scope; updated CLAUDE.md and README accordingly. Choices for style/
function/time are derived from VOCAB so the form can't drift from the
tokenizer.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-04 15:38:01 +03:00

320 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Minimal Gradio web UI for hamori — generate a harmonic period from a browser.
A single-page form that wraps :func:`src.generate.generate_period`. Pick the
mode, key, style and sampling parameters; the app returns the chord grid plus
downloadable ``.chord`` and ``.mid`` files. Russian usage instructions are
embedded on the same page (see the "Инструкция" accordion).
This is a convenience front-end for demonstration only — all generation logic
lives in ``src/``. The CLI (``scripts/generate.py``) remains the canonical
entry point.
Usage:
python app.py # serve on http://127.0.0.1:7860
python app.py --share # also create a temporary public link
python app.py --port 8000
"""
from __future__ import annotations
import argparse
import tempfile
from dataclasses import replace
from functools import lru_cache
from pathlib import Path
from uuid import uuid4
import gradio as gr
import torch
from src.generate import generate_period
from src.midi_export import chord_file_to_midi
from src.model import ChordTransformer
from src.tokenizer import VOCAB, write_chord_file
# ---------------------------------------------------------------------------
# Constants — choices derived from the vocabulary so the form never drifts.
# ---------------------------------------------------------------------------
CHECKPOINT_DIR = Path(__file__).resolve().parent / "checkpoints"
NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
# A period is 416 bars by definition; cap runaway "auto" output at this ceiling.
MAX_PERIOD_BARS = 16
STYLES = [t[len("STYLE_"):] for t in VOCAB if t.startswith("STYLE_")]
FUNCTIONS = [t[len("FUNC_"):] for t in VOCAB if t.startswith("FUNC_")]
TIMES = [t[len("TIME_"):] for t in VOCAB if t.startswith("TIME_")]
# Files generated for download live here for the lifetime of the process.
OUTPUT_DIR = Path(tempfile.mkdtemp(prefix="hamori_webui_"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------------------------------------------------------------------
# Model loading (cached per checkpoint)
# ---------------------------------------------------------------------------
@lru_cache(maxsize=2)
def _load_model(checkpoint: str) -> ChordTransformer:
"""Load and cache a ChordTransformer by checkpoint stem ('pretrained' / 'finetuned')."""
path = CHECKPOINT_DIR / f"{checkpoint}.pt"
ckpt = torch.load(path, map_location=DEVICE, weights_only=True)
model = ChordTransformer(**ckpt["model_config"])
model.load_state_dict(ckpt["model_state"])
model.to(DEVICE)
model.eval()
return model
def _available_checkpoints() -> list[str]:
return sorted(p.stem for p in CHECKPOINT_DIR.glob("*.pt"))
# ---------------------------------------------------------------------------
# Generation callback
# ---------------------------------------------------------------------------
def _format_bars(period) -> str:
"""Render the bar grid as aligned text, one bar per line."""
width = max((len(s) for bar in period.bars for s in bar), default=1)
lines = []
for i, bar in enumerate(period.bars, 1):
cells = " ".join(s.rjust(width) for s in bar)
lines.append(f"Bar {i:2d} | {cells}")
return "\n".join(lines)
def generate(
checkpoint: str,
mode: str,
key: str,
style: str,
function: str,
time: str,
subdivision: int,
auto_bars: bool,
n_bars: int,
temperature: float,
top_p: float,
repetition_penalty: float,
tonic_anchor: bool,
prefix_text: str,
seed,
tempo: int,
):
"""Run one generation and return (status, bar grid, .chord path, .mid path)."""
try:
model = _load_model(checkpoint)
except Exception as exc: # noqa: BLE001 — surface any load error to the UI
return f"❌ Не удалось загрузить чекпойнт «{checkpoint}»: {exc}", "", None, None
target_key = f"{key}_{mode}"
# Prefix: explicit text wins; otherwise optionally anchor to the tonic.
prefix_chords: list[str] | None
prefix_text = (prefix_text or "").strip()
if prefix_text:
prefix_chords = prefix_text.split()
elif tonic_anchor:
prefix_chords = [key + ("m" if mode == "minor" else "")]
else:
prefix_chords = None
seed_val = int(seed) if seed is not None else None
bars_arg = None if auto_bars else int(n_bars)
try:
period = generate_period(
model=model,
mode=mode,
time=time,
subdivision=int(subdivision),
style=style,
function=function,
key=target_key,
prefix=prefix_chords,
temperature=float(temperature),
top_p=float(top_p),
n_bars=bars_arg,
seed=seed_val,
repetition_penalty=float(repetition_penalty),
)
except Exception as exc: # noqa: BLE001 — show generation errors verbatim
return f"❌ Ошибка генерации: {exc}", "", None, None
# "Auto" lets the model close via EOS, but if it never does it can run away
# into dozens of NC/hold bars. A period is 416 bars — cap the tail.
truncated = False
if len(period.bars) > MAX_PERIOD_BARS:
period = replace(period, bars=period.bars[:MAX_PERIOD_BARS])
truncated = True
period = replace(period, title=f"hamori — {key} {mode}, {function}")
stem = f"hamori_{key.replace('#', 'sharp')}_{mode}_{uuid4().hex[:6]}"
chord_path = OUTPUT_DIR / f"{stem}.chord"
midi_path = OUTPUT_DIR / f"{stem}.mid"
write_chord_file(period, chord_path)
chord_file_to_midi(chord_path, midi_path, tempo=int(tempo))
status = (
f"✅ Готово — {len(period.bars)} тактов · {target_key} · "
f"модель: {checkpoint} · seed: {seed_val if seed_val is not None else 'random'}"
)
if truncated:
status += f" · обрезано до {MAX_PERIOD_BARS} тактов (период ≤ 16)"
return status, _format_bars(period), str(chord_path), str(midi_path)
# ---------------------------------------------------------------------------
# Russian instructions (rendered inline)
# ---------------------------------------------------------------------------
INSTRUCTIONS_RU = """
## Как пользоваться
**hamori** генерирует одну гармоническую фразу (период, 4–16 тактов) в заданной
тональности и стиле. Это инструмент-подсказчик: он предлагает аккордовую
последовательность, которую вы дорабатываете в DAW.
### Параметры
| Поле | Что задаёт |
|------|------------|
| **Модель** | `finetuned` — обучена на вашем корпусе (рекомендуется). `pretrained` — только McGill Billboard, более «общий» звук. |
| **Лад / Тональность** | Мажор или минор и тоника результата (например, `F# major`). Модель генерирует в C/Am и транспонирует в выбранную тональность. |
| **Стиль / Функция** | Метки условия. `H1K0` — авторский стиль. Функция — роль фрагмента (куплет, припев…). |
| **Размер / Subdivision** | Тактовый размер и число позиций на такт. По умолчанию `4/4` и `4`. |
| **Число тактов** | Длина периода. «Авто» — модель сама решает, где закрыть фразу. |
| **Temperature** | Разброс. `1.0` — норма. Выше — смелее и хаотичнее, ниже — предсказуемее. |
| **Top-p** | Нуклеус-сэмплинг. `0.9` — норма. Ниже — консервативнее. |
| **Repetition penalty** | Борется с зацикливанием (I–V–I–V). `0.0` — выкл. Для `pretrained` попробуйте `0.51.0`; для `finetuned` обычно не нужно. |
| **Tonic anchor** | Если префикс пуст — начинать с тоники, чтобы фраза держалась в тональности. |
| **Префикс** | Свои стартовые аккорды через пробел в выбранной тональности, напр. `Cmaj7 . Am7 .`. `.` — держать, `NC` — без аккорда. Если задан, перекрывает tonic anchor. |
| **Seed** | Фиксирует случайность для воспроизводимости. Очистите поле для случайного результата. |
| **Tempo** | Темп MIDI-файла (BPM). На сами аккорды не влияет. |
### Результат
- **Сетка аккордов** — текстовый предпросмотр периода.
- **`.chord`** — исходный формат проекта (человекочитаемый).
- **`.mid`** — импортируйте в REAPER перетаскиванием на дорожку.
### Рекомендации для старта
Модель `finetuned`, `temperature = 1.0`, `top-p = 0.9`, tonic anchor включён.
Если получается монотонно или зациклено — поднимите temperature или добавьте
repetition penalty.
"""
# ---------------------------------------------------------------------------
# UI definition
# ---------------------------------------------------------------------------
def build_ui() -> gr.Blocks:
checkpoints = _available_checkpoints()
default_ckpt = "finetuned" if "finetuned" in checkpoints else (
checkpoints[0] if checkpoints else "finetuned"
)
with gr.Blocks(title="hamori — генератор гармонии") as demo:
gr.Markdown(
"# hamori 🎶 — генератор гармонических периодов\n"
"Заполните форму и нажмите **Сгенерировать**. "
"Подробности — в разделе «Инструкция» внизу."
)
with gr.Row():
# ---- Left: form ------------------------------------------------
with gr.Column(scale=1):
checkpoint = gr.Radio(
choices=checkpoints or ["finetuned", "pretrained"],
value=default_ckpt, label="Модель",
)
with gr.Row():
mode = gr.Radio(["major", "minor"], value="major", label="Лад")
key = gr.Dropdown(NOTE_NAMES, value="C", label="Тональность")
with gr.Row():
style = gr.Dropdown(
STYLES, value="H1K0" if "H1K0" in STYLES else STYLES[0],
label="Стиль",
)
function = gr.Dropdown(
FUNCTIONS,
value="chorus" if "chorus" in FUNCTIONS else FUNCTIONS[0],
label="Функция",
)
with gr.Row():
time = gr.Dropdown(
TIMES, value="4/4" if "4/4" in TIMES else TIMES[0],
label="Размер",
)
subdivision = gr.Radio([4, 8], value=4, label="Subdivision")
with gr.Row():
auto_bars = gr.Checkbox(value=False, label="Авто (длина сама)")
n_bars = gr.Slider(4, 16, value=8, step=1, label="Число тактов")
with gr.Accordion("Сэмплирование", open=True):
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05,
label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
repetition_penalty = gr.Slider(
0.0, 2.0, value=0.0, step=0.1, label="Repetition penalty",
)
with gr.Accordion("Дополнительно", open=False):
tonic_anchor = gr.Checkbox(value=True, label="Tonic anchor")
prefix_text = gr.Textbox(
label="Префикс (аккорды через пробел)",
placeholder="напр. Cmaj7 . Am7 .",
)
seed = gr.Number(value=42, precision=0,
label="Seed (пусто = случайно)")
tempo = gr.Number(value=90, precision=0, label="Tempo (BPM)")
run = gr.Button("Сгенерировать", variant="primary")
# ---- Right: outputs -------------------------------------------
with gr.Column(scale=1):
status = gr.Markdown()
bars_out = gr.Textbox(label="Сетка аккордов", lines=10,
interactive=False)
chord_file = gr.File(label="Скачать .chord")
midi_file = gr.File(label="Скачать .mid")
with gr.Accordion("Инструкция", open=False):
gr.Markdown(INSTRUCTIONS_RU)
run.click(
fn=generate,
inputs=[
checkpoint, mode, key, style, function, time, subdivision,
auto_bars, n_bars, temperature, top_p, repetition_penalty,
tonic_anchor, prefix_text, seed, tempo,
],
outputs=[status, bars_out, chord_file, midi_file],
)
return demo
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--port", type=int, default=7860, help="Server port (default: 7860).")
ap.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1).")
ap.add_argument("--share", action="store_true",
help="Create a temporary public Gradio link.")
args = ap.parse_args()
demo = build_ui()
demo.launch(server_name=args.host, server_port=args.port, share=args.share)
if __name__ == "__main__":
main()