[fix][rope Every decode token was stuck at position 0, so <q_decoded, k_cached> lost the (n - m) term entirely]
This commit is contained in:
parent
537b116b3e
commit
18cca894dd
4 changed files with 235 additions and 20 deletions
24
test_main.py
24
test_main.py
|
|
@ -96,20 +96,20 @@ class TestRoPE:
|
|||
def test_apply_rope_shape(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.randn(B, T, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
out = apply_rope(x, freqs[:T])
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_apply_rope_preserves_norm(self):
|
||||
# rotation is an isometry — norms must be unchanged
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.randn(B, T, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
out = apply_rope(x, freqs[:T])
|
||||
assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5)
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.ones(1, 2, 1, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
out = apply_rope(x, freqs[:2])
|
||||
# position 0 and position 1 should produce different rotations
|
||||
assert not torch.allclose(out[0, 0], out[0, 1])
|
||||
|
||||
|
|
@ -168,27 +168,27 @@ class TestRoPEExtended:
|
|||
"""T=1 input uses only freqs[0] = 1+0j, so output must equal input."""
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=8)
|
||||
x = torch.randn(2, 1, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
out = apply_rope(x, freqs[:1])
|
||||
assert torch.allclose(x, out, atol=1e-6)
|
||||
|
||||
def test_dtype_float32_preserved(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=16)
|
||||
x = torch.randn(1, 4, 2, 16).float()
|
||||
assert apply_rope(x, freqs).dtype == torch.float32
|
||||
assert apply_rope(x, freqs[:4]).dtype == torch.float32
|
||||
|
||||
def test_dtype_float16_preserved(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=16)
|
||||
x = torch.randn(1, 4, 2, 16).half()
|
||||
assert apply_rope(x, freqs).dtype == torch.float16
|
||||
assert apply_rope(x, freqs[:4]).dtype == torch.float16
|
||||
|
||||
def test_inverse_rotation_recovers_input(self):
|
||||
"""Rotating by freqs then by conj(freqs) (inverse) must recover the original."""
|
||||
dim = 16
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=8)
|
||||
x = torch.randn(2, 4, 3, dim)
|
||||
rotated = apply_rope(x, freqs)
|
||||
rotated = apply_rope(x, freqs[:4])
|
||||
xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2))
|
||||
inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2)
|
||||
inv = freqs.conj()[:4].unsqueeze(0).unsqueeze(2)
|
||||
recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype)
|
||||
assert torch.allclose(x, recovered, atol=1e-5)
|
||||
|
||||
|
|
@ -199,8 +199,8 @@ class TestRoPEExtended:
|
|||
torch.manual_seed(7)
|
||||
x_a = torch.randn(1, 4, 2, dim)
|
||||
x_b = torch.randn(1, 4, 2, dim)
|
||||
solo = apply_rope(x_a, freqs)
|
||||
batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1]
|
||||
solo = apply_rope(x_a, freqs[:4])
|
||||
batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs[:4])[:1]
|
||||
assert torch.allclose(solo, batched, atol=1e-6)
|
||||
|
||||
def test_head_independence(self):
|
||||
|
|
@ -208,7 +208,7 @@ class TestRoPEExtended:
|
|||
dim = 16
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=8)
|
||||
x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous()
|
||||
out = apply_rope(x, freqs)
|
||||
out = apply_rope(x, freqs[:4])
|
||||
assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6)
|
||||
assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6)
|
||||
|
||||
|
|
@ -227,7 +227,7 @@ class TestRoPEExtended:
|
|||
"""Rotate tensor at a specific position by embedding it in a zero sequence."""
|
||||
seq = torch.zeros(1, pos + 1, 1, dim)
|
||||
seq[0, pos] = tensor[0, 0]
|
||||
return apply_rope(seq, freqs)[:, pos : pos + 1]
|
||||
return apply_rope(seq, freqs[: pos + 1])[:, pos : pos + 1]
|
||||
|
||||
# Both pairs have relative offset n - m = 6
|
||||
dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue