Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
array hadamard_transform(
const array& a,
float scale /* = 1.0 */,
std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
return array(
a.shape(),