mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
e2aa6ec8ae
commit
c34a5ae7f7
@ -80,7 +80,7 @@ template <typename T, int N, int max_radix, int read_width>
|
|||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short r = 0; r < max_radix; r++) {
|
for (short r = 0; r < max_radix; r++) {
|
||||||
buf[j + h * r] = x[r];
|
buf[j + h * r] = T(x[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
h <<= logR;
|
h <<= logR;
|
||||||
@ -106,7 +106,7 @@ template <typename T, int N, int max_radix, int read_width>
|
|||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short r = 0; r < final_radix; r++) {
|
for (short r = 0; r < final_radix; r++) {
|
||||||
buf[j + h * r] = x[r];
|
buf[j + h * r] = T(x[r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -118,7 +118,7 @@ template <typename T, int N, int max_radix, int read_width>
|
|||||||
short index = j * read_width * num_threads + i * read_width;
|
short index = j * read_width * num_threads + i * read_width;
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short r = 0; r < read_width; r++) {
|
for (short r = 0; r < read_width; r++) {
|
||||||
out[batch_idx + index + r] = buf[index + r] * scale;
|
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,7 +161,7 @@ template <typename T, int N, int M, int read_width>
|
|||||||
for (short c = 0; c < M; c++) {
|
for (short c = 0; c < M; c++) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short r = 0; r < read_width; r++) {
|
for (short r = 0; r < read_width; r++) {
|
||||||
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
|
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array hadamard_transform(
|
array hadamard_transform(
|
||||||
const array& a,
|
const array& a,
|
||||||
float scale /* = 1.0 */,
|
std::optional<float> scale_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
|
||||||
|
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
|
||||||
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
@ -134,7 +134,7 @@ array flatten(const array& a, StreamOrDevice s = {});
|
|||||||
/** Multiply the array by the Hadamard matrix of corresponding size. */
|
/** Multiply the array by the Hadamard matrix of corresponding size. */
|
||||||
array hadamard_transform(
|
array hadamard_transform(
|
||||||
const array& a,
|
const array& a,
|
||||||
float scale = 1.0f,
|
std::optional<float> scale = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Remove singleton dimensions at the given axes. */
|
/** Remove singleton dimensions at the given axes. */
|
||||||
|
@ -4379,11 +4379,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"hadamard_transform",
|
"hadamard_transform",
|
||||||
&hadamard_transform,
|
&hadamard_transform,
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scale"_a = 1.0,
|
"scale"_a = nb::none(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def hadamard_transform(a: array, float scale = 1.0, stream: Union[None, Stream, Device] = None) -> array"),
|
"def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform the Walsh-Hadamard transform along the final axis.
|
Perform the Walsh-Hadamard transform along the final axis.
|
||||||
|
|
||||||
@ -4393,7 +4393,7 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
from scipy.linalg import hadamard
|
from scipy.linalg import hadamard
|
||||||
|
|
||||||
y = hadamard(len(x)) @ x
|
y = (hadamard(len(x)) @ x) * scale
|
||||||
|
|
||||||
Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k
|
Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k
|
||||||
<= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.
|
<= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.
|
||||||
@ -4401,6 +4401,7 @@ void init_ops(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
a (array): Input array or scalar.
|
a (array): Input array or scalar.
|
||||||
scale (float): Scale the output by this factor.
|
scale (float): Scale the output by this factor.
|
||||||
|
Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The transformed array.
|
array: The transformed array.
|
||||||
|
@ -2496,6 +2496,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
atol = 2e-4 if dtype == np.float32 else 5e-2 * k
|
||||||
np.testing.assert_allclose(y, y_np, atol=atol)
|
np.testing.assert_allclose(y, y_np, atol=atol)
|
||||||
|
|
||||||
|
# bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory
|
||||||
|
if dtype == np.float16 and k < 14:
|
||||||
|
y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
y_bf16.astype(mx.float16), y, atol=atol * 2
|
||||||
|
)
|
||||||
|
|
||||||
def test_hadamard_grad_vmap(self):
|
def test_hadamard_grad_vmap(self):
|
||||||
np.random.seed(4)
|
np.random.seed(4)
|
||||||
|
|
||||||
@ -2509,7 +2516,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
c = mx.array(c).astype(mx.float32)
|
c = mx.array(c).astype(mx.float32)
|
||||||
|
|
||||||
def hadamard_transform(x):
|
def hadamard_transform(x):
|
||||||
return h @ x
|
return h @ x / mx.sqrt(x.shape[-1])
|
||||||
|
|
||||||
out = mx.vjp(hadamard_transform, [x], [c])
|
out = mx.vjp(hadamard_transform, [x], [c])
|
||||||
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
out_t = mx.vjp(mx.hadamard_transform, [x], [c])
|
||||||
|
Loading…
Reference in New Issue
Block a user