mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 01:51:18 +08:00
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2 * working for m*2^k * add scale and check contiguity * add size check * clean up * fix test * add grads + vmap * gpu only * skip on linux * test typo * add cpu impl * remove gpu only tests * fix linux build + add is_equivalent
This commit is contained in:
parent
03cf033f82
commit
a3c287354f
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def had(x):
|
||||||
|
y = mx.hadamard_transform(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def copy(x):
|
||||||
|
y = x + 1.0
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype):
|
||||||
|
system_size = 2**26
|
||||||
|
outputs = {}
|
||||||
|
for test_fn in (had, copy):
|
||||||
|
for m in [1, 12, 20, 28]:
|
||||||
|
if test_fn == copy:
|
||||||
|
key = "copy"
|
||||||
|
elif m == 1:
|
||||||
|
key = "had_2^k"
|
||||||
|
else:
|
||||||
|
key = "had_m*2^k"
|
||||||
|
outputs.setdefault(key, {})
|
||||||
|
for k in range(7, 14):
|
||||||
|
n = m * 2**k
|
||||||
|
if n > 2**15:
|
||||||
|
continue
|
||||||
|
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||||
|
x = mx.array(x_np)
|
||||||
|
runtime_ms = measure_runtime(test_fn, x=x)
|
||||||
|
bytes_per_gb = 1e9
|
||||||
|
ms_per_s = 1e3
|
||||||
|
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||||
|
bandwidth_gb = (
|
||||||
|
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||||
|
)
|
||||||
|
print(n, bandwidth_gb)
|
||||||
|
outputs[key][n] = bandwidth_gb
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
"copy": "black",
|
||||||
|
"had_2^k": "steelblue",
|
||||||
|
"had_m*2^k": "skyblue",
|
||||||
|
}
|
||||||
|
for key, output in outputs.items():
|
||||||
|
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||||
|
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||||
|
plt.xlabel("N")
|
||||||
|
plt.ylabel("Bandwidth (GB/s)")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
dtype = np.float16 if args.fp16 else np.float32
|
||||||
|
run(dtype)
|
@ -72,6 +72,7 @@ Operations
|
|||||||
gather_qmm
|
gather_qmm
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
|
hadamard_transform
|
||||||
identity
|
identity
|
||||||
inner
|
inner
|
||||||
isclose
|
isclose
|
||||||
|
@ -50,6 +50,7 @@ DEFAULT(GatherMM)
|
|||||||
DEFAULT(GatherQMM)
|
DEFAULT(GatherQMM)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
|
DEFAULT(Hadamard)
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
|
@ -42,6 +42,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
|
@ -68,6 +68,7 @@ DEFAULT(Full)
|
|||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
|
DEFAULT(Hadamard)
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
|
107
mlx/backend/common/hadamard.cpp
Normal file
107
mlx/backend/common/hadamard.cpp
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/hadamard.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// n = 2^k component
|
||||||
|
template <typename T>
|
||||||
|
void hadamard_n(array& out, int n, int m, float scale) {
|
||||||
|
for (int b = 0; b < out.size() / n; b++) {
|
||||||
|
size_t loc = b * n;
|
||||||
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
int h = 1;
|
||||||
|
int n_over_2 = n / 2;
|
||||||
|
while (h < n) {
|
||||||
|
for (int i = 0; i < n / 2; i++) {
|
||||||
|
int k = i & (h - 1);
|
||||||
|
int j = ((i - k) << 1) + k;
|
||||||
|
float x = *(data_ptr + j);
|
||||||
|
float y = *(data_ptr + j + h);
|
||||||
|
*(data_ptr + j) = x + y;
|
||||||
|
*(data_ptr + j + h) = x - y;
|
||||||
|
if (h == n_over_2) {
|
||||||
|
*(data_ptr + j) *= scale;
|
||||||
|
*(data_ptr + j + h) *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h <<= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// m component
|
||||||
|
template <typename T>
|
||||||
|
void hadamard_m(array& out, int n, int m, float scale) {
|
||||||
|
auto h_matrices = hadamard_matrices();
|
||||||
|
auto& matrix = h_matrices[m];
|
||||||
|
auto start = 1;
|
||||||
|
auto end = matrix.find('\n', start);
|
||||||
|
std::vector<bool> hmat_vec;
|
||||||
|
while (end != std::string_view::npos) {
|
||||||
|
auto row = matrix.substr(start, end - start);
|
||||||
|
for (int i = 0; i < row.length(); i++) {
|
||||||
|
hmat_vec.push_back(row[i] == '+');
|
||||||
|
}
|
||||||
|
start = end + 1;
|
||||||
|
end = matrix.find('\n', start);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < out.size() / m / n; b++) {
|
||||||
|
size_t loc = b * n * m;
|
||||||
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
std::vector<float> out(m);
|
||||||
|
for (int j = 0; j < m; j++) {
|
||||||
|
for (int k = 0; k < m; k++) {
|
||||||
|
float x = *(data_ptr + i + k * n);
|
||||||
|
if (hmat_vec[k + j * m]) {
|
||||||
|
out[j] += x;
|
||||||
|
} else {
|
||||||
|
out[j] -= x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < m; j++) {
|
||||||
|
*(data_ptr + i + j * n) = out[j] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void hadamard(array& out, int n, int m, float scale) {
|
||||||
|
float n_scale = m > 1 ? 1.0 : scale;
|
||||||
|
hadamard_n<T>(out, n, m, n_scale);
|
||||||
|
if (m > 1) {
|
||||||
|
hadamard_m<T>(out, n, m, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
// Copy input to output
|
||||||
|
copy(in, out, CopyType::General);
|
||||||
|
|
||||||
|
int axis = out.ndim() - 1;
|
||||||
|
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case float32:
|
||||||
|
return hadamard<float>(out, n, m, scale_);
|
||||||
|
case float16:
|
||||||
|
return hadamard<float16_t>(out, n, m, scale_);
|
||||||
|
case bfloat16:
|
||||||
|
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
105
mlx/backend/common/hadamard.h
Normal file
105
mlx/backend/common/hadamard.h
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// From http://neilsloane.com/hadamard/
|
||||||
|
constexpr std::string_view h12 = R"(
|
||||||
|
+-++++++++++
|
||||||
|
--+-+-+-+-+-
|
||||||
|
+++-++----++
|
||||||
|
+---+--+-++-
|
||||||
|
+++++-++----
|
||||||
|
+-+---+--+-+
|
||||||
|
++--+++-++--
|
||||||
|
+--++---+--+
|
||||||
|
++----+++-++
|
||||||
|
+--+-++---+-
|
||||||
|
++++----+++-
|
||||||
|
+-+--+-++---
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view h20 = R"(
|
||||||
|
+----+----++--++-++-
|
||||||
|
-+----+---+++---+-++
|
||||||
|
--+----+---+++-+-+-+
|
||||||
|
---+----+---+++++-+-
|
||||||
|
----+----++--++-++-+
|
||||||
|
-+++++-----+--+++--+
|
||||||
|
+-+++-+---+-+--+++--
|
||||||
|
++-++--+---+-+--+++-
|
||||||
|
+++-+---+---+-+--+++
|
||||||
|
++++-----++--+-+--++
|
||||||
|
--++-+-++-+-----++++
|
||||||
|
---++-+-++-+---+-+++
|
||||||
|
+---++-+-+--+--++-++
|
||||||
|
++---++-+----+-+++-+
|
||||||
|
-++---++-+----+++++-
|
||||||
|
-+--+--++-+----+----
|
||||||
|
+-+-----++-+----+---
|
||||||
|
-+-+-+---+--+----+--
|
||||||
|
--+-+++------+----+-
|
||||||
|
+--+--++------+----+
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view h28 = R"(
|
||||||
|
+------++----++-+--+-+--++--
|
||||||
|
-+-----+++-----+-+--+-+--++-
|
||||||
|
--+-----+++---+-+-+----+--++
|
||||||
|
---+-----+++---+-+-+-+--+--+
|
||||||
|
----+-----+++---+-+-+++--+--
|
||||||
|
-----+-----++++--+-+--++--+-
|
||||||
|
------++----++-+--+-+--++--+
|
||||||
|
--++++-+-------++--+++-+--+-
|
||||||
|
---++++-+-----+-++--+-+-+--+
|
||||||
|
+---+++--+----++-++--+-+-+--
|
||||||
|
++---++---+----++-++--+-+-+-
|
||||||
|
+++---+----+----++-++--+-+-+
|
||||||
|
++++--------+-+--++-++--+-+-
|
||||||
|
-++++--------+++--++--+--+-+
|
||||||
|
-+-++-++--++--+--------++++-
|
||||||
|
+-+-++--+--++--+--------++++
|
||||||
|
-+-+-++--+--++--+----+---+++
|
||||||
|
+-+-+-++--+--+---+---++---++
|
||||||
|
++-+-+-++--+------+--+++---+
|
||||||
|
-++-+-+-++--+------+-++++---
|
||||||
|
+-++-+---++--+------+-++++--
|
||||||
|
-++--++-+-++-+++----++------
|
||||||
|
+-++--++-+-++-+++-----+-----
|
||||||
|
++-++---+-+-++-+++-----+----
|
||||||
|
-++-++-+-+-+-+--+++-----+---
|
||||||
|
--++-++++-+-+----+++-----+--
|
||||||
|
+--++-+-++-+-+----+++-----+-
|
||||||
|
++--++-+-++-+-+----++------+
|
||||||
|
)";
|
||||||
|
|
||||||
|
inline const std::map<int, std::string_view> hadamard_matrices() {
|
||||||
|
return {{12, h12}, {20, h20}, {28, h28}};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::pair<int, int> decompose_hadamard(int n) {
|
||||||
|
// n = m*2^k
|
||||||
|
int m = 1;
|
||||||
|
if (!is_power_of_2(n)) {
|
||||||
|
auto h_matrices = hadamard_matrices();
|
||||||
|
for (auto [factor, _] : h_matrices) {
|
||||||
|
if (n % factor == 0) {
|
||||||
|
m = factor;
|
||||||
|
n /= factor;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (m == 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {n, m};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -52,6 +52,7 @@ make_jit_source(
|
|||||||
)
|
)
|
||||||
make_jit_source(scatter)
|
make_jit_source(scatter)
|
||||||
make_jit_source(gather)
|
make_jit_source(gather)
|
||||||
|
make_jit_source(hadamard)
|
||||||
|
|
||||||
if (MLX_METAL_JIT)
|
if (MLX_METAL_JIT)
|
||||||
target_sources(
|
target_sources(
|
||||||
@ -132,6 +133,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
203
mlx/backend/metal/hadamard.cpp
Normal file
203
mlx/backend/metal/hadamard.cpp
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/common/hadamard.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||||
|
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||||
|
|
||||||
|
std::string gen_hadamard_codelet(int m) {
|
||||||
|
// Generate a O(m^2) hadamard codelet for a given M
|
||||||
|
// using the hadamard matrices above
|
||||||
|
//
|
||||||
|
// e.g. m = 2
|
||||||
|
// METAL_FUNC void hadamard_m(thread float *x) {
|
||||||
|
// float tmp[2];
|
||||||
|
// tmp[0] = + x[0] + x[1];
|
||||||
|
// tmp[1] = + x[0] - x[1];
|
||||||
|
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
auto h_matrices = hadamard_matrices();
|
||||||
|
auto& matrix = h_matrices[m];
|
||||||
|
|
||||||
|
std::ostringstream source;
|
||||||
|
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
|
||||||
|
if (m == 1) {
|
||||||
|
source << "}" << std::endl;
|
||||||
|
return source.str();
|
||||||
|
}
|
||||||
|
source << " float tmp[" << m << "];" << std::endl;
|
||||||
|
auto start = 1;
|
||||||
|
auto end = matrix.find('\n', start);
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
while (end != std::string_view::npos) {
|
||||||
|
source << " tmp[" << index << "] = ";
|
||||||
|
auto row = matrix.substr(start, end - start);
|
||||||
|
for (int i = 0; i < row.length(); i++) {
|
||||||
|
source << " " << row[i] << " x[" << i << "]";
|
||||||
|
}
|
||||||
|
source << ";" << std::endl;
|
||||||
|
start = end + 1;
|
||||||
|
end = matrix.find('\n', start);
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
|
||||||
|
<< std::endl;
|
||||||
|
source << "}" << std::endl;
|
||||||
|
return source.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_hadamard(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int batch_size,
|
||||||
|
int threads_per,
|
||||||
|
const std::string kernel_name,
|
||||||
|
float scale,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
const auto& lib_name = kernel_name.substr(1);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
// Only support the last axis for now
|
||||||
|
int axis = in.ndim() - 1;
|
||||||
|
auto check_input = [&copies, &s](const array& x) {
|
||||||
|
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||||
|
bool no_copy = x.flags().row_contiguous;
|
||||||
|
if (no_copy) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||||
|
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||||
|
return copies.back();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const array& in_contiguous = check_input(in);
|
||||||
|
|
||||||
|
if (in_contiguous.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in_contiguous);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [n, m] = decompose_hadamard(in.shape(axis));
|
||||||
|
|
||||||
|
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
int max_radix = std::min(n, 16);
|
||||||
|
// Use read_width 2 for m = 28 to avoid register spilling
|
||||||
|
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||||
|
auto kernel_name = kname.str();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
auto codelet = gen_hadamard_codelet(m);
|
||||||
|
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||||
|
kernel_source << get_template_definition(
|
||||||
|
"n" + kernel_name,
|
||||||
|
"hadamard_n",
|
||||||
|
get_type_string(in.dtype()),
|
||||||
|
n,
|
||||||
|
max_radix,
|
||||||
|
read_width);
|
||||||
|
kernel_source << get_template_definition(
|
||||||
|
"m" + kernel_name,
|
||||||
|
"hadamard_m",
|
||||||
|
get_type_string(in.dtype()),
|
||||||
|
n,
|
||||||
|
m,
|
||||||
|
read_width);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int batch_size = in.size() / n;
|
||||||
|
int threads_per = n / max_radix;
|
||||||
|
|
||||||
|
if (m > 1) {
|
||||||
|
// When m is greater than 1, we decompose the
|
||||||
|
// computation into two uploads to the GPU:
|
||||||
|
//
|
||||||
|
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||||
|
//
|
||||||
|
// y = h48 @ x
|
||||||
|
//
|
||||||
|
// Upload 1:
|
||||||
|
// tmp = a.reshape(12, 4) @ h4
|
||||||
|
//
|
||||||
|
// Upload 2:
|
||||||
|
// y = h12 @ tmp
|
||||||
|
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||||
|
copies.push_back(temp);
|
||||||
|
|
||||||
|
launch_hadamard(
|
||||||
|
in_contiguous,
|
||||||
|
temp,
|
||||||
|
batch_size,
|
||||||
|
threads_per,
|
||||||
|
"n" + kernel_name,
|
||||||
|
1.0,
|
||||||
|
s);
|
||||||
|
|
||||||
|
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||||
|
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||||
|
batch_size = in.size() / m / read_width / threads_per;
|
||||||
|
launch_hadamard(
|
||||||
|
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
||||||
|
} else {
|
||||||
|
launch_hadamard(
|
||||||
|
in_contiguous,
|
||||||
|
out,
|
||||||
|
batch_size,
|
||||||
|
threads_per,
|
||||||
|
"n" + kernel_name,
|
||||||
|
scale_,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -18,6 +18,7 @@ const char* binary();
|
|||||||
const char* binary_two();
|
const char* binary_two();
|
||||||
const char* copy();
|
const char* copy();
|
||||||
const char* fft();
|
const char* fft();
|
||||||
|
const char* hadamard();
|
||||||
const char* quantized();
|
const char* quantized();
|
||||||
const char* ternary();
|
const char* ternary();
|
||||||
const char* scan();
|
const char* scan();
|
||||||
|
167
mlx/backend/metal/kernels/hadamard.h
Normal file
167
mlx/backend/metal/kernels/hadamard.h
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include <metal_common>
|
||||||
|
#include <metal_compute>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
// Thread local Hadamard transform for 2^R
|
||||||
|
template <short R>
|
||||||
|
METAL_FUNC void radix_func(thread float* x) {
|
||||||
|
constexpr short logR = __builtin_ctz(R);
|
||||||
|
short h = 1;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short s = 0; s < logR; s++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < R / 2; i++) {
|
||||||
|
short k = i & (h - 1);
|
||||||
|
short j = ((i - k) << 1) + k;
|
||||||
|
float a = x[j];
|
||||||
|
float b = x[j + h];
|
||||||
|
x[j] = a + b;
|
||||||
|
x[j + h] = a - b;
|
||||||
|
}
|
||||||
|
h <<= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N, int max_radix, int read_width>
|
||||||
|
[[kernel]] void hadamard_n(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
constant const float& scale,
|
||||||
|
uint3 elem [[thread_position_in_grid]],
|
||||||
|
uint3 grid [[threads_per_grid]]) {
|
||||||
|
// Compute a Hadamard transform of size N = 2^k
|
||||||
|
//
|
||||||
|
// Equivalent to:
|
||||||
|
// from scipy.linalg import hadamard
|
||||||
|
// y = hadamard(len(x)) @ x
|
||||||
|
|
||||||
|
constexpr short num_threads = N / max_radix;
|
||||||
|
constexpr short logN = __builtin_ctz(N);
|
||||||
|
constexpr short logR = __builtin_ctz(max_radix);
|
||||||
|
constexpr short num_steps = logN / logR;
|
||||||
|
constexpr short logFinal = logN % logR;
|
||||||
|
constexpr short final_radix = 1 << (logFinal);
|
||||||
|
|
||||||
|
int batch_idx = elem.x * N;
|
||||||
|
short i = elem.y;
|
||||||
|
|
||||||
|
threadgroup T buf[N];
|
||||||
|
|
||||||
|
// Read values from device
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < max_radix / read_width; j++) {
|
||||||
|
short index = j * read_width * num_threads + i * read_width;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < read_width; r++) {
|
||||||
|
buf[index + r] = in[batch_idx + index + r];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float x[max_radix];
|
||||||
|
short h = 1;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short s = 0; s < num_steps; s++) {
|
||||||
|
short k = i & (h - 1);
|
||||||
|
short j = ((i - k) << logR) + k;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < max_radix; r++) {
|
||||||
|
x[r] = buf[j + h * r];
|
||||||
|
}
|
||||||
|
|
||||||
|
radix_func<max_radix>(x);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < max_radix; r++) {
|
||||||
|
buf[j + h * r] = x[r];
|
||||||
|
}
|
||||||
|
|
||||||
|
h <<= logR;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do the final radix
|
||||||
|
// e.g. max_radix = 16
|
||||||
|
// N = 1024 = 16 * 16 * 4
|
||||||
|
if (final_radix > 1) {
|
||||||
|
// Each thread does multiple butterflies
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int t = 0; t < max_radix / final_radix; t++) {
|
||||||
|
short index = i + t * num_threads;
|
||||||
|
short k = index & (h - 1);
|
||||||
|
short j = ((index - k) << logFinal) + k;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < final_radix; r++) {
|
||||||
|
x[r] = buf[j + h * r];
|
||||||
|
}
|
||||||
|
|
||||||
|
radix_func<final_radix>(x);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < final_radix; r++) {
|
||||||
|
buf[j + h * r] = x[r];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write values to device
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < max_radix / read_width; j++) {
|
||||||
|
short index = j * read_width * num_threads + i * read_width;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < read_width; r++) {
|
||||||
|
out[batch_idx + index + r] = buf[index + r] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N, int M, int read_width>
|
||||||
|
[[kernel]] void hadamard_m(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
constant const float& scale,
|
||||||
|
uint3 elem [[thread_position_in_grid]],
|
||||||
|
uint3 grid [[threads_per_grid]]) {
|
||||||
|
// Compute a Hadamard transform of size M
|
||||||
|
// using a naive O(M^2) codelet.
|
||||||
|
//
|
||||||
|
// This kernel is the second stage in the computation
|
||||||
|
// of a Hadamard transform of size M*N where N = 2^k.
|
||||||
|
|
||||||
|
int index = elem.x * grid.y + elem.y;
|
||||||
|
short i = index % (N / read_width);
|
||||||
|
int batch_idx = index / (N / read_width) * M * N;
|
||||||
|
|
||||||
|
float x[read_width][M];
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short c = 0; c < M; c++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < read_width; r++) {
|
||||||
|
x[r][c] = in[batch_idx + c * N + i * read_width + r];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < read_width; r++) {
|
||||||
|
// This function is JIT compiled for M
|
||||||
|
// using the Hadamard matrix strings in `metal/hadamard.cpp`
|
||||||
|
hadamard_radix_m(x[r]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write back to device
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short c = 0; c < M; c++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short r = 0; r < read_width; r++) {
|
||||||
|
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -130,17 +130,6 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_power_of_2(int n) {
|
|
||||||
return ((n & (n - 1)) == 0) && n != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int next_power_of_2(int n) {
|
|
||||||
if (is_power_of_2(n)) {
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
return pow(2, std::ceil(std::log2(n)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
std::string get_primitive_string(Primitive* primitive) {
|
||||||
std::ostringstream op_t;
|
std::ostringstream op_t;
|
||||||
primitive->print(op_t);
|
primitive->print(op_t);
|
||||||
|
@ -61,6 +61,7 @@ NO_CPU(GatherMM)
|
|||||||
NO_CPU(GatherQMM)
|
NO_CPU(GatherQMM)
|
||||||
NO_CPU(Greater)
|
NO_CPU(Greater)
|
||||||
NO_CPU(GreaterEqual)
|
NO_CPU(GreaterEqual)
|
||||||
|
NO_CPU(Hadamard)
|
||||||
NO_CPU(Less)
|
NO_CPU(Less)
|
||||||
NO_CPU(LessEqual)
|
NO_CPU(LessEqual)
|
||||||
NO_CPU(Load)
|
NO_CPU(Load)
|
||||||
|
@ -62,6 +62,7 @@ NO_GPU(GatherMM)
|
|||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Greater)
|
NO_GPU(Greater)
|
||||||
NO_GPU(GreaterEqual)
|
NO_GPU(GreaterEqual)
|
||||||
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Less)
|
NO_GPU(Less)
|
||||||
NO_GPU(LessEqual)
|
NO_GPU(LessEqual)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
|
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -451,6 +451,18 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return flatten(a, 0, a.ndim() - 1, s);
|
return flatten(a, 0, a.ndim() - 1, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array hadamard_transform(
|
||||||
|
const array& a,
|
||||||
|
float scale /* = 1.0 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||||
|
return array(
|
||||||
|
a.shape(),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<Hadamard>(to_stream(s), scale),
|
||||||
|
{astype(a, dtype, s)});
|
||||||
|
}
|
||||||
|
|
||||||
array squeeze(
|
array squeeze(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
|
@ -131,6 +131,12 @@ array flatten(
|
|||||||
/** Flatten the array to 1D. */
|
/** Flatten the array to 1D. */
|
||||||
array flatten(const array& a, StreamOrDevice s = {});
|
array flatten(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Multiply the array by the Hadamard matrix of corresponding size. */
|
||||||
|
array hadamard_transform(
|
||||||
|
const array& a,
|
||||||
|
float scale = 1.0f,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Remove singleton dimensions at the given axes. */
|
/** Remove singleton dimensions at the given axes. */
|
||||||
array squeeze(
|
array squeeze(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -3976,4 +3976,42 @@ bool View::is_equivalent(const Primitive& other) const {
|
|||||||
return (dtype_ == a_other.dtype_);
|
return (dtype_ == a_other.dtype_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
auto& s = stream();
|
||||||
|
if (axes[0] == inputs[0].ndim() - 1) {
|
||||||
|
auto a = moveaxis(inputs[0], axes[0], 0, s);
|
||||||
|
auto b = hadamard_transform(a, scale_, s);
|
||||||
|
return {{b}, {0}};
|
||||||
|
}
|
||||||
|
return {{hadamard_transform(inputs[0], scale_, s)}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Hadamard::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return jvp(primals, cotangents, argnums);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Hadamard::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {hadamard_transform(tangents[0], scale_, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Hadamard::is_equivalent(const Primitive& other) const {
|
||||||
|
const Hadamard& h_other = static_cast<const Hadamard&>(other);
|
||||||
|
return scale_ == h_other.scale_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1064,6 +1064,27 @@ class GreaterEqual : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Hadamard : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Hadamard(Stream stream, float scale)
|
||||||
|
: UnaryPrimitive(stream), scale_(scale) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Hadamard)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float scale_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Less : public UnaryPrimitive {
|
class Less : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
|
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
|
12
mlx/utils.h
12
mlx/utils.h
@ -118,4 +118,16 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
|||||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||||
return os << static_cast<float>(v);
|
return os << static_cast<float>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_power_of_2(int n) {
|
||||||
|
return ((n & (n - 1)) == 0) && n != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int next_power_of_2(int n) {
|
||||||
|
if (is_power_of_2(n)) {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
return pow(2, std::ceil(std::log2(n)));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -4372,6 +4372,35 @@ void init_ops(nb::module_& m) {
|
|||||||
a (array): Input array or scalar.
|
a (array): Input array or scalar.
|
||||||
dtype (Dtype): The data type to change to.
|
dtype (Dtype): The data type to change to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The array with the new type.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"hadamard_transform",
|
||||||
|
&hadamard_transform,
|
||||||
|
nb::arg(),
|
||||||
|
"scale"_a = 1.0,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Perform the Walsh-Hadamard transform along the final axis.
|
||||||
|
|
||||||
|
Equivalent to:
|
||||||
|
```python
|
||||||
|
from scipy.linalg import hadamard
|
||||||
|
|
||||||
|
y = hadamard(len(x)) @ x
|
||||||
|
```
|
||||||
|
|
||||||
|
Supports sizes `n = m*2^k` where m in (1, 12, 20, 28)
|
||||||
|
and 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array or scalar.
|
||||||
|
scale (float): Scale the output by this factor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The array with the new type.
|
array: The array with the new type.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
@ -2425,6 +2425,104 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
a_out = out.view(mx.int32)
|
a_out = out.view(mx.int32)
|
||||||
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
||||||
|
|
||||||
|
def _hadamard(self, N):
|
||||||
|
# Matches scipy.linalg.hadamard
|
||||||
|
H = np.array([[1]], dtype=np.int64)
|
||||||
|
for i in range(0, np.log2(N).astype(np.int64)):
|
||||||
|
H = np.vstack((np.hstack((H, H)), np.hstack((H, -H))))
|
||||||
|
return H
|
||||||
|
|
||||||
|
def test_hadamard(self):
|
||||||
|
h28_str = """
|
||||||
|
+------++----++-+--+-+--++--
|
||||||
|
-+-----+++-----+-+--+-+--++-
|
||||||
|
--+-----+++---+-+-+----+--++
|
||||||
|
---+-----+++---+-+-+-+--+--+
|
||||||
|
----+-----+++---+-+-+++--+--
|
||||||
|
-----+-----++++--+-+--++--+-
|
||||||
|
------++----++-+--+-+--++--+
|
||||||
|
--++++-+-------++--+++-+--+-
|
||||||
|
---++++-+-----+-++--+-+-+--+
|
||||||
|
+---+++--+----++-++--+-+-+--
|
||||||
|
++---++---+----++-++--+-+-+-
|
||||||
|
+++---+----+----++-++--+-+-+
|
||||||
|
++++--------+-+--++-++--+-+-
|
||||||
|
-++++--------+++--++--+--+-+
|
||||||
|
-+-++-++--++--+--------++++-
|
||||||
|
+-+-++--+--++--+--------++++
|
||||||
|
-+-+-++--+--++--+----+---+++
|
||||||
|
+-+-+-++--+--+---+---++---++
|
||||||
|
++-+-+-++--+------+--+++---+
|
||||||
|
-++-+-+-++--+------+-++++---
|
||||||
|
+-++-+---++--+------+-++++--
|
||||||
|
-++--++-+-++-+++----++------
|
||||||
|
+-++--++-+-++-+++-----+-----
|
||||||
|
++-++---+-+-++-+++-----+----
|
||||||
|
-++-++-+-+-+-+--+++-----+---
|
||||||
|
--++-++++-+-+----+++-----+--
|
||||||
|
+--++-+-++-+-+----+++-----+-
|
||||||
|
++--++-+-++-+-+----++------+
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_h_string(h_str):
|
||||||
|
return np.array(
|
||||||
|
[[1 if s == "+" else -1 for s in row] for row in h_str.split()]
|
||||||
|
)
|
||||||
|
|
||||||
|
h28 = parse_h_string(h28_str)
|
||||||
|
|
||||||
|
np.random.seed(7)
|
||||||
|
tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15))
|
||||||
|
for dtype, m, k in tests:
|
||||||
|
# skip large m=28 cases because they're very slow in NumPy
|
||||||
|
if (m > 1 and k > 8) or (dtype != np.float16 and k == 14):
|
||||||
|
continue
|
||||||
|
with self.subTest(dtype=dtype, m=m, k=k):
|
||||||
|
n = m * 2**k
|
||||||
|
b = 4
|
||||||
|
scale = 0.34
|
||||||
|
x = np.random.normal(size=(b, n)).astype(dtype)
|
||||||
|
# contiguity check
|
||||||
|
x = mx.array(x)[::2]
|
||||||
|
y = mx.hadamard_transform(x, scale=scale)
|
||||||
|
mx.eval(y)
|
||||||
|
h = (
|
||||||
|
self._hadamard(2**k)
|
||||||
|
if m == 1
|
||||||
|
else np.kron(h28, self._hadamard(2**k))
|
||||||
|
)
|
||||||
|
y_np = np.einsum("ij,bj->bi", h, x) * scale
|
||||||
|
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||||
|
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||||
|
|
||||||
|
def test_hadamard_grad_vmap(self):
|
||||||
|
np.random.seed(4)
|
||||||
|
|
||||||
|
for k in range(2, 8):
|
||||||
|
n = 2**k
|
||||||
|
x = np.random.normal(size=(n,))
|
||||||
|
h = self._hadamard(n)
|
||||||
|
c = np.random.normal(size=(n,))
|
||||||
|
x = mx.array(x).astype(mx.float32)
|
||||||
|
h = mx.array(h).astype(mx.float32)
|
||||||
|
c = mx.array(c).astype(mx.float32)
|
||||||
|
|
||||||
|
def hadamard_transform(x):
|
||||||
|
return h @ x
|
||||||
|
|
||||||
|
out = mx.vjp(hadamard_transform, [x], [c])
|
||||||
|
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||||
|
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||||
|
|
||||||
|
for axis in (0, 1, 2):
|
||||||
|
vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis)
|
||||||
|
vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis)
|
||||||
|
|
||||||
|
xb = mx.array(np.random.normal(size=(n, n, n)))
|
||||||
|
out = vht(xb)
|
||||||
|
out_t = vht_t(xb)
|
||||||
|
np.testing.assert_allclose(out, out_t, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user