mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 09:07:12 +08:00
Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like * Improvements * fix access after moved
This commit is contained in:
@@ -482,6 +482,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
|
||||
|
||||
def test_shapeless_compile_full_like(self):
|
||||
x_shape = (1, 1, 32)
|
||||
x = mx.zeros((x_shape))
|
||||
|
||||
def zeros_fun(x):
|
||||
return mx.zeros_like(x)
|
||||
|
||||
def ones_fun(x):
|
||||
return mx.ones_like(x)
|
||||
|
||||
compiled_zero_like = mx.compile(zeros_fun, shapeless=True)
|
||||
compiled_ones_like = mx.compile(ones_fun, shapeless=True)
|
||||
|
||||
self.assertEqual(compiled_zero_like(x).shape, x_shape)
|
||||
self.assertEqual(compiled_ones_like(x).shape, x_shape)
|
||||
|
||||
y_shape = (2, 2, 16)
|
||||
y = mx.zeros(y_shape)
|
||||
|
||||
self.assertEqual(compiled_zero_like(y).shape, y_shape)
|
||||
self.assertEqual(compiled_ones_like(y).shape, y_shape)
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
@@ -842,7 +864,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_outputs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(arr):
|
||||
arrs = [arr] * 64
|
||||
|
||||
Reference in New Issue
Block a user