mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size * also cast in slice update
This commit is contained in:
22
mlx/ops.cpp
22
mlx/ops.cpp
@@ -708,19 +708,19 @@ array slice_update(
|
||||
auto [has_neg_strides, upd_shape] =
|
||||
normalize_slice(src.shape(), start, stop, strides);
|
||||
|
||||
// Broadcast update shape to slice shape
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape, s);
|
||||
// Cast update to src type and broadcast update shape to slice shape
|
||||
auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);
|
||||
|
||||
// If the entire src is the slice, just return the update
|
||||
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||
return astype(update_broadcasted, src.dtype(), s);
|
||||
return upd;
|
||||
}
|
||||
return array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
std::make_shared<SliceUpdate>(
|
||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||
{src, update_broadcasted});
|
||||
{src, upd});
|
||||
}
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
@@ -869,7 +869,7 @@ array clip(
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
std::vector<array> arrays,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (arrays.size() == 0) {
|
||||
@@ -915,6 +915,9 @@ array concatenate(
|
||||
|
||||
// Promote all the arrays to the same type
|
||||
auto dtype = result_type(arrays);
|
||||
for (auto& a : arrays) {
|
||||
a = astype(a, dtype, s);
|
||||
}
|
||||
|
||||
return array(
|
||||
std::move(shape),
|
||||
@@ -923,14 +926,11 @@ array concatenate(
|
||||
std::move(arrays));
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> flat_inputs;
|
||||
array concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} */) {
|
||||
for (auto& a : arrays) {
|
||||
flat_inputs.push_back(flatten(a, s));
|
||||
a = flatten(a, s);
|
||||
}
|
||||
return concatenate(flat_inputs, 0, s);
|
||||
return concatenate(std::move(arrays), 0, s);
|
||||
}
|
||||
|
||||
/** Stack arrays along a new axis */
|
||||
|
||||
Reference in New Issue
Block a user