mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
@@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) {
|
||||
a (array): Input array or scalar.
|
||||
dtype (Dtype): The data type to change to.
|
||||
|
||||
Returns:
|
||||
array: The array with the new type.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"hadamard_transform",
|
||||
&hadamard_transform,
|
||||
nb::arg(),
|
||||
"scale"_a = 1.0,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the Walsh-Hadamard transform along the final axis.
|
||||
|
||||
Equivalent to:
|
||||
```python
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
y = hadamard(len(x)) @ x
|
||||
```
|
||||
|
||||
Supports sizes `n = m*2^k` where m in (1, 12, 20, 28)
|
||||
and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
scale (float): Scale the output by this factor.
|
||||
|
||||
Returns:
|
||||
array: The array with the new type.
|
||||
)pbdoc");
|
||||
|
||||
Reference in New Issue
Block a user