mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
faster rms norm (#2433)
This commit is contained in:
@@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.power(mx.array(0j), float("nan"))
|
||||
self.assertTrue(mx.isnan(out))
|
||||
|
||||
def test_irregular_alignments(self):
|
||||
# Unaligned unary op
|
||||
a = mx.ones((64, 1))
|
||||
b = -a[1:]
|
||||
self.assertTrue(mx.all(b == -1.0))
|
||||
|
||||
# Unaligned binary op
|
||||
a = mx.ones((64, 1))
|
||||
b = a[1:]
|
||||
c = b + b
|
||||
self.assertTrue(mx.all(c == 2.0))
|
||||
|
||||
# Unaligned ternary op
|
||||
a = mx.ones((64, 1))
|
||||
b = mx.zeros((63, 1))
|
||||
c = mx.ones((63, 1)).astype(mx.bool_)
|
||||
d = mx.where(c, a[1:], b)
|
||||
self.assertTrue(mx.all(d == 1.0))
|
||||
|
||||
|
||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
def test_broadcast_shapes(self):
|
||||
|
Reference in New Issue
Block a user