// Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/utils.h" namespace mlx::core { // From http://neilsloane.com/hadamard/ constexpr std::string_view h12 = R"( +-++++++++++ --+-+-+-+-+- +++-++----++ +---+--+-++- +++++-++---- +-+---+--+-+ ++--+++-++-- +--++---+--+ ++----+++-++ +--+-++---+- ++++----+++- +-+--+-++--- )"; constexpr std::string_view h20 = R"( +----+----++--++-++- -+----+---+++---+-++ --+----+---+++-+-+-+ ---+----+---+++++-+- ----+----++--++-++-+ -+++++-----+--+++--+ +-+++-+---+-+--+++-- ++-++--+---+-+--+++- +++-+---+---+-+--+++ ++++-----++--+-+--++ --++-+-++-+-----++++ ---++-+-++-+---+-+++ +---++-+-+--+--++-++ ++---++-+----+-+++-+ -++---++-+----+++++- -+--+--++-+----+---- +-+-----++-+----+--- -+-+-+---+--+----+-- --+-+++------+----+- +--+--++------+----+ )"; constexpr std::string_view h28 = R"( +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++- --+-----+++---+-+-+----+--++ ---+-----+++---+-+-+-+--+--+ ----+-----+++---+-+-+++--+-- -----+-----++++--+-+--++--+- ------++----++-+--+-+--++--+ --++++-+-------++--+++-+--+- ---++++-+-----+-++--+-+-+--+ +---+++--+----++-++--+-+-+-- ++---++---+----++-++--+-+-+- +++---+----+----++-++--+-+-+ ++++--------+-+--++-++--+-+- -++++--------+++--++--+--+-+ -+-++-++--++--+--------++++- +-+-++--+--++--+--------++++ -+-+-++--+--++--+----+---+++ +-+-+-++--+--+---+---++---++ ++-+-+-++--+------+--+++---+ -++-+-+-++--+------+-++++--- +-++-+---++--+------+-++++-- -++--++-+-++-+++----++------ +-++--++-+-++-+++-----+----- ++-++---+-+-++-+++-----+---- -++-++-+-+-+-+--+++-----+--- --++-++++-+-+----+++-----+-- +--++-+-++-+-+----+++-----+- ++--++-+-++-+-+----++------+ )"; inline const std::map hadamard_matrices() { return {{12, h12}, {20, h20}, {28, h28}}; } inline std::pair decompose_hadamard(int n) { // n = m*2^k int m = 1; if (!is_power_of_2(n)) { auto h_matrices = hadamard_matrices(); for (auto [factor, _] : h_matrices) { if (n % factor == 0) { m = factor; n /= factor; break; } } if (m == 1) { throw std::invalid_argument( "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } if (n > (1 << 26)) { throw std::invalid_argument( "[hadamard] Only supports n = m*2^k where k <= 26"); } return {n, m}; } } // namespace mlx::core