mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -80,7 +80,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < max_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
buf[j + h * r] = T(x[r]);
|
||||
}
|
||||
|
||||
h <<= logR;
|
||||
@@ -106,7 +106,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < final_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
buf[j + h * r] = T(x[r]);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -118,7 +118,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = buf[index + r] * scale;
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ template <typename T, int N, int M, int read_width>
|
||||
for (short c = 0; c < M; c++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
|
||||
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array hadamard_transform(
|
||||
const array& a,
|
||||
float scale /* = 1.0 */,
|
||||
std::optional<float> scale_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
|
||||
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
|
||||
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||
return array(
|
||||
a.shape(),
|
||||
|
@@ -134,7 +134,7 @@ array flatten(const array& a, StreamOrDevice s = {});
|
||||
/** Multiply the array by the Hadamard matrix of corresponding size. */
|
||||
array hadamard_transform(
|
||||
const array& a,
|
||||
float scale = 1.0f,
|
||||
std::optional<float> scale = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Remove singleton dimensions at the given axes. */
|
||||
|
Reference in New Issue
Block a user