mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix overflow in quantize/dequantize (#1379)
* add 2d indices to prevent overflow * use nthreads not out size
This commit is contained in:
@@ -584,8 +584,19 @@ void fast::AffineQuantize::eval_gpu(
|
||||
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
auto group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
bool use_2d = nthreads > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
if (dequantize_) {
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
} else {
|
||||
grid_shape.back() /= per_thread;
|
||||
}
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
||||
Reference in New Issue
Block a user