Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -686,51 +686,51 @@ std::vector<array> BitwiseBinary::vjp(
return jvp(primals, cotangents, argnums);
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(argnums.size() == 1);
std::vector<array>
broadcast_vjp(const array& primal, const array& cotan, const Stream& s) {
// Reduce cotangents to the shape of the primal
auto& shape = primals[0].shape();
auto& cotan = cotangents[0];
auto& shape = primal.shape();
int diff = cotan.ndim() - shape.size();
std::vector<int> reduce_axes;
for (int i = 0; i < cotan.ndim(); ++i) {
if (i < diff) {
reduce_axes.push_back(i);
} else if (shape[i - diff] != cotan.shape(i)) {
std::vector<int> squeeze_axes(diff);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
auto reduce_axes = squeeze_axes;
for (int i = diff; i < cotan.ndim(); ++i) {
if (shape[i - diff] != cotan.shape(i)) {
reduce_axes.push_back(i);
}
}
return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};
return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return broadcast_vjp(primals[0], cotangents[0], stream());
}
std::vector<array> Broadcast::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(argnums.size() == 1);
return {broadcast_to(tangents[0], shape_, stream())};
return {array(
shape_,
tangents[0].dtype(),
std::make_shared<Broadcast>(stream(), shape_),
tangents)};
}
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto ax = axes[0];
auto in = inputs[0];
auto& in = inputs[0];
if (ax >= 0) {
auto in_shape = in.shape();
int diff = shape_.size() - in.ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
shape_.insert(shape_.begin() + ax + diff, in.shape(ax));
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
in = reshape(in, in_shape, stream());
}
return {{broadcast_to(in, shape_, stream())}, {ax}};
}
@@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
return shape_ == b_other.shape_;
}
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
Shape Broadcast::output_shape(const std::vector<array>& inputs) {
auto shape = inputs[0].shape();
for (int i = 1; i < inputs.size(); ++i) {
shape = broadcast_shapes(shape, inputs[i].shape());
}
return {shape_};
return shape;
}
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
if (inputs.size() < 2) {
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
throw std::invalid_argument(
"[Broadcast] Unable to infer broadcast shape");
}
return {shape_};
}
return {output_shape(inputs)};
};
std::vector<array> BroadcastAxes::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return broadcast_vjp(primals[0], cotangents[0], stream());
}
std::vector<array> BroadcastAxes::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return {array(
output_shape(primals, ignore_axes_),
tangents[0].dtype(),
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
tangents)};
}
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
}
bool BroadcastAxes::is_equivalent(const Primitive& other) const {
const auto& b_other = static_cast<const BroadcastAxes&>(other);
return ignore_axes_ == b_other.ignore_axes_;
}
Shape BroadcastAxes::output_shape(
const std::vector<array>& inputs,
const std::vector<int>& ignore_axes) {
auto shape = Shape{};
for (auto& in : inputs) {
auto in_shape = in.shape();
for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {
in_shape.erase(in_shape.begin() + in.ndim() + *it);
}
shape = broadcast_shapes(shape, in_shape);
}
int dims = ignore_axes.size() + shape.size();
for (auto ax : ignore_axes) {
shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));
}
return shape;
}
std::vector<Shape> BroadcastAxes::output_shapes(
const std::vector<array>& inputs) {
return {output_shape(inputs, ignore_axes_)};
}
std::vector<array> Ceil::vjp(
@@ -3066,14 +3131,9 @@ std::vector<array> Reduce::vjp(
const std::vector<array>& outputs) {
auto in = primals[0];
auto shape = in.shape();
for (auto ax : axes_) {
shape[ax] = 1;
}
auto& cotan = cotangents[0];
if (reduce_type_ == Reduce::Sum) {
return {
broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())};
return {broadcast_arrays({cotan, in}, stream())[0]};
} else if (reduce_type_ == Reduce::Prod) {
auto s = stream();
auto prod_grad_single_axis =
@@ -3129,7 +3189,7 @@ std::vector<array> Reduce::vjp(
return {grad};
} else {
return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])};
return {prod_grad_single_axis(in, cotan, axes_[0])};
}
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
@@ -3139,9 +3199,7 @@ std::vector<array> Reduce::vjp(
}
auto mask = equal(in, out, stream());
auto normalizer = sum(mask, axes_, true, stream());
auto cotan_reshape = reshape(cotan, shape, stream());
cotan_reshape = divide(cotan_reshape, normalizer, stream());
return {multiply(cotan_reshape, mask, stream())};
return {multiply(divide(cotan, normalizer, stream()), mask, stream())};
}
else {