Fix bfloat16 Hadamard (#1283)

* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-07-23 14:54:43 -07:00
committed by GitHub
parent e2aa6ec8ae
commit c34a5ae7f7
5 changed files with 20 additions and 10 deletions

View File

@@ -80,7 +80,7 @@ template <typename T, int N, int max_radix, int read_width>
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = x[r];
buf[j + h * r] = T(x[r]);
}
h <<= logR;
@@ -106,7 +106,7 @@ template <typename T, int N, int max_radix, int read_width>
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = x[r];
buf[j + h * r] = T(x[r]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -118,7 +118,7 @@ template <typename T, int N, int max_radix, int read_width>
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = buf[index + r] * scale;
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
}
@@ -161,7 +161,7 @@ template <typename T, int N, int M, int read_width>
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
}
}
}