mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Block sparse mm (#1058)
This commit is contained in:
118
mlx/ops.cpp
118
mlx/ops.cpp
@@ -3785,6 +3785,124 @@ array block_masked_mm(
|
||||
return out;
|
||||
}
|
||||
|
||||
/** Compute matrix product with matrix-level gather */
|
||||
array block_sparse_mm(
|
||||
array a,
|
||||
array b,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no indices, fall back to full matmul
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return matmul(a, b, s);
|
||||
}
|
||||
|
||||
// Do shape checks for operands
|
||||
int in_a_ndim = a.ndim();
|
||||
int in_b_ndim = b.ndim();
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Last dimension of first input with shape "
|
||||
<< a.shape() << " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto indices_or_default = [&](const std::optional<array>& indices,
|
||||
const std::vector<int>& bsx_shape) {
|
||||
if (indices.has_value()) {
|
||||
return indices.value();
|
||||
} else {
|
||||
int n_batch = 1;
|
||||
for (auto& i : bsx_shape)
|
||||
n_batch *= i;
|
||||
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
|
||||
}
|
||||
};
|
||||
|
||||
// Pull and broadcast indices
|
||||
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
rhs_indices = astype(rhs_indices, uint32, s);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
|
||||
auto out_shape = out_bsx_shape;
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
// Caculate array
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BlockSparseMM>(to_stream(s)),
|
||||
{a, b, lhs_indices, rhs_indices});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out_shape.erase(
|
||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
||||
out = reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
||||
Reference in New Issue
Block a user