Fix concatenate/slice_update vjp + reduce binary size (#1735)

* fix concatenate vjp + reduce binary size

* also cast in slice update
This commit is contained in:
Awni Hannun
2025-01-02 16:36:33 -08:00
committed by GitHub
parent ae69cb15e9
commit 6fa0501387
4 changed files with 38 additions and 26 deletions

View File

@@ -708,19 +708,19 @@ array slice_update(
auto [has_neg_strides, upd_shape] =
normalize_slice(src.shape(), start, stop, strides);
// Broadcast update shape to slice shape
auto update_broadcasted = broadcast_to(update, upd_shape, s);
// Cast update to src type and broadcast update shape to slice shape
auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);
// If the entire src is the slice, just return the update
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
return upd;
}
return array(
src.shape(),
src.dtype(),
std::make_shared<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update_broadcasted});
{src, upd});
}
/** Update a slice from the source array with stride 1 in each dimension */
@@ -869,7 +869,7 @@ array clip(
}
array concatenate(
const std::vector<array>& arrays,
std::vector<array> arrays,
int axis,
StreamOrDevice s /* = {} */) {
if (arrays.size() == 0) {
@@ -915,6 +915,9 @@ array concatenate(
// Promote all the arrays to the same type
auto dtype = result_type(arrays);
for (auto& a : arrays) {
a = astype(a, dtype, s);
}
return array(
std::move(shape),
@@ -923,14 +926,11 @@ array concatenate(
std::move(arrays));
}
array concatenate(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> flat_inputs;
array concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} */) {
for (auto& a : arrays) {
flat_inputs.push_back(flatten(a, s));
a = flatten(a, s);
}
return concatenate(flat_inputs, 0, s);
return concatenate(std::move(arrays), 0, s);
}
/** Stack arrays along a new axis */