mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -80,7 +80,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < max_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
buf[j + h * r] = T(x[r]);
|
||||
}
|
||||
|
||||
h <<= logR;
|
||||
@@ -106,7 +106,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < final_radix; r++) {
|
||||
buf[j + h * r] = x[r];
|
||||
buf[j + h * r] = T(x[r]);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -118,7 +118,7 @@ template <typename T, int N, int max_radix, int read_width>
|
||||
short index = j * read_width * num_threads + i * read_width;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + index + r] = buf[index + r] * scale;
|
||||
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ template <typename T, int N, int M, int read_width>
|
||||
for (short c = 0; c < M; c++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short r = 0; r < read_width; r++) {
|
||||
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
|
||||
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user