Files
mlx/mlx/backend/common/hadamard.h
Alex Barron a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00

105 lines
2.2 KiB
C++

// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core