Add more checks and clearer error messages to conv operations (#563)

* Add more checks and clearer error messages to conv operations
This commit is contained in:
Jagrit Digani
2024-01-26 15:13:26 -08:00
committed by GitHub
parent 8fa6b322b9
commit bf17ab5002
3 changed files with 38 additions and 1 deletions

View File

@@ -106,7 +106,9 @@ class RoPE(Module):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
half_D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))