Fast Hadamard Transform (#1249)

* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
This commit is contained in:
Alex Barron
2024-07-09 20:39:01 -07:00
committed by GitHub
parent 03cf033f82
commit a3c287354f
22 changed files with 878 additions and 11 deletions

View File

@@ -451,6 +451,18 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
return flatten(a, 0, a.ndim() - 1, s);
}
array hadamard_transform(
const array& a,
float scale /* = 1.0 */,
StreamOrDevice s /* = {} */) {
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
return array(
a.shape(),
dtype,
std::make_shared<Hadamard>(to_stream(s), scale),
{astype(a, dtype, s)});
}
array squeeze(
const array& a,
const std::vector<int>& axes,