mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
12
mlx/ops.cpp
12
mlx/ops.cpp
@@ -451,6 +451,18 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return flatten(a, 0, a.ndim() - 1, s);
|
||||
}
|
||||
|
||||
array hadamard_transform(
|
||||
const array& a,
|
||||
float scale /* = 1.0 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||
return array(
|
||||
a.shape(),
|
||||
dtype,
|
||||
std::make_shared<Hadamard>(to_stream(s), scale),
|
||||
{astype(a, dtype, s)});
|
||||
}
|
||||
|
||||
array squeeze(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
|
||||
Reference in New Issue
Block a user