mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv * Add relevant tests in test_blas.py
This commit is contained in:
@@ -343,10 +343,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
|
||||
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
|
||||
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel =
|
||||
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
|
||||
(batch_size_mat == batch_size_vec ||
|
||||
std::min(batch_size_mat, batch_size_vec) == 1));
|
||||
|
||||
int nc_dim = out.ndim() - 2;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
@@ -383,6 +391,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
if (!contiguous_kernel) {
|
||||
kname << "_nc";
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
@@ -398,8 +410,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
|
||||
if (contiguous_kernel) {
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
} else {
|
||||
// In case of complex broadcasting, we consider the shape[:-2] and
|
||||
// strides [:-2] to determine the location of a batch
|
||||
// nc_dim = out.ndim() - 2
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
vec.strides().data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
mat.strides().data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
||||
Reference in New Issue
Block a user