Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
expected[1:, 2:, 3:] = update
self.assertTrue(mx.array_equal(expected, out))
def test_broadcast_arrays(self):
a = mx.array(1)
b = mx.array(1.0)
a, b = mx.broadcast_arrays(a, b)
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.int32)
self.assertEqual(b.shape, ())
self.assertEqual(b.dtype, mx.float32)
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
self.assertEqual(a.shape, (3, 4, 2))
self.assertEqual(b.shape, (3, 4, 2))
if __name__ == "__main__":
unittest.main()