faster rms norm (#2433)

This commit is contained in:
Awni Hannun
2025-07-29 13:12:00 -07:00
committed by GitHub
parent 970dbe8e25
commit ef631d63af
11 changed files with 210 additions and 112 deletions

View File

@@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.power(mx.array(0j), float("nan"))
self.assertTrue(mx.isnan(out))
def test_irregular_alignments(self):
# Unaligned unary op
a = mx.ones((64, 1))
b = -a[1:]
self.assertTrue(mx.all(b == -1.0))
# Unaligned binary op
a = mx.ones((64, 1))
b = a[1:]
c = b + b
self.assertTrue(mx.all(c == 2.0))
# Unaligned ternary op
a = mx.ones((64, 1))
b = mx.zeros((63, 1))
c = mx.ones((63, 1)).astype(mx.bool_)
d = mx.where(c, a[1:], b)
self.assertTrue(mx.all(d == 1.0))
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):