Fix gemv broadcasting bug (#6)

* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
This commit is contained in:
Jagrit Digani
2023-12-05 14:15:43 -08:00
committed by GitHub
parent 49cda449b1
commit d518b3b6a5
4 changed files with 492 additions and 203 deletions

View File

@@ -343,10 +343,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
int batch_size_vec = vec.data_size() / in_vector_len;
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel =
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
(batch_size_mat == batch_size_vec ||
std::min(batch_size_mat, batch_size_vec) == 1));
int nc_dim = out.ndim() - 2;
// Determine dispatch kernel
int tm = 4, tn = 4;
@@ -383,6 +391,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
if (!contiguous_kernel) {
kname << "_nc";
}
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
@@ -398,8 +410,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
if (contiguous_kernel) {
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
} else {
// In case of complex broadcasting, we consider the shape[:-2] and
// strides [:-2] to determine the location of a batch
// nc_dim = out.ndim() - 2
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(
vec.strides().data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(
mat.strides().data(), nc_dim * sizeof(size_t), 8);
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(