mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -686,51 +686,51 @@ std::vector<array> BitwiseBinary::vjp(
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1);
|
||||
|
||||
std::vector<array>
|
||||
broadcast_vjp(const array& primal, const array& cotan, const Stream& s) {
|
||||
// Reduce cotangents to the shape of the primal
|
||||
auto& shape = primals[0].shape();
|
||||
auto& cotan = cotangents[0];
|
||||
auto& shape = primal.shape();
|
||||
int diff = cotan.ndim() - shape.size();
|
||||
std::vector<int> reduce_axes;
|
||||
for (int i = 0; i < cotan.ndim(); ++i) {
|
||||
if (i < diff) {
|
||||
reduce_axes.push_back(i);
|
||||
} else if (shape[i - diff] != cotan.shape(i)) {
|
||||
std::vector<int> squeeze_axes(diff);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
auto reduce_axes = squeeze_axes;
|
||||
for (int i = diff; i < cotan.ndim(); ++i) {
|
||||
if (shape[i - diff] != cotan.shape(i)) {
|
||||
reduce_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};
|
||||
return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(argnums.size() == 1);
|
||||
return {broadcast_to(tangents[0], shape_, stream())};
|
||||
return {array(
|
||||
shape_,
|
||||
tangents[0].dtype(),
|
||||
std::make_shared<Broadcast>(stream(), shape_),
|
||||
tangents)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
auto ax = axes[0];
|
||||
auto in = inputs[0];
|
||||
auto& in = inputs[0];
|
||||
if (ax >= 0) {
|
||||
auto in_shape = in.shape();
|
||||
int diff = shape_.size() - in.ndim() + 1;
|
||||
assert(diff >= 0);
|
||||
in_shape.insert(in_shape.begin(), diff, 1);
|
||||
shape_.insert(shape_.begin() + ax + diff, in.shape(ax));
|
||||
ax += diff;
|
||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
||||
in = reshape(in, in_shape, stream());
|
||||
}
|
||||
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
||||
}
|
||||
@@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == b_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
|
||||
Shape Broadcast::output_shape(const std::vector<array>& inputs) {
|
||||
auto shape = inputs[0].shape();
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
shape = broadcast_shapes(shape, inputs[i].shape());
|
||||
}
|
||||
return {shape_};
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||
if (inputs.size() < 2) {
|
||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||
throw std::invalid_argument(
|
||||
"[Broadcast] Unable to infer broadcast shape");
|
||||
}
|
||||
return {shape_};
|
||||
}
|
||||
return {output_shape(inputs)};
|
||||
};
|
||||
|
||||
std::vector<array> BroadcastAxes::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||
}
|
||||
|
||||
std::vector<array> BroadcastAxes::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {array(
|
||||
output_shape(primals, ignore_axes_),
|
||||
tangents[0].dtype(),
|
||||
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
|
||||
tangents)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
|
||||
}
|
||||
|
||||
bool BroadcastAxes::is_equivalent(const Primitive& other) const {
|
||||
const auto& b_other = static_cast<const BroadcastAxes&>(other);
|
||||
return ignore_axes_ == b_other.ignore_axes_;
|
||||
}
|
||||
|
||||
Shape BroadcastAxes::output_shape(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& ignore_axes) {
|
||||
auto shape = Shape{};
|
||||
for (auto& in : inputs) {
|
||||
auto in_shape = in.shape();
|
||||
for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {
|
||||
in_shape.erase(in_shape.begin() + in.ndim() + *it);
|
||||
}
|
||||
shape = broadcast_shapes(shape, in_shape);
|
||||
}
|
||||
int dims = ignore_axes.size() + shape.size();
|
||||
for (auto ax : ignore_axes) {
|
||||
shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> BroadcastAxes::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {output_shape(inputs, ignore_axes_)};
|
||||
}
|
||||
|
||||
std::vector<array> Ceil::vjp(
|
||||
@@ -3066,14 +3131,9 @@ std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
auto in = primals[0];
|
||||
|
||||
auto shape = in.shape();
|
||||
for (auto ax : axes_) {
|
||||
shape[ax] = 1;
|
||||
}
|
||||
auto& cotan = cotangents[0];
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
return {
|
||||
broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())};
|
||||
return {broadcast_arrays({cotan, in}, stream())[0]};
|
||||
} else if (reduce_type_ == Reduce::Prod) {
|
||||
auto s = stream();
|
||||
auto prod_grad_single_axis =
|
||||
@@ -3129,7 +3189,7 @@ std::vector<array> Reduce::vjp(
|
||||
|
||||
return {grad};
|
||||
} else {
|
||||
return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])};
|
||||
return {prod_grad_single_axis(in, cotan, axes_[0])};
|
||||
}
|
||||
|
||||
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
|
||||
@@ -3139,9 +3199,7 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
auto mask = equal(in, out, stream());
|
||||
auto normalizer = sum(mask, axes_, true, stream());
|
||||
auto cotan_reshape = reshape(cotan, shape, stream());
|
||||
cotan_reshape = divide(cotan_reshape, normalizer, stream());
|
||||
return {multiply(cotan_reshape, mask, stream())};
|
||||
return {multiply(divide(cotan, normalizer, stream()), mask, stream())};
|
||||
}
|
||||
|
||||
else {
|
||||
|
||||
Reference in New Issue
Block a user