Fix concatenate/slice_update vjp + reduce binary size (#1735)

* fix concatenate vjp + reduce binary size

* also cast in slice update
This commit is contained in:
Awni Hannun 2025-01-02 16:36:33 -08:00 committed by GitHub
parent ae69cb15e9
commit 6fa0501387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 26 deletions

View File

@ -12,21 +12,24 @@
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype)
#define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_same(itname ##itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \

View File

@ -708,19 +708,19 @@ array slice_update(
auto [has_neg_strides, upd_shape] =
normalize_slice(src.shape(), start, stop, strides);
// Broadcast update shape to slice shape
auto update_broadcasted = broadcast_to(update, upd_shape, s);
// Cast update to src type and broadcast update shape to slice shape
auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);
// If the entire src is the slice, just return the update
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
return upd;
}
return array(
src.shape(),
src.dtype(),
std::make_shared<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update_broadcasted});
{src, upd});
}
/** Update a slice from the source array with stride 1 in each dimension */
@ -869,7 +869,7 @@ array clip(
}
array concatenate(
const std::vector<array>& arrays,
std::vector<array> arrays,
int axis,
StreamOrDevice s /* = {} */) {
if (arrays.size() == 0) {
@ -915,6 +915,9 @@ array concatenate(
// Promote all the arrays to the same type
auto dtype = result_type(arrays);
for (auto& a : arrays) {
a = astype(a, dtype, s);
}
return array(
std::move(shape),
@ -923,14 +926,11 @@ array concatenate(
std::move(arrays));
}
array concatenate(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> flat_inputs;
array concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} */) {
for (auto& a : arrays) {
flat_inputs.push_back(flatten(a, s));
a = flatten(a, s);
}
return concatenate(flat_inputs, 0, s);
return concatenate(std::move(arrays), 0, s);
}
/** Stack arrays along a new axis */

View File

@ -211,11 +211,8 @@ array clip(
StreamOrDevice s = {});
/** Concatenate arrays along a given axis. */
array concatenate(
const std::vector<array>& arrays,
int axis,
StreamOrDevice s = {});
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
array concatenate(std::vector<array> arrays, int axis, StreamOrDevice s = {});
array concatenate(std::vector<array> arrays, StreamOrDevice s = {});
/** Stack arrays along a new axis. */
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});

View File

@ -620,6 +620,18 @@ class TestAutograd(mlx_tests.MLXTestCase):
x = mx.zeros((2, 4, 8))
self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))
def test_concatenate_vjps(self):
def fun(x, y):
return mx.concatenate([x, y])
x = mx.array([1, 2, 3], mx.float32)
y = mx.array([1, 2, 3], mx.float16)
grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1]
self.assertTrue(mx.allclose(grads[0], mx.ones(3)))
self.assertTrue(mx.allclose(grads[1], mx.ones(3)))
self.assertEqual(grads[0].dtype, mx.float32)
self.assertEqual(grads[1].dtype, mx.float16)
if __name__ == "__main__":
unittest.main()