"""Minimal Gradio web UI for hamori — generate a harmonic period from a browser. A single-page form that wraps :func:`src.generate.generate_period`. Pick the mode, key, style and sampling parameters; the app returns the chord grid plus downloadable ``.chord`` and ``.mid`` files. Russian usage instructions are embedded on the same page (see the "Инструкция" accordion). This is a convenience front-end for demonstration only — all generation logic lives in ``src/``. The CLI (``scripts/generate.py``) remains the canonical entry point. Usage: python app.py # serve on http://127.0.0.1:7860 python app.py --share # also create a temporary public link python app.py --port 8000 """ from __future__ import annotations import argparse import tempfile from dataclasses import replace from functools import lru_cache from pathlib import Path from time import perf_counter from uuid import uuid4 import gradio as gr import torch from src.generate import generate_period from src.midi_export import chord_file_to_midi from src.model import ChordTransformer from src.tokenizer import VOCAB, write_chord_file # --------------------------------------------------------------------------- # Constants — choices derived from the vocabulary so the form never drifts. # --------------------------------------------------------------------------- CHECKPOINT_DIR = Path(__file__).resolve().parent / "checkpoints" NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] # A period is 4–16 bars by definition; cap runaway "auto" output at this ceiling. MAX_PERIOD_BARS = 16 STYLES = [t[len("STYLE_"):] for t in VOCAB if t.startswith("STYLE_")] FUNCTIONS = [t[len("FUNC_"):] for t in VOCAB if t.startswith("FUNC_")] TIMES = [t[len("TIME_"):] for t in VOCAB if t.startswith("TIME_")] # Files generated for download live here for the lifetime of the process. OUTPUT_DIR = Path(tempfile.mkdtemp(prefix="hamori_webui_")) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --------------------------------------------------------------------------- # Model loading (cached per checkpoint) # --------------------------------------------------------------------------- @lru_cache(maxsize=2) def _load_model(checkpoint: str) -> ChordTransformer: """Load and cache a ChordTransformer by checkpoint stem ('pretrained' / 'finetuned').""" path = CHECKPOINT_DIR / f"{checkpoint}.pt" ckpt = torch.load(path, map_location=DEVICE, weights_only=True) model = ChordTransformer(**ckpt["model_config"]) model.load_state_dict(ckpt["model_state"]) model.to(DEVICE) model.eval() return model def _available_checkpoints() -> list[str]: return sorted(p.stem for p in CHECKPOINT_DIR.glob("*.pt")) # --------------------------------------------------------------------------- # Generation callback # --------------------------------------------------------------------------- def _format_bars(period) -> str: """Render the bar grid as aligned text, one bar per line.""" width = max((len(s) for bar in period.bars for s in bar), default=1) lines = [] for i, bar in enumerate(period.bars, 1): cells = " ".join(s.rjust(width) for s in bar) lines.append(f"Bar {i:2d} | {cells}") return "\n".join(lines) def generate( checkpoint: str, mode: str, key: str, style: str, function: str, time: str, subdivision: int, auto_bars: bool, n_bars: int, temperature: float, top_p: float, repetition_penalty: float, tonic_anchor: bool, prefix_text: str, seed, tempo: int, ): """Run one generation and return (status, bar grid, .chord path, .mid path).""" t0 = perf_counter() try: model = _load_model(checkpoint) except Exception as exc: # noqa: BLE001 — surface any load error to the UI return f"❌ Не удалось загрузить чекпойнт «{checkpoint}»: {exc}", "", None, None target_key = f"{key}_{mode}" # Prefix: explicit text wins; otherwise optionally anchor to the tonic. prefix_chords: list[str] | None prefix_text = (prefix_text or "").strip() if prefix_text: prefix_chords = prefix_text.split() elif tonic_anchor: prefix_chords = [key + ("m" if mode == "minor" else "")] else: prefix_chords = None seed_val = int(seed) if seed is not None else None bars_arg = None if auto_bars else int(n_bars) try: period = generate_period( model=model, mode=mode, time=time, subdivision=int(subdivision), style=style, function=function, key=target_key, prefix=prefix_chords, temperature=float(temperature), top_p=float(top_p), n_bars=bars_arg, seed=seed_val, repetition_penalty=float(repetition_penalty), ) except Exception as exc: # noqa: BLE001 — show generation errors verbatim return f"❌ Ошибка генерации: {exc}", "", None, None # "Auto" lets the model close via EOS, but if it never does it can run away # into dozens of NC/hold bars. A period is 4–16 bars — cap the tail. truncated = False if len(period.bars) > MAX_PERIOD_BARS: period = replace(period, bars=period.bars[:MAX_PERIOD_BARS]) truncated = True period = replace(period, title=f"hamori — {key} {mode}, {function}") stem = f"hamori_{key.replace('#', 'sharp')}_{mode}_{uuid4().hex[:6]}" chord_path = OUTPUT_DIR / f"{stem}.chord" midi_path = OUTPUT_DIR / f"{stem}.mid" write_chord_file(period, chord_path) chord_file_to_midi(chord_path, midi_path, tempo=int(tempo)) elapsed = perf_counter() - t0 status = ( f"✅ Готово — {len(period.bars)} тактов · {target_key} · " f"модель: {checkpoint} · seed: {seed_val if seed_val is not None else 'random'} · " f"{elapsed:.2f} с" ) if truncated: status += f" · обрезано до {MAX_PERIOD_BARS} тактов (период ≤ 16)" return status, _format_bars(period), str(chord_path), str(midi_path) # --------------------------------------------------------------------------- # Button-state helpers — give immediate, persistent feedback while generating. # Gradio's built-in spinner barely flashes once the model is cached, so we also # disable the button and show a "generating" status until the work completes. # --------------------------------------------------------------------------- def _begin_generation(): """Show a busy state the instant the button is clicked (runs before generate).""" return ( gr.update(value="⏳ Генерация…", interactive=False), # run button "⏳ Генерация…", # status message "", # clear the previous grid ) def _end_generation(): """Restore the run button after generation finishes (success or handled error).""" return gr.update(value="Сгенерировать", interactive=True) # --------------------------------------------------------------------------- # Russian instructions (rendered inline) # --------------------------------------------------------------------------- INSTRUCTIONS_RU = """ ## Как пользоваться **hamori** генерирует одну гармоническую фразу (период, 4–16 тактов) в заданной тональности и стиле. Это инструмент-подсказчик: он предлагает аккордовую последовательность, которую вы дорабатываете в DAW. ### Параметры | Поле | Что задаёт | |------|------------| | **Модель** | `finetuned` — обучена на вашем корпусе (рекомендуется). `pretrained` — только McGill Billboard, более «общий» звук. | | **Лад / Тональность** | Мажор или минор и тоника результата (например, `F# major`). Модель генерирует в C/Am и транспонирует в выбранную тональность. | | **Стиль / Функция** | Метки условия. `H1K0` — авторский стиль. Функция — роль фрагмента (куплет, припев…). | | **Размер / Subdivision** | Тактовый размер и число позиций на такт. По умолчанию `4/4` и `4`. | | **Число тактов** | Длина периода. «Авто» — модель сама решает, где закрыть фразу. | | **Temperature** | Разброс. `1.0` — норма. Выше — смелее и хаотичнее, ниже — предсказуемее. | | **Top-p** | Нуклеус-сэмплинг. `0.9` — норма. Ниже — консервативнее. | | **Repetition penalty** | Борется с зацикливанием (I–V–I–V). `0.0` — выкл. Для `pretrained` попробуйте `0.5–1.0`; для `finetuned` обычно не нужно. | | **Tonic anchor** | Если префикс пуст — начинать с тоники, чтобы фраза держалась в тональности. | | **Префикс** | Свои стартовые аккорды через пробел в выбранной тональности, напр. `Cmaj7 . Am7 .`. `.` — держать, `NC` — без аккорда. Если задан, перекрывает tonic anchor. | | **Seed** | Фиксирует случайность для воспроизводимости. Очистите поле для случайного результата. | | **Tempo** | Темп MIDI-файла (BPM). На сами аккорды не влияет. | ### Результат - **Сетка аккордов** — текстовый предпросмотр периода. - **`.chord`** — исходный формат проекта (человекочитаемый). - **`.mid`** — импортируйте в REAPER перетаскиванием на дорожку. ### Рекомендации для старта Модель `finetuned`, `temperature = 1.0`, `top-p = 0.9`, tonic anchor включён. Если получается монотонно или зациклено — поднимите temperature или добавьте repetition penalty. """ # --------------------------------------------------------------------------- # UI definition # --------------------------------------------------------------------------- def build_ui() -> gr.Blocks: checkpoints = _available_checkpoints() default_ckpt = "finetuned" if "finetuned" in checkpoints else ( checkpoints[0] if checkpoints else "finetuned" ) with gr.Blocks(title="hamori — генератор аккордовых последовательностей") as demo: gr.Markdown( "# hamori 🎶 — генератор аккордовых последовательностей\n" "Заполните форму и нажмите **Сгенерировать**. " "Подробности — в разделе «Инструкция» внизу." ) with gr.Row(): # ---- Left: form ------------------------------------------------ with gr.Column(scale=1): checkpoint = gr.Radio( choices=checkpoints or ["finetuned", "pretrained"], value=default_ckpt, label="Модель", ) with gr.Row(): mode = gr.Radio(["major", "minor"], value="major", label="Лад") key = gr.Dropdown(NOTE_NAMES, value="C", label="Тональность") with gr.Row(): style = gr.Dropdown( STYLES, value="H1K0" if "H1K0" in STYLES else STYLES[0], label="Стиль", ) function = gr.Dropdown( FUNCTIONS, value="chorus" if "chorus" in FUNCTIONS else FUNCTIONS[0], label="Функция", ) with gr.Row(): time = gr.Dropdown( TIMES, value="4/4" if "4/4" in TIMES else TIMES[0], label="Размер", ) subdivision = gr.Radio([4, 8], value=4, label="Subdivision") with gr.Row(): auto_bars = gr.Checkbox(value=False, label="Авто") n_bars = gr.Slider(4, 16, value=8, step=1, label="Число тактов") with gr.Accordion("Сэмплирование", open=True): temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") repetition_penalty = gr.Slider( 0.0, 2.0, value=0.0, step=0.1, label="Repetition penalty", ) with gr.Accordion("Дополнительно", open=False): tonic_anchor = gr.Checkbox(value=True, label="Tonic anchor") prefix_text = gr.Textbox( label="Префикс (аккорды через пробел)", placeholder="напр. Cmaj7 . Am7 .", ) seed = gr.Number(value=42, precision=0, label="Seed (пусто = случайно)") tempo = gr.Number(value=90, precision=0, label="Tempo (BPM)") # ---- Right: outputs ------------------------------------------- with gr.Column(scale=1): # Button lives with the outputs so repeated generation needs no # scrolling back up to the form. run = gr.Button("Сгенерировать", variant="primary") status = gr.Markdown() bars_out = gr.Textbox(label="Сетка аккордов", lines=10, interactive=False) chord_file = gr.File(label="Скачать .chord") midi_file = gr.File(label="Скачать .mid") with gr.Accordion("Инструкция", open=False): gr.Markdown(INSTRUCTIONS_RU) run.click( fn=_begin_generation, inputs=None, outputs=[run, status, bars_out], ).then( fn=generate, inputs=[ checkpoint, mode, key, style, function, time, subdivision, auto_bars, n_bars, temperature, top_p, repetition_penalty, tonic_anchor, prefix_text, seed, tempo, ], outputs=[status, bars_out, chord_file, midi_file], ).then( fn=_end_generation, inputs=None, outputs=[run], ) return demo def main() -> None: ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--port", type=int, default=7860, help="Server port (default: 7860).") ap.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1).") ap.add_argument("--share", action="store_true", help="Create a temporary public Gradio link.") args = ap.parse_args() demo = build_ui() demo.launch(server_name=args.host, server_port=args.port, share=args.share) if __name__ == "__main__": main()