mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix mask broadcasting bug and add relevant test (#1003)
This commit is contained in:
parent
581b699ac9
commit
85c8a91a27
@ -65,6 +65,12 @@ template <typename T,
|
|||||||
lhs_mask += batch_offsets.x;
|
lhs_mask += batch_offsets.x;
|
||||||
rhs_mask += batch_offsets.y;
|
rhs_mask += batch_offsets.y;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||||
|
if(has_operand_mask) {
|
||||||
|
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||||
|
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust for batch
|
// Adjust for batch
|
||||||
|
@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
bool has_op_mask = inputs.size() > 3;
|
||||||
|
auto& out_mask = inputs[2];
|
||||||
|
|
||||||
|
std::vector<int> batch_shape{1};
|
||||||
|
int A_batch_str = 0;
|
||||||
|
int B_batch_str = 0;
|
||||||
|
|
||||||
|
std::vector<size_t> batch_strides;
|
||||||
|
|
||||||
|
if (out.ndim() > 2) {
|
||||||
|
auto get_batch_dims = [](const auto& v) {
|
||||||
|
return decltype(v){v.begin(), v.end() - 2};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
|
||||||
|
std::vector<std::vector<size_t>> bstrides;
|
||||||
|
|
||||||
|
for (auto& arr : inputs) {
|
||||||
|
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||||
|
batch_shape = bshape_c;
|
||||||
|
A_batch_str = int(bstrides_c[0].back());
|
||||||
|
B_batch_str = int(bstrides_c[1].back());
|
||||||
|
|
||||||
|
for (auto& bstr : bstrides_c) {
|
||||||
|
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
auto batch_size_out = out.size() / (M * N);
|
auto batch_size_out = out.size() / (M * N);
|
||||||
int matrix_stride_out = M * N;
|
int matrix_stride_out = M * N;
|
||||||
@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||||
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
|
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
|
||||||
<< (inputs.size() > 3 ? "T" : "N");
|
<< (has_op_mask ? "T" : "N");
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* const int ldd = */ N,
|
/* const int ldd = */ N,
|
||||||
/* const int tiles_n = */ tn,
|
/* const int tiles_n = */ tn,
|
||||||
/* const int tiles_m = */ tm,
|
/* const int tiles_m = */ tm,
|
||||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
/* const int batch_stride_a = */ A_batch_str,
|
||||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
/* const int batch_stride_b = */ B_batch_str,
|
||||||
/* const int batch_stride_d = */ matrix_stride_out,
|
/* const int batch_stride_d = */ matrix_stride_out,
|
||||||
/* const int swizzle_log = */ swizzle_log,
|
/* const int swizzle_log = */ swizzle_log,
|
||||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||||
@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||||
|
|
||||||
std::vector<size_t> batch_strides = A_batch_stride;
|
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
|
||||||
|
|
||||||
std::vector<int> mask_strides;
|
std::vector<int> mask_strides;
|
||||||
|
|
||||||
auto& out_mask = inputs[2];
|
|
||||||
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
||||||
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
||||||
|
|
||||||
batch_strides.insert(
|
if (has_op_mask) {
|
||||||
batch_strides.end(),
|
|
||||||
out_mask.strides().begin(),
|
|
||||||
out_mask.strides().end() - 2);
|
|
||||||
|
|
||||||
if (inputs.size() > 3) {
|
|
||||||
auto& lhs_mask = inputs[3];
|
auto& lhs_mask = inputs[3];
|
||||||
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
||||||
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
||||||
|
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(),
|
|
||||||
lhs_mask.strides().begin(),
|
|
||||||
lhs_mask.strides().end() - 2);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(lhs_mask, 11);
|
compute_encoder.set_input_array(lhs_mask, 11);
|
||||||
|
|
||||||
auto& rhs_mask = inputs[4];
|
auto& rhs_mask = inputs[4];
|
||||||
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
||||||
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
||||||
|
|
||||||
batch_strides.insert(
|
|
||||||
batch_strides.end(),
|
|
||||||
rhs_mask.strides().begin(),
|
|
||||||
rhs_mask.strides().end() - 2);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(rhs_mask, 12);
|
compute_encoder.set_input_array(rhs_mask, 12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -717,7 +717,16 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
out = out * out_mask
|
out = out * out_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
|
def test_shape(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_size,
|
||||||
|
transpose=False,
|
||||||
|
np_dtype=np.float32,
|
||||||
|
batch_A=(),
|
||||||
|
batch_B=(),
|
||||||
|
):
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
M=M,
|
M=M,
|
||||||
N=N,
|
N=N,
|
||||||
@ -725,28 +734,32 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
np_dtype=np_dtype,
|
np_dtype=np_dtype,
|
||||||
transpose=transpose,
|
transpose=transpose,
|
||||||
|
batch_A=batch_A,
|
||||||
|
batch_B=batch_B,
|
||||||
):
|
):
|
||||||
tm = (M + block_size - 1) // block_size
|
tm = (M + block_size - 1) // block_size
|
||||||
tn = (N + block_size - 1) // block_size
|
tn = (N + block_size - 1) // block_size
|
||||||
tk = (K + block_size - 1) // block_size
|
tk = (K + block_size - 1) // block_size
|
||||||
|
|
||||||
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
|
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||||
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
|
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||||
|
|
||||||
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
|
batch_out = np.broadcast_shapes(batch_A, batch_B)
|
||||||
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
|
|
||||||
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
|
a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0
|
||||||
|
b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0
|
||||||
|
out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0
|
||||||
|
|
||||||
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
|
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
|
||||||
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
|
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
|
||||||
)
|
)
|
||||||
|
|
||||||
if transpose:
|
if transpose:
|
||||||
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
|
b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype)
|
||||||
b_mx = mx.array(b_np)
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
b_np = b_np.T
|
b_np = np.swapaxes(b_np, -2, -1)
|
||||||
b_mx = b_mx.T
|
b_mx = mx.swapaxes(b_mx, -2, -1)
|
||||||
|
|
||||||
out_np = np_block_masked_mm(
|
out_np = np_block_masked_mm(
|
||||||
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
|
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
|
||||||
@ -779,6 +792,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
test_shape(M, N, K, block_size, transpose=False)
|
test_shape(M, N, K, block_size, transpose=False)
|
||||||
test_shape(M, N, K, block_size, transpose=True)
|
test_shape(M, N, K, block_size, transpose=True)
|
||||||
|
|
||||||
|
# Test broadcasting
|
||||||
|
test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2))
|
||||||
|
|
||||||
# Test gemv
|
# Test gemv
|
||||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||||
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user