Fix mask broadcasting bug and add relevant test (#1003)

This commit is contained in:
Jagrit Digani
2024-04-17 17:33:48 -07:00
committed by GitHub
parent 581b699ac9
commit 85c8a91a27
3 changed files with 67 additions and 35 deletions

View File

@@ -65,6 +65,12 @@ template <typename T,
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if(has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch

View File

@@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
bool has_op_mask = inputs.size() > 3;
auto& out_mask = inputs[2];
std::vector<int> batch_shape{1};
int A_batch_str = 0;
int B_batch_str = 0;
std::vector<size_t> batch_strides;
if (out.ndim() > 2) {
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<std::vector<size_t>> bstrides;
for (auto& arr : inputs) {
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
}
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
batch_shape = bshape_c;
A_batch_str = int(bstrides_c[0].back());
B_batch_str = int(bstrides_c[1].back());
for (auto& bstr : bstrides_c) {
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
}
} else {
batch_strides = std::vector<size_t>(inputs.size(), 0);
}
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
@@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
<< (inputs.size() > 3 ? "T" : "N");
<< (has_op_mask ? "T" : "N");
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_a = */ A_batch_str,
/* const int batch_stride_b = */ B_batch_str,
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
@@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
std::vector<int> mask_strides;
auto& out_mask = inputs[2];
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
out_mask.strides().begin(),
out_mask.strides().end() - 2);
if (inputs.size() > 3) {
if (has_op_mask) {
auto& lhs_mask = inputs[3];
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
lhs_mask.strides().begin(),
lhs_mask.strides().end() - 2);
compute_encoder.set_input_array(lhs_mask, 11);
auto& rhs_mask = inputs[4];
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
rhs_mask.strides().begin(),
rhs_mask.strides().end() - 2);
compute_encoder.set_input_array(rhs_mask, 12);
}