Fix mask broadcasting bug and add relevant test (#1003)

This commit is contained in:
Jagrit Digani 2024-04-17 17:33:48 -07:00 committed by GitHub
parent 581b699ac9
commit 85c8a91a27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 35 deletions

View File

@ -65,6 +65,12 @@ template <typename T,
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if(has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch

View File

@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
bool has_op_mask = inputs.size() > 3;
auto& out_mask = inputs[2];
std::vector<int> batch_shape{1};
int A_batch_str = 0;
int B_batch_str = 0;
std::vector<size_t> batch_strides;
if (out.ndim() > 2) {
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
std::vector<std::vector<size_t>> bstrides;
for (auto& arr : inputs) {
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
}
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
batch_shape = bshape_c;
A_batch_str = int(bstrides_c[0].back());
B_batch_str = int(bstrides_c[1].back());
for (auto& bstr : bstrides_c) {
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
}
} else {
batch_strides = std::vector<size_t>(inputs.size(), 0);
}
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
<< (inputs.size() > 3 ? "T" : "N");
<< (has_op_mask ? "T" : "N");
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_a = */ A_batch_str,
/* const int batch_stride_b = */ B_batch_str,
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
std::vector<int> mask_strides;
auto& out_mask = inputs[2];
mask_strides.push_back(*(out_mask.strides().end() - 1));
mask_strides.push_back(*(out_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
out_mask.strides().begin(),
out_mask.strides().end() - 2);
if (inputs.size() > 3) {
if (has_op_mask) {
auto& lhs_mask = inputs[3];
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
lhs_mask.strides().begin(),
lhs_mask.strides().end() - 2);
compute_encoder.set_input_array(lhs_mask, 11);
auto& rhs_mask = inputs[4];
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
batch_strides.insert(
batch_strides.end(),
rhs_mask.strides().begin(),
rhs_mask.strides().end() - 2);
compute_encoder.set_input_array(rhs_mask, 12);
}

View File

@ -717,7 +717,16 @@ class TestBlas(mlx_tests.MLXTestCase):
out = out * out_mask
return out
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
def test_shape(
M,
N,
K,
block_size,
transpose=False,
np_dtype=np.float32,
batch_A=(),
batch_B=(),
):
with self.subTest(
M=M,
N=N,
@ -725,28 +734,32 @@ class TestBlas(mlx_tests.MLXTestCase):
block_size=block_size,
np_dtype=np_dtype,
transpose=transpose,
batch_A=batch_A,
batch_B=batch_B,
):
tm = (M + block_size - 1) // block_size
tn = (N + block_size - 1) // block_size
tk = (K + block_size - 1) // block_size
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
batch_out = np.broadcast_shapes(batch_A, batch_B)
a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0
b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0
out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
)
if transpose:
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype)
b_mx = mx.array(b_np)
b_np = b_np.T
b_mx = b_mx.T
b_np = np.swapaxes(b_np, -2, -1)
b_mx = mx.swapaxes(b_mx, -2, -1)
out_np = np_block_masked_mm(
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
@ -779,6 +792,9 @@ class TestBlas(mlx_tests.MLXTestCase):
test_shape(M, N, K, block_size, transpose=False)
test_shape(M, N, K, block_size, transpose=True)
# Test broadcasting
test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2))
# Test gemv
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
b_np = np.random.normal(size=(64,)).astype(np.float32)