Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -80,7 +80,7 @@ template <typename T, int N, int max_radix, int read_width>
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = x[r];
buf[j + h * r] = T(x[r]);
}
h <<= logR;
@@ -106,7 +106,7 @@ template <typename T, int N, int max_radix, int read_width>
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = x[r];
buf[j + h * r] = T(x[r]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -118,7 +118,7 @@ template <typename T, int N, int max_radix, int read_width>
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = buf[index + r] * scale;
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
}
@@ -161,7 +161,7 @@ template <typename T, int N, int M, int read_width>
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
}
}
}

View File

@@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
array hadamard_transform(
const array& a,
float scale /* = 1.0 */,
std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
return array(
a.shape(),

View File

@@ -134,7 +134,7 @@ array flatten(const array& a, StreamOrDevice s = {});
/** Multiply the array by the Hadamard matrix of corresponding size. */
array hadamard_transform(
const array& a,
float scale = 1.0f,
std::optional<float> scale = std::nullopt,
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */