diff --git a/.gitignore b/.gitignore index b7faf40..dd631d4 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ # Distribution / packaging .Python build/ +experimental/ develop-eggs/ dist/ downloads/ diff --git a/open_mythos/__init__.py b/open_mythos/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/open_mythos/main.py b/open_mythos/main.py new file mode 100644 index 0000000..77be9d8 --- /dev/null +++ b/open_mythos/main.py @@ -0,0 +1,1078 @@ +""" +OpenMythos v1 — Recurrent-Depth Transformer +Architecture: Prelude → [Looped Recurrent Block]×T → Coda +MoE FFN (DeepSeek-style), GQA or MLA, RoPE, RMSNorm, KV cache, LTI-stable injection, ACT halting +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class MythosConfig: + """ + Hyperparameter configuration for OpenMythos. + + Core: + vocab_size -- token vocabulary size + dim -- model hidden dimension + n_heads -- number of query attention heads + n_kv_heads -- number of key/value heads (GQA; ignored by MLA) + max_seq_len -- maximum sequence length for RoPE precomputation + max_loop_iters -- default recurrent loop depth T at inference + prelude_layers -- number of standard transformer layers before the loop + coda_layers -- number of standard transformer layers after the loop + + Attention (attn_type selects between the two): + attn_type -- "gqa" for Grouped Query Attention, "mla" for Multi-Latent Attention + kv_lora_rank -- [MLA] compressed KV latent dimension stored in the cache + q_lora_rank -- [MLA] compressed Q latent dimension + qk_rope_head_dim-- [MLA] per-head dims that receive RoPE + qk_nope_head_dim-- [MLA] per-head dims without positional encoding + v_head_dim -- [MLA] per-head value dimension + + MoE FFN (used inside the recurrent block): + n_experts -- total number of routed expert FFNs + n_shared_experts-- number of always-active shared experts + n_experts_per_tok-- top-K experts selected per token by the router + expert_dim -- hidden dimension inside each fine-grained expert + + Other: + act_threshold -- ACT halting threshold (cumulative probability to stop looping) + rope_theta -- RoPE base frequency + lora_rank -- rank of the per-loop depth-wise LoRA adapter + """ + + vocab_size: int = 32000 + dim: int = 2048 + n_heads: int = 16 + n_kv_heads: int = 4 # GQA: fewer KV heads than Q heads + max_seq_len: int = 4096 + max_loop_iters: int = 16 # T — recurrent depth at inference + prelude_layers: int = 2 + coda_layers: int = 2 + # Attention type: "gqa" | "mla" + attn_type: str = "mla" + # MLA params (only used when attn_type="mla") + kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V + q_lora_rank: int = 1536 # compressed Q latent dim + qk_rope_head_dim: int = 64 # per-head dims that receive RoPE + qk_nope_head_dim: int = 128 # per-head dims without RoPE + v_head_dim: int = 128 # per-head value dim + # MoE + n_experts: int = 64 + n_shared_experts: int = 2 + n_experts_per_tok: int = 4 # top-K routed + expert_dim: int = 512 # fine-grained: dim // (n_experts // n_experts_per_tok) + # ACT halting + act_threshold: float = 0.99 + # RoPE + rope_theta: float = 500000.0 + # LoRA depth adaptation + lora_rank: int = 16 + + +# --------------------------------------------------------------------------- +# RMSNorm +# --------------------------------------------------------------------------- + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (Zhang & Sennrich, 2019). + + Normalizes by the RMS of the input rather than mean+variance, with a + learned per-channel rescaling weight. No bias term. Used in place of + LayerNorm throughout the model for stability and efficiency. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + """ + Args: + dim -- feature dimension to normalize over + eps -- small constant added before sqrt for numerical stability + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x -- input tensor of shape (..., dim) + Returns: + RMS-normalized tensor of the same shape, rescaled by self.weight + """ + rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() + return x * rms * self.weight + + +# --------------------------------------------------------------------------- +# RoPE +# --------------------------------------------------------------------------- + + +def precompute_rope_freqs( + dim: int, max_len: int, theta: float = 500000.0 +) -> torch.Tensor: + """ + Precompute complex-valued RoPE rotation matrices for positions 0..max_len-1. + + Each position gets a complex phasor e^{i·m·θ_k} for each frequency pair k. + Stored as a complex tensor so that rotation is a single pointwise multiply. + + Args: + dim -- head dimension (must be even); frequencies are computed for dim//2 pairs + max_len -- maximum sequence length to precompute + theta -- RoPE base (higher = slower frequency decay; 500k is the LLaMA-3 default) + + Returns: + complex64 tensor of shape (max_len, dim//2) + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + t = torch.arange(max_len, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs) + + +def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Apply rotary positional embeddings to query or key tensors. + + Interprets each pair of adjacent features as a 2D complex number and + multiplies by the precomputed phasor for that position, rotating the + representation in the complex plane without changing its norm. + + Args: + x -- tensor of shape (B, T, H, head_dim); head_dim must be even + freqs_cis -- precomputed complex frequencies of shape (max_len, head_dim//2) + + Returns: + Rotated tensor of the same shape and dtype as x + """ + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[: x.shape[1]].unsqueeze(0).unsqueeze(2) + return torch.view_as_real(xc * freqs_cis).flatten(-2).to(x.dtype) + + +# --------------------------------------------------------------------------- +# Grouped Query Attention with KV cache +# --------------------------------------------------------------------------- + + +class GQAttention(nn.Module): + """ + Grouped Query Attention (Ainslie et al., 2023). + + Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is + shared across n_heads // n_kv_heads query heads, reducing the KV cache size + by that factor while keeping full query expressiveness. + + RoPE is applied to both Q and K. K and V are stored in kv_cache after + RoPE application so that cached values are already positionally encoded and + do not need to be re-rotated on retrieval. + """ + + def __init__(self, cfg: MythosConfig): + """ + Args: + cfg -- MythosConfig; uses dim, n_heads, n_kv_heads + """ + super().__init__() + self.n_heads = cfg.n_heads + self.n_kv_heads = cfg.n_kv_heads + self.head_dim = cfg.dim // cfg.n_heads + self.groups = cfg.n_heads // cfg.n_kv_heads + + self.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + kv_cache: Optional[dict] = None, + cache_key: str = "default", + ) -> torch.Tensor: + """ + Args: + x -- input of shape (B, T, dim) + freqs_cis -- RoPE frequencies for head_dim, shape (T, head_dim//2) + mask -- additive causal mask of shape (1, 1, T, S) or None + kv_cache -- dict mutated in-place; stores {"k": ..., "v": ...} per cache_key + cache_key -- unique key identifying this layer in the cache dict + + Returns: + Output tensor of shape (B, T, dim) + """ + B, T, _ = x.shape + q = self.wq(x).view(B, T, self.n_heads, self.head_dim) + k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) + v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = apply_rope(q, freqs_cis) + k = apply_rope(k, freqs_cis) + + if kv_cache is not None: + if cache_key in kv_cache: + k = torch.cat([kv_cache[cache_key]["k"], k], dim=1) + v = torch.cat([kv_cache[cache_key]["v"], v], dim=1) + kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()} + + # expand KV to match Q heads + k = k.repeat_interleave(self.groups, dim=2) + v = v.repeat_interleave(self.groups, dim=2) + + q = q.transpose(1, 2) # (B, H, T, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + scale = self.head_dim**-0.5 + attn = torch.matmul(q, k.transpose(-2, -1)) * scale + if mask is not None: + attn = attn + mask + attn = F.softmax(attn, dim=-1) + out = torch.matmul(attn, v) + out = out.transpose(1, 2).contiguous().view(B, T, -1) + return self.wo(out) + + +# --------------------------------------------------------------------------- +# Multi-Latent Attention (DeepSeek-V2 style) +# --------------------------------------------------------------------------- + + +class MLAttention(nn.Module): + """ + Multi-Latent Attention (DeepSeek-V2, 2024). + + The key insight: instead of caching full K and V tensors (each of size + n_heads × head_dim per token), MLA compresses the KV path through a + low-rank latent c_kv and only caches that plus the RoPE keys. K_nope and + V are reconstructed from c_kv at each decoding step, trading a cheap + linear projection for dramatically smaller cache memory. + + Q path: + x → q_down (dim→q_lora_rank) → q_norm + → q_up_nope (q_lora_rank → n_heads×qk_nope_head_dim) [no RoPE] + → q_up_rope (q_lora_rank → n_heads×qk_rope_head_dim) [RoPE applied] + q = cat(q_nope, q_rope) per head + + KV path: + x → kv_down (dim → kv_lora_rank + qk_rope_head_dim) + splits into c_kv (latent, cached) and k_rope_raw (shared across heads) + k_rope = RoPE(expand(k_rope_raw)) — applied before caching + c_kv → kv_norm → kv_up → [k_nope | v] — reconstructed each step + k = cat(k_nope, k_rope) per head + + Cache stores: c_kv (kv_lora_rank) + k_rope (n_heads × qk_rope_head_dim), + versus full GQA cache: n_kv_heads × head_dim × 2. At production scale this + is roughly a 10–20× memory reduction. + """ + + def __init__(self, cfg: MythosConfig): + """ + Args: + cfg -- MythosConfig; uses dim, n_heads, kv_lora_rank, q_lora_rank, + qk_rope_head_dim, qk_nope_head_dim, v_head_dim + """ + super().__init__() + self.n_heads = cfg.n_heads + self.kv_lora_rank = cfg.kv_lora_rank + self.qk_rope_dim = cfg.qk_rope_head_dim + self.qk_nope_dim = cfg.qk_nope_head_dim + self.v_dim = cfg.v_head_dim + self.q_head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim + + # Q compression + self.q_down = nn.Linear(cfg.dim, cfg.q_lora_rank, bias=False) + self.q_norm = RMSNorm(cfg.q_lora_rank) + self.q_up_nope = nn.Linear( + cfg.q_lora_rank, cfg.n_heads * cfg.qk_nope_head_dim, bias=False + ) + self.q_up_rope = nn.Linear( + cfg.q_lora_rank, cfg.n_heads * cfg.qk_rope_head_dim, bias=False + ) + + # KV compression: output is [c_kv | k_rope_raw] concatenated + self.kv_down = nn.Linear( + cfg.dim, cfg.kv_lora_rank + cfg.qk_rope_head_dim, bias=False + ) + self.kv_norm = RMSNorm(cfg.kv_lora_rank) + self.kv_up = nn.Linear( + cfg.kv_lora_rank, + cfg.n_heads * (cfg.qk_nope_head_dim + cfg.v_head_dim), + bias=False, + ) + + self.wo = nn.Linear(cfg.n_heads * cfg.v_head_dim, cfg.dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + kv_cache: Optional[dict] = None, + cache_key: str = "default", + ) -> torch.Tensor: + """ + Args: + x -- input of shape (B, T, dim) + freqs_cis -- RoPE frequencies sized for qk_rope_head_dim, shape (T, rope_dim//2) + mask -- additive causal mask of shape (1, 1, T, S) or None + kv_cache -- dict mutated in-place; stores {"c_kv": ..., "k_rope": ...} + cache_key -- unique key identifying this layer in the cache dict + + Returns: + Output tensor of shape (B, T, dim) + """ + B, T, _ = x.shape + + # Q + c_q = self.q_norm(self.q_down(x)) + q_nope = self.q_up_nope(c_q).view(B, T, self.n_heads, self.qk_nope_dim) + q_rope = self.q_up_rope(c_q).view(B, T, self.n_heads, self.qk_rope_dim) + q_rope = apply_rope(q_rope, freqs_cis) + q = torch.cat([q_nope, q_rope], dim=-1) # (B, T, H, nope+rope) + + # KV compress + kv_raw = self.kv_down(x) + c_kv = kv_raw[..., : self.kv_lora_rank] # (B, T, lora_rank) ← cached + k_rope = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim) + # expand rope keys across heads and apply RoPE before caching so + # retrieved keys are already positionally encoded + k_rope = ( + k_rope.unsqueeze(2) + .expand(B, T, self.n_heads, self.qk_rope_dim) + .contiguous() + ) + k_rope = apply_rope(k_rope, freqs_cis) # (B, T, H, rope_dim) ← cached + + if kv_cache is not None: + if cache_key in kv_cache: + c_kv = torch.cat([kv_cache[cache_key]["c_kv"], c_kv], dim=1) + k_rope = torch.cat([kv_cache[cache_key]["k_rope"], k_rope], dim=1) + kv_cache[cache_key] = {"c_kv": c_kv.detach(), "k_rope": k_rope.detach()} + + S = c_kv.shape[1] # full sequence length including cache + + # reconstruct K_nope and V from latent (not cached, recomputed each step) + kv = self.kv_up(self.kv_norm(c_kv)) # (B, S, H*(nope+v)) + kv = kv.view(B, S, self.n_heads, self.qk_nope_dim + self.v_dim) + k_nope = kv[..., : self.qk_nope_dim] # (B, S, H, nope) + v = kv[..., self.qk_nope_dim :] # (B, S, H, v_dim) + k = torch.cat([k_nope, k_rope], dim=-1) # (B, S, H, nope+rope) + + # attention + q = q.transpose(1, 2) # (B, H, T, q_head_dim) + k = k.transpose(1, 2) # (B, H, S, q_head_dim) + v = v.transpose(1, 2) # (B, H, S, v_dim) + + scale = self.q_head_dim**-0.5 + attn = torch.matmul(q, k.transpose(-2, -1)) * scale + if mask is not None: + attn = attn + mask + attn = F.softmax(attn, dim=-1) + out = torch.matmul(attn, v) # (B, H, T, v_dim) + out = out.transpose(1, 2).contiguous().view(B, T, -1) + return self.wo(out) + + +# --------------------------------------------------------------------------- +# DeepSeek-style MoE FFN +# --------------------------------------------------------------------------- + + +class Expert(nn.Module): + """ + Single SwiGLU feed-forward expert. + + Implements the gated linear unit variant: output = down(silu(gate(x)) * up(x)). + Used both as individual routed experts inside MoEFFN and as the standard dense + FFN in prelude/coda blocks (where expert_dim = dim * 4 // 3). + """ + + def __init__(self, dim: int, expert_dim: int): + """ + Args: + dim -- input and output feature dimension + expert_dim -- inner (hidden) dimension of the expert + """ + super().__init__() + self.gate = nn.Linear(dim, expert_dim, bias=False) + self.up = nn.Linear(dim, expert_dim, bias=False) + self.down = nn.Linear(expert_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x -- input of shape (..., dim) + Returns: + Tensor of shape (..., dim) + """ + return self.down(F.silu(self.gate(x)) * self.up(x)) + + +class MoEFFN(nn.Module): + """ + Fine-grained Mixture-of-Experts FFN (DeepSeekMoE, Dai et al., 2024). + + Two classes of experts: + - Routed experts: n_experts small FFNs; each token activates top-K of them + via a learned router. A per-expert bias on router logits is updated during + training to keep load balanced across experts without distorting the loss. + - Shared experts: n_shared_experts larger FFNs always activated for every token, + absorbing common cross-domain patterns (syntax, basic reasoning) that would + otherwise be redundantly learned by many routed experts. + + Total activated parameters per token ≈ topk/n_experts of routed + all shared, + keeping compute sparse while the total parameter count stays large. + """ + + def __init__(self, cfg: MythosConfig): + """ + Args: + cfg -- MythosConfig; uses n_experts, n_shared_experts, n_experts_per_tok, + dim, expert_dim + """ + super().__init__() + self.n_experts = cfg.n_experts + self.n_shared = cfg.n_shared_experts + self.topk = cfg.n_experts_per_tok + + self.router = nn.Linear(cfg.dim, cfg.n_experts, bias=False) + # load-balancing bias adjusted externally during training; not a gradient param + self.register_buffer("router_bias", torch.zeros(cfg.n_experts)) + + self.routed_experts = nn.ModuleList( + [Expert(cfg.dim, cfg.expert_dim) for _ in range(cfg.n_experts)] + ) + self.shared_experts = nn.ModuleList( + [ + Expert(cfg.dim, cfg.expert_dim * cfg.n_experts_per_tok) + for _ in range(self.n_shared) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x -- input of shape (B, T, dim) + Returns: + Tensor of shape (B, T, dim); shared expert outputs are summed on top + of the weighted routed expert outputs + """ + B, T, D = x.shape + flat = x.view(B * T, D) + + # router — bias shifts logits for load balancing without touching loss + logits = self.router(flat) + self.router_bias # (B*T, n_experts) + scores = F.softmax(logits, dim=-1) + topk_scores, topk_idx = scores.topk(self.topk, dim=-1) + topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm + + # routed expert dispatch (token-level scatter) + out = torch.zeros_like(flat) + for i in range(self.topk): + expert_ids = topk_idx[:, i] + token_scores = topk_scores[:, i].unsqueeze(-1) + for eid in range(self.n_experts): + mask = expert_ids == eid + if not mask.any(): + continue + out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask]) + + # shared experts always fire for every token + for shared in self.shared_experts: + out = out + shared(flat) + + return out.view(B, T, D) + + +# --------------------------------------------------------------------------- +# Loop-index RoPE (differentiates recurrent block across iterations) +# --------------------------------------------------------------------------- + + +def loop_index_embedding( + h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0 +) -> torch.Tensor: + """ + Inject a sinusoidal loop-index signal into the first loop_dim channels of h. + + Analogous to RoPE for sequence position, but applied over recurrence depth + instead of token position. Without this, the shared recurrent block weights + must handle both early-stage pattern-matching and late-stage refinement with + no signal distinguishing which loop they are on. Adding the loop index lets + the same parameters implement functionally distinct operations per iteration. + + Args: + h -- hidden state tensor of shape (B, T, dim) + loop_t -- current loop iteration index (0-based) + loop_dim -- number of leading channels to receive the embedding (must be even) + theta -- sinusoidal base frequency + + Returns: + h with a sinusoidal bias added to its first loop_dim channels; same shape + """ + freqs = 1.0 / ( + theta + ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim) + ) + angles = loop_t * freqs # (loop_dim//2,) + emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] + emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) + emb_full[:loop_dim] = emb + return h + emb_full.unsqueeze(0).unsqueeze(0) + + +# --------------------------------------------------------------------------- +# Depth-wise LoRA adapter (per loop iteration) +# --------------------------------------------------------------------------- + + +class LoRAAdapter(nn.Module): + """ + Depth-wise LoRA adaptation for the recurrent block (Bae et al., 2024). + + Pure weight-tying (identical weights every loop) limits expressiveness; + fully distinct weights per loop eliminate parameter savings. This adapter + sits in between: a shared low-rank down-projection and up-projection matrix B + are shared across all loops, while a small per-loop scale vector shifts the + effective transformation at each depth without adding significant parameters. + + delta(x, t) = (down(x) * scale[t]) @ B + """ + + def __init__(self, dim: int, rank: int, max_loops: int): + """ + Args: + dim -- model hidden dimension (input and output size) + rank -- low-rank bottleneck dimension + max_loops -- maximum number of loop iterations (determines embedding table size) + """ + super().__init__() + self.down = nn.Linear(dim, rank, bias=False) # shared A: dim → rank + self.B = nn.Parameter(torch.zeros(rank, dim)) # shared B: rank → dim + self.scale = nn.Embedding(max_loops, rank) # per-loop element-wise scale + + def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: + """ + Args: + x -- input tensor of shape (B, T, dim) + loop_t -- current loop index used to look up the per-loop scale + + Returns: + Delta tensor of shape (B, T, dim) to be added to the block output + """ + s = self.scale(torch.tensor(loop_t, device=x.device)) # (rank,) + down = self.down(x) * s # (B, T, rank) + return down @ self.B # (B, T, dim) + + +# --------------------------------------------------------------------------- +# Single Transformer Block (shared across recurrent loops) +# --------------------------------------------------------------------------- + + +class TransformerBlock(nn.Module): + """ + Standard pre-norm transformer block with swappable attention and optional MoE FFN. + + Attention is selected by cfg.attn_type: + "gqa" → GQAttention (Grouped Query Attention, fewer KV heads) + "mla" → MLAttention (Multi-Latent Attention, compressed KV cache) + + FFN is selected by use_moe: + True → MoEFFN (fine-grained routed + shared experts; used in RecurrentBlock) + False → Expert (dense SwiGLU FFN; used in Prelude and Coda) + """ + + def __init__(self, cfg: MythosConfig, use_moe: bool = False): + """ + Args: + cfg -- MythosConfig; attn_type selects the attention class + use_moe -- if True, use MoEFFN; otherwise use a dense Expert FFN + """ + super().__init__() + self.attn_norm = RMSNorm(cfg.dim) + self.ffn_norm = RMSNorm(cfg.dim) + self.attn = MLAttention(cfg) if cfg.attn_type == "mla" else GQAttention(cfg) + self.ffn = MoEFFN(cfg) if use_moe else Expert(cfg.dim, cfg.dim * 4 // 3) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + kv_cache: Optional[dict] = None, + cache_key: str = "default", + ) -> torch.Tensor: + """ + Args: + x -- input of shape (B, T, dim) + freqs_cis -- precomputed RoPE frequencies + mask -- additive causal mask or None + kv_cache -- cache dict mutated in-place by the attention layer + cache_key -- key identifying this layer in the cache + + Returns: + Output tensor of shape (B, T, dim) + """ + x = x + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key) + x = x + self.ffn(self.ffn_norm(x)) + return x + + +# --------------------------------------------------------------------------- +# LTI-stable injection parameters (spectral radius < 1 by construction) +# --------------------------------------------------------------------------- + + +class LTIInjection(nn.Module): + """ + Stable input injection for the recurrent update rule (Parcae, Prairie et al., 2026). + + The recurrent hidden state evolves as: + h_{t+1} = A · h_t + B · e + Transformer(h_t, e) + + where e is the encoded input injected at every loop step to prevent drift. + Without constraints, A can develop spectral radius ≥ 1, causing the hidden + state to explode across loop iterations and destabilize training. + + This class guarantees ρ(A) < 1 by construction via a ZOH discretization: + A_continuous = Diag(-exp(log_A)) always negative diagonal + A_discrete = exp(Δt · A_continuous) element-wise, values in (0, 1) + + where log_A and log_dt are learned parameters and exp ensures positivity. + This makes looped model training robust to hyperparameter choices and stable + even at high learning rates. + """ + + def __init__(self, dim: int): + """ + Args: + dim -- hidden state dimension; one scalar per channel for A and B + """ + super().__init__() + self.log_A = nn.Parameter(torch.zeros(dim)) # log of A_continuous magnitude + self.log_dt = nn.Parameter(torch.zeros(1)) # log of discretization step Δt + self.B = nn.Parameter(torch.ones(dim) * 0.1) + + def get_A(self) -> torch.Tensor: + """ + Compute the discretized diagonal state matrix A_discrete. + + Returns: + 1-D tensor of shape (dim,) with all values strictly in (0, 1), + guaranteeing ρ(A) < 1 regardless of learned parameter values. + """ + A_c = -torch.exp(self.log_A) # continuous diagonal, always < 0 + dt = torch.exp(self.log_dt) # always > 0 + return torch.exp(dt * A_c) # ZOH: values in (0, 1) + + def forward( + self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor + ) -> torch.Tensor: + """ + Compute h_{t+1} = A·h_t + B·e + transformer_out. + + Args: + h -- current hidden state (B, T, dim) + e -- encoded input from Prelude, frozen across loops (B, T, dim) + transformer_out -- output of the recurrent TransformerBlock at this step (B, T, dim) + + Returns: + Updated hidden state of shape (B, T, dim) + """ + A = self.get_A() + return A * h + self.B * e + transformer_out + + +# --------------------------------------------------------------------------- +# ACT halting (Adaptive Computation Time) +# --------------------------------------------------------------------------- + + +class ACTHalting(nn.Module): + """ + Adaptive Computation Time halting mechanism (Graves, 2016). + + Learns a per-position halting probability at each loop iteration. Positions + where the hidden state has converged (high cumulative halting probability) + stop accumulating updates, while positions still being refined continue. + This lets easy tokens halt early and hard tokens receive more computation, + all within the same batch. Also makes the model Turing-complete under + certain assumptions about the expressiveness of the transformer block. + """ + + def __init__(self, dim: int): + """ + Args: + dim -- hidden state dimension; input to the halting scalar predictor + """ + super().__init__() + self.halt = nn.Linear(dim, 1) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """ + Predict per-position halting probability from the current hidden state. + + Args: + h -- hidden state of shape (B, T, dim) + + Returns: + Halting probability tensor of shape (B, T), values in (0, 1) + """ + return torch.sigmoid(self.halt(h)).squeeze(-1) + + +# --------------------------------------------------------------------------- +# Recurrent Block (one set of weights, looped T times) +# --------------------------------------------------------------------------- + + +class RecurrentBlock(nn.Module): + """ + The core recurrent block of OpenMythos — a single TransformerBlock looped T times. + + At each loop iteration t, the hidden state h is updated via: + 1. loop_index_embedding: inject sinusoidal loop-index signal into h + 2. TransformerBlock: compute attention + MoE FFN on normalized (h + e) + 3. LoRAAdapter: apply depth-wise LoRA delta to transformer output + 4. LTIInjection: stable update h = A·h + B·e + transformer_out + 5. ACTHalting: accumulate per-position halting probabilities; + positions that have converged stop contributing + + The encoded input e (output of the Prelude) is injected at every step to keep + the original input signal alive across arbitrary loop depth, preventing drift. + The ACT mechanism produces a weighted sum of hidden states across iterations, + where the weights reflect when each position converged. + + More loop iterations at inference = deeper reasoning chains, following the + depth-extrapolation property of looped transformers (Saunshi et al., 2025). + """ + + def __init__(self, cfg: MythosConfig): + """ + Args: + cfg -- MythosConfig; uses dim, lora_rank, max_loop_iters, act_threshold + """ + super().__init__() + self.cfg = cfg + self.block = TransformerBlock(cfg, use_moe=True) + self.injection = LTIInjection(cfg.dim) + self.act = ACTHalting(cfg.dim) + self.lora = LoRAAdapter(cfg.dim, cfg.lora_rank, cfg.max_loop_iters) + self.norm = RMSNorm(cfg.dim) + self.loop_dim = ( + cfg.dim // 8 + ) # fraction of channels receiving loop-index embedding + + def forward( + self, + h: torch.Tensor, + e: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor] = None, + n_loops: Optional[int] = None, + kv_cache: Optional[dict] = None, + ) -> torch.Tensor: + """ + Run the recurrent loop for up to n_loops iterations with ACT early exit. + + Args: + h -- initial hidden state from the Prelude, shape (B, T, dim) + e -- encoded input frozen for injection each step, shape (B, T, dim) + freqs_cis-- precomputed RoPE frequencies + mask -- additive causal mask or None + n_loops -- number of loop iterations; defaults to cfg.max_loop_iters. + Can be increased at inference for deeper reasoning (depth extrapolation). + kv_cache -- cache dict passed through to the inner TransformerBlock; + each loop iteration uses a separate cache key + + Returns: + ACT-weighted sum of hidden states across iterations, shape (B, T, dim) + """ + n_loops = n_loops or self.cfg.max_loop_iters + B, T, D = h.shape + + halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) + cumulative_p = torch.zeros(B, T, device=h.device) + h_out = torch.zeros_like(h) + + for t in range(n_loops): + h_loop = loop_index_embedding(h, t, self.loop_dim) + combined = self.norm(h_loop + e) + cache_key = f"recurrent_loop_{t}" + trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key) + trans_out = trans_out + self.lora(trans_out, t) + h = self.injection(h, e, trans_out) + + p = self.act(h) # (B, T) + still_running = ~halted + + # ACT remainder trick: once cumulative_p + p crosses threshold, + # assign the remaining probability mass as the final weight + remainder = (1.0 - cumulative_p).clamp(min=0) + weight = torch.where( + cumulative_p + p >= self.cfg.act_threshold, + remainder, + p, + ) + h_out = h_out + weight.unsqueeze(-1) * h + + cumulative_p = cumulative_p + p * still_running.float() + halted = halted | (cumulative_p >= self.cfg.act_threshold) + + if halted.all(): + break + + return h_out + + +# --------------------------------------------------------------------------- +# Full Model +# --------------------------------------------------------------------------- + + +class OpenMythos(nn.Module): + """ + OpenMythos — Recurrent-Depth Transformer language model. + + Implements the hypothesized Claude Mythos architecture as a Recurrent-Depth + Transformer (RDT). The model divides computation into three functional blocks: + + Input tokens + ↓ + [Prelude] — prelude_layers standard transformer blocks, run once + ↓ + [Recurrent Block] — one transformer block looped T times with input injection + ↑_______↓ h_{t+1} = A·h_t + B·e + Transformer(h_t, e) + ↓ + [Coda] — coda_layers standard transformer blocks, run once + ↓ + Output logits + + Key properties: + - Same weights, more loops → deeper reasoning, no parameter growth + - Depth extrapolation: train on N loops, test on N+k loops (emergent) + - ACT halting: variable compute per position within a batch + - MoE FFN in the recurrent block: breadth across domains + - LTI-stable injection: spectral radius < 1 guaranteed by construction + - Supports both GQA and MLA attention (set via cfg.attn_type) + """ + + def __init__(self, cfg: MythosConfig): + """ + Args: + cfg -- MythosConfig specifying all architecture hyperparameters + """ + super().__init__() + self.cfg = cfg + + self.embed = nn.Embedding(cfg.vocab_size, cfg.dim) + + # GQA uses full head_dim for RoPE; MLA uses only qk_rope_head_dim (decoupled) + freqs = precompute_rope_freqs( + cfg.dim // cfg.n_heads, cfg.max_seq_len, cfg.rope_theta + ) + self.register_buffer("freqs_cis", freqs) + freqs_mla = precompute_rope_freqs( + cfg.qk_rope_head_dim, cfg.max_seq_len, cfg.rope_theta + ) + self.register_buffer("freqs_cis_mla", freqs_mla) + + self.prelude = nn.ModuleList( + [TransformerBlock(cfg, use_moe=False) for _ in range(cfg.prelude_layers)] + ) + self.recurrent = RecurrentBlock(cfg) + self.coda = nn.ModuleList( + [TransformerBlock(cfg, use_moe=False) for _ in range(cfg.coda_layers)] + ) + + self.norm = RMSNorm(cfg.dim) + self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + self.head.weight = self.embed.weight # weight tying + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize all linear and embedding weights with N(0, 0.02).""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + @staticmethod + def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: + """ + Build an additive causal mask: 0 on and below the diagonal, -inf above. + + Args: + seq_len -- sequence length + device -- target device + + Returns: + Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S) + """ + mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device) + return torch.triu(mask, diagonal=1) + + def forward( + self, + input_ids: torch.Tensor, + n_loops: Optional[int] = None, + kv_cache: Optional[dict] = None, + ) -> torch.Tensor: + """ + Forward pass through Prelude → Recurrent Block → Coda. + + Args: + input_ids -- token indices of shape (B, T) + n_loops -- recurrent loop depth; defaults to cfg.max_loop_iters. + Increase at inference to extrapolate to harder problems. + kv_cache -- dict mutated in-place for autoregressive KV caching; + pass an empty dict {} and reuse across decode steps + + Returns: + Logits of shape (B, T, vocab_size) + """ + B, T = input_ids.shape + device = input_ids.device + + x = self.embed(input_ids) + freqs_cis = ( + self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis + )[:T] + mask = self._causal_mask(T, device) if T > 1 else None + + for i, layer in enumerate(self.prelude): + x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}") + + e = x # encoded input frozen for injection every loop + x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache) + + for i, layer in enumerate(self.coda): + x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"coda_{i}") + + return self.head(self.norm(x)) + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + max_new_tokens: int = 64, + n_loops: int = 8, + temperature: float = 1.0, + top_k: int = 50, + ) -> torch.Tensor: + """ + Autoregressive token generation with KV caching. + + On step 0 the full prompt is processed. On subsequent steps only the + last generated token is passed, with all previous keys and values + retrieved from kv_cache. This keeps decode cost proportional to one + token per step rather than the full growing sequence. + + n_loops can be set higher than the training value to extrapolate to + harder problems at inference time (depth extrapolation property). + + Args: + input_ids -- prompt token indices of shape (B, T) + max_new_tokens -- number of tokens to generate + n_loops -- recurrent loop depth for each decode step + temperature -- softmax temperature; lower = more greedy + top_k -- restrict sampling to top-K logits (0 = disabled) + + Returns: + Token indices of shape (B, T + max_new_tokens) + """ + kv_cache: dict = {} + for step in range(max_new_tokens): + cur_ids = input_ids if step == 0 else input_ids[:, -1:] + logits = self.forward(cur_ids, n_loops=n_loops, kv_cache=kv_cache) + logits = logits[:, -1, :] / temperature + if top_k > 0: + v, _ = logits.topk(top_k) + logits[logits < v[:, -1:]] = float("-inf") + probs = F.softmax(logits, dim=-1) + next_tok = torch.multinomial(probs, num_samples=1) + input_ids = torch.cat([input_ids, next_tok], dim=1) + return input_ids + + +# --------------------------------------------------------------------------- +# Quick smoke test +# --------------------------------------------------------------------------- + + +def _smoke(attn_type: str) -> None: + """ + Instantiate a small OpenMythos model with the given attention type, run a + forward pass and a short generation, and verify the LTI spectral radius. + + Args: + attn_type -- "gqa" or "mla" + """ + base = dict( + vocab_size=1000, + dim=256, + n_heads=8, + max_seq_len=128, + max_loop_iters=4, + prelude_layers=1, + coda_layers=1, + n_experts=8, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=64, + lora_rank=8, + attn_type=attn_type, + ) + if attn_type == "gqa": + cfg = MythosConfig(**base, n_kv_heads=2) + else: + cfg = MythosConfig( + **base, + n_kv_heads=8, # unused by MLA but field must be valid + kv_lora_rank=32, + q_lora_rank=64, + qk_rope_head_dim=16, + qk_nope_head_dim=16, + v_head_dim=16, + ) + + model = OpenMythos(cfg) + total = sum(p.numel() for p in model.parameters()) + print(f"\n[{attn_type.upper()}] Parameters: {total:,}") + + ids = torch.randint(0, cfg.vocab_size, (2, 16)) + logits = model(ids, n_loops=4) + print(f"[{attn_type.upper()}] Logits shape: {logits.shape}") + + out = model.generate(ids, max_new_tokens=8, n_loops=8) + print(f"[{attn_type.upper()}] Generated shape: {out.shape}") + + A = model.recurrent.injection.get_A() + print( + f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" + ) + + +if __name__ == "__main__": + _smoke("gqa") + _smoke("mla")