Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -134,7 +134,7 @@ array flatten(const array& a, StreamOrDevice s = {});
/** Multiply the array by the Hadamard matrix of corresponding size. */
array hadamard_transform(
const array& a,
float scale = 1.0f,
std::optional<float> scale = std::nullopt,
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */