Fast Hadamard Transform (#1249)

* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
This commit is contained in:
Alex Barron
2024-07-09 20:39:01 -07:00
committed by GitHub
parent 03cf033f82
commit a3c287354f
22 changed files with 878 additions and 11 deletions

View File

@@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) {
a (array): Input array or scalar.
dtype (Dtype): The data type to change to.
Returns:
array: The array with the new type.
)pbdoc");
m.def(
"hadamard_transform",
&hadamard_transform,
nb::arg(),
"scale"_a = 1.0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the Walsh-Hadamard transform along the final axis.
Equivalent to:
```python
from scipy.linalg import hadamard
y = hadamard(len(x)) @ x
```
Supports sizes `n = m*2^k` where m in (1, 12, 20, 28)
and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.
Args:
a (array): Input array or scalar.
scale (float): Scale the output by this factor.
Returns:
array: The array with the new type.
)pbdoc");