Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -4379,11 +4379,11 @@ void init_ops(nb::module_& m) {
"hadamard_transform",
&hadamard_transform,
nb::arg(),
"scale"_a = 1.0,
"scale"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
"def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the Walsh-Hadamard transform along the final axis.
@@ -4393,7 +4393,7 @@ void init_ops(nb::module_& m) {
from scipy.linalg import hadamard
y = hadamard(len(x)) @ x
y = (hadamard(len(x)) @ x) * scale
Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k
<= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.
@@ -4401,6 +4401,7 @@ void init_ops(nb::module_& m) {
Args:
a (array): Input array or scalar.
scale (float): Scale the output by this factor.
Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.
Returns:
array: The transformed array.