Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -862,6 +862,22 @@ void init_array(py::module_& m) {
py::kw_only(),
"stream"_a = none,
"See :func:`any`.")
.def(
"moveaxis",
&moveaxis,
"source"_a,
"destination"_a,
py::kw_only(),
"stream"_a = none,
"See :func:`moveaxis`.")
.def(
"swapaxes",
&swapaxes,
"axis1"_a,
"axis2"_a,
py::kw_only(),
"stream"_a = none,
"See :func:`moveaxis`.")
.def(
"transpose",
[](const array& a, py::args axes, StreamOrDevice s) {

View File

@@ -1591,6 +1591,50 @@ void init_ops(py::module_& m) {
Returns:
array: The ceil of ``a``.
)pbdoc");
m.def(
"moveaxis",
&moveaxis,
"a"_a,
py::pos_only(),
"source"_a,
"destiantion"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array
Move an axis to a new position.
Args:
a (array): Input array.
source (int): Specifies the source axis.
destination (int): Specifies the destination axis.
Returns:
array: The array with the axis moved.
)pbdoc");
m.def(
"swapaxes",
&swapaxes,
"a"_a,
py::pos_only(),
"axis1"_a,
"axis2"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array
Swap two axes of an array.
Args:
a (array): Input array.
axis1 (int): Specifies the first axis.
axis2 (int): Specifies the second axis.
Returns:
array: The array with swapped axes.
)pbdoc");
m.def(
"transpose",
[](const array& a,

View File

@@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)
def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
def test_sum(self):
x = mx.array(
[

View File

@@ -163,6 +163,61 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
def test_vmap_indexing(self):
x = mx.arange(16).reshape(2, 2, 2, 2)
inds = mx.array([[0, 1, 0], [1, 1, 0]])
out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds)
expected = mx.array(
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
]
)
self.assertTrue(mx.array_equal(out, expected))
out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds)
expected = mx.array(
[
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
[[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
],
[
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
],
]
)
self.assertTrue(mx.array_equal(out, expected))
out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds)
expected = mx.array(
[
[
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
],
[
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
],
]
)
self.assertTrue(mx.array_equal(out, expected))
inds2 = mx.array([[0, 1, 0], [0, 1, 0]])
out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2)
expected = mx.array(
[
[[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
]
)
self.assertTrue(mx.array_equal(out, expected))
if __name__ == "__main__":
unittest.main()