Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
const auto& small = ndim1 > ndim2 ? s2 : s1;
Shape out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) {
int a = big[i];
int b = small[i - diff];
auto a = big[i];
auto b = small[i - diff];
if (b == a) {
out_shape[i] = a;
} else if (a == 1 || b == 1) {
@@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
out_shape[i] = a * b;
} else {
std::ostringstream msg;
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
<< " cannot be broadcast.";
throw std::invalid_argument(msg.str());
}
}