mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size * also cast in slice update
This commit is contained in:
parent
ae69cb15e9
commit
6fa0501387
@ -12,21 +12,24 @@
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype)
|
||||
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_same(itname ##itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
|
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -708,19 +708,19 @@ array slice_update(
|
||||
auto [has_neg_strides, upd_shape] =
|
||||
normalize_slice(src.shape(), start, stop, strides);
|
||||
|
||||
// Broadcast update shape to slice shape
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape, s);
|
||||
// Cast update to src type and broadcast update shape to slice shape
|
||||
auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);
|
||||
|
||||
// If the entire src is the slice, just return the update
|
||||
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||
return astype(update_broadcasted, src.dtype(), s);
|
||||
return upd;
|
||||
}
|
||||
return array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
std::make_shared<SliceUpdate>(
|
||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||
{src, update_broadcasted});
|
||||
{src, upd});
|
||||
}
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
@ -869,7 +869,7 @@ array clip(
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
std::vector<array> arrays,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (arrays.size() == 0) {
|
||||
@ -915,6 +915,9 @@ array concatenate(
|
||||
|
||||
// Promote all the arrays to the same type
|
||||
auto dtype = result_type(arrays);
|
||||
for (auto& a : arrays) {
|
||||
a = astype(a, dtype, s);
|
||||
}
|
||||
|
||||
return array(
|
||||
std::move(shape),
|
||||
@ -923,14 +926,11 @@ array concatenate(
|
||||
std::move(arrays));
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> flat_inputs;
|
||||
array concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} */) {
|
||||
for (auto& a : arrays) {
|
||||
flat_inputs.push_back(flatten(a, s));
|
||||
a = flatten(a, s);
|
||||
}
|
||||
return concatenate(flat_inputs, 0, s);
|
||||
return concatenate(std::move(arrays), 0, s);
|
||||
}
|
||||
|
||||
/** Stack arrays along a new axis */
|
||||
|
@ -211,11 +211,8 @@ array clip(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Concatenate arrays along a given axis. */
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
array concatenate(std::vector<array> arrays, int axis, StreamOrDevice s = {});
|
||||
array concatenate(std::vector<array> arrays, StreamOrDevice s = {});
|
||||
|
||||
/** Stack arrays along a new axis. */
|
||||
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||
|
@ -620,6 +620,18 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((2, 4, 8))
|
||||
self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))
|
||||
|
||||
def test_concatenate_vjps(self):
|
||||
def fun(x, y):
|
||||
return mx.concatenate([x, y])
|
||||
|
||||
x = mx.array([1, 2, 3], mx.float32)
|
||||
y = mx.array([1, 2, 3], mx.float16)
|
||||
grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1]
|
||||
self.assertTrue(mx.allclose(grads[0], mx.ones(3)))
|
||||
self.assertTrue(mx.allclose(grads[1], mx.ones(3)))
|
||||
self.assertEqual(grads[0].dtype, mx.float32)
|
||||
self.assertEqual(grads[1].dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user