mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix overflow in quantize/dequantize (#1379)
* add 2d indices to prevent overflow * use nthreads not out size
This commit is contained in:
@@ -1460,7 +1460,8 @@ template <typename T, const int group_size, const int bits>
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int uint8_bits = 8;
|
||||
@@ -1475,8 +1476,9 @@ template <typename T, const int group_size, const int bits>
|
||||
group_size % simd_size == 0,
|
||||
"Group size must be divisible by simd size.");
|
||||
|
||||
int in_index = index * values_per_reduce;
|
||||
int out_index = index * writes_per_pack;
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t in_index = offset * values_per_reduce;
|
||||
size_t out_index = offset * writes_per_pack;
|
||||
|
||||
T w_thread[values_per_reduce];
|
||||
T w_min = Limits<T>::max;
|
||||
@@ -1503,7 +1505,7 @@ template <typename T, const int group_size, const int bits>
|
||||
T bias = at_zero ? T(0) : edge;
|
||||
|
||||
// Write out the scales and biases
|
||||
int gindex = in_index / group_size;
|
||||
size_t gindex = in_index / group_size;
|
||||
if (in_index % group_size == 0) {
|
||||
scales[gindex] = scale;
|
||||
biases[gindex] = bias;
|
||||
@@ -1542,13 +1544,16 @@ template <typename T, const int group_size, const int bits>
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device uint8_t* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
|
||||
int in_index = index * packs_per_int;
|
||||
int gindex = in_index / group_size;
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t in_index = offset * packs_per_int;
|
||||
size_t gindex = in_index / group_size;
|
||||
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
|
||||
@@ -1562,7 +1567,7 @@ template <typename T, const int group_size, const int bits>
|
||||
output += val << (bits * i);
|
||||
}
|
||||
}
|
||||
out[index] = output;
|
||||
out[offset] = output;
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
@@ -1571,15 +1576,17 @@ template <typename T, const int group_size, const int bits>
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
|
||||
int oindex = index * packs_per_int;
|
||||
int gindex = oindex / group_size;
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t oindex = offset * packs_per_int;
|
||||
size_t gindex = oindex / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
uint val = w[index];
|
||||
uint val = w[offset];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
|
||||
Reference in New Issue
Block a user