mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Rename block sparse (#1149)
* block_sparse_mm to gather_mm * rename * nit * nit
This commit is contained in:
23
mlx/ops.cpp
23
mlx/ops.cpp
@@ -3491,7 +3491,7 @@ array dequantize(
|
||||
return w_full;
|
||||
}
|
||||
|
||||
array block_sparse_qmm(
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
@@ -3508,7 +3508,7 @@ array block_sparse_qmm(
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@@ -3529,8 +3529,7 @@ array block_sparse_qmm(
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<BlockSparseQMM>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
@@ -3937,7 +3936,7 @@ array block_masked_mm(
|
||||
}
|
||||
|
||||
/** Compute matrix product with matrix-level gather */
|
||||
array block_sparse_mm(
|
||||
array gather_mm(
|
||||
array a,
|
||||
array b,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
@@ -3954,7 +3953,7 @@ array block_sparse_mm(
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got 0 dimension input. Inputs must "
|
||||
"[gather_mm] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
@@ -3969,8 +3968,8 @@ array block_sparse_mm(
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Last dimension of first input with shape "
|
||||
<< a.shape() << " must match second to last dimension of"
|
||||
msg << "[gather_mm] Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -3979,7 +3978,7 @@ array block_sparse_mm(
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Only real floating point types are supported but "
|
||||
msg << "[gather_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
@@ -3995,12 +3994,12 @@ array block_sparse_mm(
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
"[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
"[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
@@ -4024,7 +4023,7 @@ array block_sparse_mm(
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BlockSparseMM>(to_stream(s)),
|
||||
std::make_shared<GatherMM>(to_stream(s)),
|
||||
{a, b, lhs_indices, rhs_indices});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
|
||||
Reference in New Issue
Block a user