Fix overflow in quantize/dequantize (#1379)

* add 2d indices to prevent overflow

* use nthreads not out size
This commit is contained in:
Alex Barron 2024-08-30 13:32:41 -07:00 committed by GitHub
parent 1600092e92
commit da691257ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 13 deletions

View File

@ -1460,7 +1460,8 @@ template <typename T, const int group_size, const int bits>
device uint8_t* out [[buffer(1)]], device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]], device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]], device T* biases [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr T eps = T(1e-7); constexpr T eps = T(1e-7);
constexpr int simd_size = 32; constexpr int simd_size = 32;
constexpr int uint8_bits = 8; constexpr int uint8_bits = 8;
@ -1475,8 +1476,9 @@ template <typename T, const int group_size, const int bits>
group_size % simd_size == 0, group_size % simd_size == 0,
"Group size must be divisible by simd size."); "Group size must be divisible by simd size.");
int in_index = index * values_per_reduce; size_t offset = index.x + grid_dim.x * size_t(index.y);
int out_index = index * writes_per_pack; size_t in_index = offset * values_per_reduce;
size_t out_index = offset * writes_per_pack;
T w_thread[values_per_reduce]; T w_thread[values_per_reduce];
T w_min = Limits<T>::max; T w_min = Limits<T>::max;
@ -1503,7 +1505,7 @@ template <typename T, const int group_size, const int bits>
T bias = at_zero ? T(0) : edge; T bias = at_zero ? T(0) : edge;
// Write out the scales and biases // Write out the scales and biases
int gindex = in_index / group_size; size_t gindex = in_index / group_size;
if (in_index % group_size == 0) { if (in_index % group_size == 0) {
scales[gindex] = scale; scales[gindex] = scale;
biases[gindex] = bias; biases[gindex] = bias;
@ -1542,13 +1544,16 @@ template <typename T, const int group_size, const int bits>
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]], device uint8_t* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8; constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits; constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1; constexpr T n_bins = (1 << bits) - 1;
int in_index = index * packs_per_int; size_t offset = index.x + grid_dim.x * size_t(index.y);
int gindex = in_index / group_size; size_t in_index = offset * packs_per_int;
size_t gindex = in_index / group_size;
T scale = scales[gindex]; T scale = scales[gindex];
T bias = biases[gindex]; T bias = biases[gindex];
@ -1562,7 +1567,7 @@ template <typename T, const int group_size, const int bits>
output += val << (bits * i); output += val << (bits * i);
} }
} }
out[index] = output; out[offset] = output;
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
@ -1571,15 +1576,17 @@ template <typename T, const int group_size, const int bits>
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int uint8_bits = 8; constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits; constexpr int packs_per_int = uint8_bits / bits;
int oindex = index * packs_per_int; size_t offset = index.x + grid_dim.x * size_t(index.y);
int gindex = oindex / group_size; size_t oindex = offset * packs_per_int;
size_t gindex = oindex / group_size;
T scale = scales[gindex]; T scale = scales[gindex];
T bias = biases[gindex]; T bias = biases[gindex];
uint val = w[index]; uint val = w[offset];
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) { for (int i = 0; i < packs_per_int; i++) {

View File

@ -584,8 +584,19 @@ void fast::AffineQuantize::eval_gpu(
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread; dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
auto group_dims = MTL::Size(thread_group_size, 1, 1); auto group_dims = MTL::Size(thread_group_size, 1, 1);
auto grid_dims = MTL::Size(nthreads, 1, 1); bool use_2d = nthreads > UINT_MAX;
auto grid_shape = w.shape();
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(