From 6fa05013873401c6c2cf15a178126a66fb0d6500 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 2 Jan 2025 16:36:33 -0800 Subject: [PATCH] Fix concatenate/slice_update vjp + reduce binary size (#1735) * fix concatenate vjp + reduce binary size * also cast in slice update --- mlx/backend/metal/kernels/copy.metal | 23 +++++++++++++---------- mlx/ops.cpp | 22 +++++++++++----------- mlx/ops.h | 7 ++----- python/tests/test_autograd.py | 12 ++++++++++++ 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 298b48fe9..68e6dcec6 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -12,21 +12,24 @@ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ - instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \ - instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \ - instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \ instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ - instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \ instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ - instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ - instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \ - instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \ - instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \ - instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \ - instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4) + instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) + +#define instantiate_copy_same(tname, type) \ + instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ + instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ + instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \ + instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \ + instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \ + instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \ + instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \ + instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \ + instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) #define instantiate_copy_itype(itname, itype) \ + instantiate_copy_same(itname ##itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ac7238904..ad0a64697 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -708,19 +708,19 @@ array slice_update( auto [has_neg_strides, upd_shape] = normalize_slice(src.shape(), start, stop, strides); - // Broadcast update shape to slice shape - auto update_broadcasted = broadcast_to(update, upd_shape, s); + // Cast update to src type and broadcast update shape to slice shape + auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s); // If the entire src is the slice, just return the update if (!has_neg_strides && upd_shape == src.shape()) { - return astype(update_broadcasted, src.dtype(), s); + return upd; } return array( src.shape(), src.dtype(), std::make_shared( to_stream(s), std::move(start), std::move(stop), std::move(strides)), - {src, update_broadcasted}); + {src, upd}); } /** Update a slice from the source array with stride 1 in each dimension */ @@ -869,7 +869,7 @@ array clip( } array concatenate( - const std::vector& arrays, + std::vector arrays, int axis, StreamOrDevice s /* = {} */) { if (arrays.size() == 0) { @@ -915,6 +915,9 @@ array concatenate( // Promote all the arrays to the same type auto dtype = result_type(arrays); + for (auto& a : arrays) { + a = astype(a, dtype, s); + } return array( std::move(shape), @@ -923,14 +926,11 @@ array concatenate( std::move(arrays)); } -array concatenate( - const std::vector& arrays, - StreamOrDevice s /* = {} */) { - std::vector flat_inputs; +array concatenate(std::vector arrays, StreamOrDevice s /* = {} */) { for (auto& a : arrays) { - flat_inputs.push_back(flatten(a, s)); + a = flatten(a, s); } - return concatenate(flat_inputs, 0, s); + return concatenate(std::move(arrays), 0, s); } /** Stack arrays along a new axis */ diff --git a/mlx/ops.h b/mlx/ops.h index 0f92ff372..ec8cf20e2 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -211,11 +211,8 @@ array clip( StreamOrDevice s = {}); /** Concatenate arrays along a given axis. */ -array concatenate( - const std::vector& arrays, - int axis, - StreamOrDevice s = {}); -array concatenate(const std::vector& arrays, StreamOrDevice s = {}); +array concatenate(std::vector arrays, int axis, StreamOrDevice s = {}); +array concatenate(std::vector arrays, StreamOrDevice s = {}); /** Stack arrays along a new axis. */ array stack(const std::vector& arrays, int axis, StreamOrDevice s = {}); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index a1106b2a4..3553824aa 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -620,6 +620,18 @@ class TestAutograd(mlx_tests.MLXTestCase): x = mx.zeros((2, 4, 8)) self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8)) + def test_concatenate_vjps(self): + def fun(x, y): + return mx.concatenate([x, y]) + + x = mx.array([1, 2, 3], mx.float32) + y = mx.array([1, 2, 3], mx.float16) + grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1] + self.assertTrue(mx.allclose(grads[0], mx.ones(3))) + self.assertTrue(mx.allclose(grads[1], mx.ones(3))) + self.assertEqual(grads[0].dtype, mx.float32) + self.assertEqual(grads[1].dtype, mx.float16) + if __name__ == "__main__": unittest.main()