mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 20:11:19 +08:00
fix integer overflow in qmm
This commit is contained in:
parent
87720a8908
commit
45df803538
@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
|
|
||||||
auto wl = (const device uint8_t*)w;
|
auto wl = (const device uint8_t*)w;
|
||||||
|
|
||||||
x += y_row * K;
|
x += y_row * static_cast<int64_t>(K);
|
||||||
wl += y_col * K_w;
|
wl += y_col * K_w;
|
||||||
scales += y_col * K_g;
|
scales += y_col * K_g;
|
||||||
biases += y_col * K_g;
|
biases += y_col * K_g;
|
||||||
y += y_row * N + y_col;
|
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||||
|
|
||||||
// Make the x loader and mma operation
|
// Make the x loader and mma operation
|
||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
// Set the block
|
// Set the block
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
x += y_row * K;
|
x += y_row * static_cast<int64_t>(K);
|
||||||
wl += y_col * bytes_per_pack / pack_factor;
|
wl += y_col * bytes_per_pack / pack_factor;
|
||||||
scales += y_col / group_size;
|
scales += y_col / group_size;
|
||||||
biases += y_col / group_size;
|
biases += y_col / group_size;
|
||||||
y += y_row * N + y_col;
|
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||||
|
|
||||||
// Make the x loader and mma operation
|
// Make the x loader and mma operation
|
||||||
const short num_els = min(BM, M - y_row);
|
const short num_els = min(BM, M - y_row);
|
||||||
|
Loading…
Reference in New Issue
Block a user