Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun
2024-05-22 07:48:34 -07:00
committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
16 changed files with 120 additions and 111 deletions

View File

@@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) {
}
void init_ops(nb::module_& m) {
// TODO, remove deprecation errors in a future release
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
});
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
});
m.def(
"reshape",
&reshape,
@@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) {
array: The dequantized version of ``w``
)pbdoc");
m.def(
"block_sparse_qmm",
&block_sparse_qmm,
"gater_qmm",
&gather_qmm,
nb::arg(),
nb::arg(),
"scales"_a,
@@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
This operation is the quantized equivalent to :func:`block_sparse_mm`.
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
This operation is the quantized equivalent to :func:`gather_mm`.
Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
all but the last two dimensions) of ``x`` and ``w`` respectively.
@@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"block_sparse_mm",
&block_sparse_mm,
"gather_mm",
&gather_mm,
nb::arg(),
nb::arg(),
"lhs_indices"_a = nb::none(),
@@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
Performs a gather of the operands with the given indices followed by a
(possibly batched) matrix multiplication of two arrays. This operation
is more efficient than explicitly applying a :func:`take` followed by a
:func:`matmul`.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
along the batch dimensions (i.e. all but the last two dimensions) of
``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args:
a (array): Input array.