mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
@@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)
|
||||
|
||||
def test_move_swap_axes(self):
|
||||
x = mx.zeros((2, 3, 4))
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
|
||||
|
||||
def test_sum(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
Reference in New Issue
Block a user