Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)

* handle hadamard and addmm on empty inputs

* fix
This commit is contained in:
Awni Hannun
2025-05-12 10:48:57 -07:00
committed by GitHub
parent caaa3f1f8c
commit 8f3d208dce
5 changed files with 52 additions and 1 deletions

View File

@@ -472,6 +472,10 @@ array hadamard_transform(
const array& a,
std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
if (a.size() == 0) {
throw std::invalid_argument(
"[hadamard_transform] Does not support empty arrays.");
}
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
int n = a.ndim() > 0 ? a.shape(-1) : 1;
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);
@@ -4326,6 +4330,10 @@ array addmm(
c = reshape(c, c_reshape, s);
}
if (c.shape() != out_shape) {
throw std::invalid_argument(
"[addmm] input c must broadcast to the output shape");
}
auto out = array(
std::move(out_shape),