diff --git a/example.py b/example.py new file mode 100644 index 0000000..0a2eb66 --- /dev/null +++ b/example.py @@ -0,0 +1,48 @@ +import torch +from open_mythos.main import OpenMythos, MythosConfig + + +attn_type = "mla" # or "gqa" + +base = { + "vocab_size": 1000, + "dim": 256, + "n_heads": 8, + "max_seq_len": 128, + "max_loop_iters": 4, + "prelude_layers": 1, + "coda_layers": 1, + "n_experts": 8, + "n_shared_experts": 1, + "n_experts_per_tok": 2, + "expert_dim": 64, + "lora_rank": 8, + "attn_type": attn_type, +} + +if attn_type == "gqa": + cfg = MythosConfig(**base, n_kv_heads=2) +else: + cfg = MythosConfig( + **base, + n_kv_heads=8, + kv_lora_rank=32, + q_lora_rank=64, + qk_rope_head_dim=16, + qk_nope_head_dim=16, + v_head_dim=16, + ) + +model = OpenMythos(cfg) +total = sum(p.numel() for p in model.parameters()) +print(f"\n[{attn_type.upper()}] Parameters: {total:,}") + +ids = torch.randint(0, cfg.vocab_size, (2, 16)) +logits = model(ids, n_loops=4) +print(f"[{attn_type.upper()}] Logits shape: {logits.shape}") + +out = model.generate(ids, max_new_tokens=8, n_loops=8) +print(f"[{attn_type.upper()}] Generated shape: {out.shape}") + +A = model.recurrent.injection.get_A() +print(f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)") diff --git a/open_mythos/main.py b/open_mythos/main.py index 77be9d8..aa2f878 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @@ -1014,65 +1013,3 @@ class OpenMythos(nn.Module): input_ids = torch.cat([input_ids, next_tok], dim=1) return input_ids - -# --------------------------------------------------------------------------- -# Quick smoke test -# --------------------------------------------------------------------------- - - -def _smoke(attn_type: str) -> None: - """ - Instantiate a small OpenMythos model with the given attention type, run a - forward pass and a short generation, and verify the LTI spectral radius. - - Args: - attn_type -- "gqa" or "mla" - """ - base = dict( - vocab_size=1000, - dim=256, - n_heads=8, - max_seq_len=128, - max_loop_iters=4, - prelude_layers=1, - coda_layers=1, - n_experts=8, - n_shared_experts=1, - n_experts_per_tok=2, - expert_dim=64, - lora_rank=8, - attn_type=attn_type, - ) - if attn_type == "gqa": - cfg = MythosConfig(**base, n_kv_heads=2) - else: - cfg = MythosConfig( - **base, - n_kv_heads=8, # unused by MLA but field must be valid - kv_lora_rank=32, - q_lora_rank=64, - qk_rope_head_dim=16, - qk_nope_head_dim=16, - v_head_dim=16, - ) - - model = OpenMythos(cfg) - total = sum(p.numel() for p in model.parameters()) - print(f"\n[{attn_type.upper()}] Parameters: {total:,}") - - ids = torch.randint(0, cfg.vocab_size, (2, 16)) - logits = model(ids, n_loops=4) - print(f"[{attn_type.upper()}] Logits shape: {logits.shape}") - - out = model.generate(ids, max_new_tokens=8, n_loops=8) - print(f"[{attn_type.upper()}] Generated shape: {out.shape}") - - A = model.recurrent.injection.get_A() - print( - f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" - ) - - -if __name__ == "__main__": - _smoke("gqa") - _smoke("mla")