Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)
def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
def test_sum(self):
x = mx.array(
[