diff --git a/README.md b/README.md index 20b38b8..53ed583 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # OpenMythos +> **Disclaimer:** OpenMythos is an independent, community-driven theoretical reconstruction based solely on publicly available research and speculation. It is not affiliated with, endorsed by, or connected to Anthropic or any of their proprietary systems. + OpenMythos is an open-source, theoretical implementation of the Claude Mythos model. It implements a Recurrent-Depth Transformer (RDT) with three stages: **Prelude** (transformer blocks), a looped **Recurrent Block** (up to `max_loop_iters`), and a final **Coda**. Attention is switchable between MLA and GQA, and the feed-forward uses a sparse MoE with routed and shared experts ideal for exploring compute-adaptive, depth-variable reasoning. diff --git a/test_main.py b/test_main.py index 6c9d4dc..6f35fac 100644 --- a/test_main.py +++ b/test_main.py @@ -121,6 +121,139 @@ class TestRoPE: assert not torch.allclose(out[0, 0], out[0, 1]) +# --------------------------------------------------------------------------- +# RoPE extended — correctness invariants +# --------------------------------------------------------------------------- + + +class TestRoPEExtended: + """Comprehensive correctness tests for precompute_rope_freqs and apply_rope.""" + + # --- precompute_rope_freqs --- + + def test_position_zero_is_unit_phasor(self): + """freqs[0] must be all 1+0j (angle = 0 * freq = 0 for every pair).""" + freqs = precompute_rope_freqs(dim=16, max_len=8) + expected = torch.ones(8, dtype=torch.complex64) + assert torch.allclose(freqs[0], expected, atol=1e-6) + + def test_all_phasors_have_unit_magnitude(self): + """Every phasor magnitude must be 1 — RoPE is an isometric rotation.""" + freqs = precompute_rope_freqs(dim=16, max_len=32) + assert torch.allclose(freqs.abs(), torch.ones_like(freqs.abs()), atol=1e-6) + + def test_angles_equal_outer_product(self): + """freqs[t, k].angle() must equal t × base_freq[k] for all t, k.""" + dim, max_len, theta = 8, 6, 500000.0 + freqs = precompute_rope_freqs(dim=dim, max_len=max_len, theta=theta) + base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + t = torch.arange(max_len, dtype=torch.float32) + expected = torch.polar(torch.ones(max_len, dim // 2), torch.outer(t, base)) + assert torch.allclose(freqs.real, expected.real, atol=1e-6) + assert torch.allclose(freqs.imag, expected.imag, atol=1e-6) + + def test_higher_theta_produces_smaller_angles(self): + """Larger theta → slower frequency decay → smaller rotation angle per step.""" + dim, max_len = 16, 8 + freqs_fast = precompute_rope_freqs(dim=dim, max_len=max_len, theta=100.0) + freqs_slow = precompute_rope_freqs(dim=dim, max_len=max_len, theta=500000.0) + assert (freqs_fast[1].angle().abs() > freqs_slow[1].angle().abs()).all() + + def test_default_theta_matches_explicit(self): + """Omitting theta must equal passing theta=500000.0.""" + f1 = precompute_rope_freqs(16, 8) + f2 = precompute_rope_freqs(16, 8, theta=500000.0) + assert torch.allclose(f1.real, f2.real) and torch.allclose(f1.imag, f2.imag) + + # --- apply_rope --- + + def test_position_zero_is_identity(self): + """T=1 input uses only freqs[0] = 1+0j, so output must equal input.""" + freqs = precompute_rope_freqs(dim=16, max_len=8) + x = torch.randn(2, 1, 4, 16) + out = apply_rope(x, freqs) + assert torch.allclose(x, out, atol=1e-6) + + def test_dtype_float32_preserved(self): + freqs = precompute_rope_freqs(dim=16, max_len=16) + x = torch.randn(1, 4, 2, 16).float() + assert apply_rope(x, freqs).dtype == torch.float32 + + def test_dtype_float16_preserved(self): + freqs = precompute_rope_freqs(dim=16, max_len=16) + x = torch.randn(1, 4, 2, 16).half() + assert apply_rope(x, freqs).dtype == torch.float16 + + def test_inverse_rotation_recovers_input(self): + """Rotating by freqs then by conj(freqs) (inverse) must recover the original.""" + dim = 16 + freqs = precompute_rope_freqs(dim=dim, max_len=8) + x = torch.randn(2, 4, 3, dim) + rotated = apply_rope(x, freqs) + xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2)) + inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2) + recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype) + assert torch.allclose(x, recovered, atol=1e-5) + + def test_batch_independence(self): + """Output for one batch item must not depend on other items in the batch.""" + dim = 16 + freqs = precompute_rope_freqs(dim=dim, max_len=16) + torch.manual_seed(7) + x_a = torch.randn(1, 4, 2, dim) + x_b = torch.randn(1, 4, 2, dim) + solo = apply_rope(x_a, freqs) + batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1] + assert torch.allclose(solo, batched, atol=1e-6) + + def test_head_independence(self): + """All heads at the same position must receive identical rotations.""" + dim = 16 + freqs = precompute_rope_freqs(dim=dim, max_len=8) + x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous() + out = apply_rope(x, freqs) + assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6) + assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6) + + def test_relative_position_property(self): + """ + Core RoPE invariant: depends only on (n-m). + Two pairs with the same offset must produce the same dot product. + """ + dim, max_len = 16, 32 + freqs = precompute_rope_freqs(dim=dim, max_len=max_len) + torch.manual_seed(42) + q = torch.randn(1, 1, 1, dim) + k = torch.randn(1, 1, 1, dim) + + def rope_at(tensor, pos): + """Rotate tensor at a specific position by embedding it in a zero sequence.""" + seq = torch.zeros(1, pos + 1, 1, dim) + seq[0, pos] = tensor[0, 0] + return apply_rope(seq, freqs)[:, pos : pos + 1] + + # Both pairs have relative offset n - m = 6 + dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum() + dot_1_7 = (rope_at(q, 1) * rope_at(k, 7)).sum() + assert torch.allclose(dot_3_9, dot_1_7, atol=1e-5) + + def test_max_len_boundary(self): + """apply_rope must handle T == max_len without error or NaN.""" + max_len = 10 + freqs = precompute_rope_freqs(dim=8, max_len=max_len) + x = torch.randn(1, max_len, 2, 8) + out = apply_rope(x, freqs) + assert out.shape == x.shape + assert not torch.isnan(out).any() + + def test_exceeds_max_len_raises(self): + """apply_rope must raise RuntimeError when T > max_len.""" + freqs = precompute_rope_freqs(dim=8, max_len=4) + x = torch.randn(1, 8, 2, 8) # T=8 > max_len=4 + with pytest.raises(RuntimeError): + apply_rope(x, freqs) + + # --------------------------------------------------------------------------- # GQAttention # ---------------------------------------------------------------------------