Fix batched qmv bug (#1758)

This commit is contained in:
Alex Barron
2025-01-09 11:45:57 -08:00
committed by GitHub
parent da8c885784
commit c7b0300af5
2 changed files with 22 additions and 13 deletions

View File

@@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
@@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
@@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(
x,
w,
@@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
lhs_indices,
rhs_indices,
y,
out_vec_size,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,