Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun
2024-05-22 07:48:34 -07:00
committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
16 changed files with 120 additions and 111 deletions

View File

@@ -3491,7 +3491,7 @@ array dequantize(
return w_full;
}
array block_sparse_qmm(
array gather_qmm(
const array& x,
const array& w,
const array& scales,
@@ -3508,7 +3508,7 @@ array block_sparse_qmm(
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
@@ -3529,8 +3529,7 @@ array block_sparse_qmm(
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<BlockSparseQMM>(
to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),
@@ -3937,7 +3936,7 @@ array block_masked_mm(
}
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array gather_mm(
array a,
array b,
std::optional<array> lhs_indices_ /* = std::nullopt */,
@@ -3954,7 +3953,7 @@ array block_sparse_mm(
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
"[block_sparse_mm] Got 0 dimension input. Inputs must "
"[gather_mm] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
@@ -3969,8 +3968,8 @@ array block_sparse_mm(
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
msg << "[gather_mm] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
@@ -3979,7 +3978,7 @@ array block_sparse_mm(
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Only real floating point types are supported but "
msg << "[gather_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
@@ -3995,12 +3994,12 @@ array block_sparse_mm(
if (!issubdtype(lhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
"[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
}
if (!issubdtype(rhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
"[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
}
lhs_indices = astype(lhs_indices, uint32, s);
@@ -4024,7 +4023,7 @@ array block_sparse_mm(
auto out = array(
out_shape,
out_type,
std::make_shared<BlockSparseMM>(to_stream(s)),
std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions