mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||
Shape out_shape(ndim);
|
||||
for (int i = ndim - 1; i >= diff; --i) {
|
||||
int a = big[i];
|
||||
int b = small[i - diff];
|
||||
auto a = big[i];
|
||||
auto b = small[i - diff];
|
||||
if (b == a) {
|
||||
out_shape[i] = a;
|
||||
} else if (a == 1 || b == 1) {
|
||||
@@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
out_shape[i] = a * b;
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
|
||||
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
|
||||
<< " cannot be broadcast.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user