mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 10:18:10 +08:00
Rename block sparse (#1149)
* block_sparse_mm to gather_mm * rename * nit * nit
This commit is contained in:
@@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) {
|
||||
}
|
||||
|
||||
void init_ops(nb::module_& m) {
|
||||
// TODO, remove deprecation errors in a future release
|
||||
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
||||
});
|
||||
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
||||
throw std::invalid_argument(
|
||||
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
||||
});
|
||||
m.def(
|
||||
"reshape",
|
||||
&reshape,
|
||||
@@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) {
|
||||
array: The dequantized version of ``w``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_sparse_qmm",
|
||||
&block_sparse_qmm,
|
||||
"gater_qmm",
|
||||
&gather_qmm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
@@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
This operation is the quantized equivalent to :func:`block_sparse_mm`.
|
||||
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
|
||||
This operation is the quantized equivalent to :func:`gather_mm`.
|
||||
Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
|
||||
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
||||
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
||||
|
||||
@@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_sparse_mm",
|
||||
&block_sparse_mm,
|
||||
"gather_mm",
|
||||
&gather_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
@@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with matrix-level gather.
|
||||
|
||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
||||
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
|
||||
Performs a gather of the operands with the given indices followed by a
|
||||
(possibly batched) matrix multiplication of two arrays. This operation
|
||||
is more efficient than explicitly applying a :func:`take` followed by a
|
||||
:func:`matmul`.
|
||||
|
||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
|
||||
along the batch dimensions (i.e. all but the last two dimensions) of
|
||||
``a`` and ``b`` respectively.
|
||||
|
||||
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
|
||||
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
||||
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
|
||||
contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
||||
|
||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
|
||||
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
||||
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
Reference in New Issue
Block a user