feat: add bigram repetition penalty to generate_period
Tracks ROOT-level bigrams (prev_root → curr_root) across chord-change events. At each FREE position, subtracts penalty * count(prev→root) from ROOT logits, capped at 3.0 to prevent NC/HOLD flooding at extreme values. Practical range: 0.5 (mild, breaks loops after 2 occurrences) to 1.0 (aggressive). Default 0.0 keeps backward compatibility. Added --repetition-penalty flag to scripts/generate.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -91,6 +91,11 @@ def main() -> None:
|
||||
help="Hard cap on generated tokens (default: 300).")
|
||||
ap.add_argument("--bars", type=int, default=None,
|
||||
help="Stop after this many complete bars (default: let the model decide).")
|
||||
ap.add_argument("--repetition-penalty", type=float, default=0.0,
|
||||
dest="repetition_penalty",
|
||||
help="Bigram repetition penalty: subtracts penalty * count(prev→root) "
|
||||
"from ROOT logits at each chord-change position. "
|
||||
"0.0 = disabled (default). Suggested: 0.5–1.5.")
|
||||
ap.add_argument("--seed", type=int, default=None,
|
||||
help="Random seed for reproducibility.")
|
||||
ap.add_argument("--tempo", type=int, default=90,
|
||||
@@ -140,6 +145,7 @@ def main() -> None:
|
||||
max_tokens=args.max_tokens,
|
||||
n_bars=args.bars,
|
||||
seed=args.seed,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
|
||||
# Give generated periods a readable title
|
||||
|
||||
Reference in New Issue
Block a user