Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun 2024-05-22 07:48:34 -07:00 committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 120 additions and 111 deletions

View File

@ -35,7 +35,6 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
block_sparse_mm
broadcast_to broadcast_to
ceil ceil
clip clip
@ -69,6 +68,8 @@ Operations
floor floor
floor_divide floor_divide
full full
gather_mm
gather_qmm
greater greater
greater_equal greater_equal
identity identity

View File

@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace
``` ```
python test.py python test.py
` ```

View File

@ -32,8 +32,6 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort) DEFAULT(ArgSort)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM) DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
@ -49,6 +47,8 @@ DEFAULT(ErfInv)
DEFAULT(FFT) DEFAULT(FFT)
DEFAULT(Floor) DEFAULT(Floor)
DEFAULT(Gather) DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Less) DEFAULT(Less)

View File

@ -43,8 +43,8 @@ DEFAULT(AsType)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM) DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM) DEFAULT(GatherMM)
DEFAULT(BlockSparseQMM) DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)

View File

@ -190,10 +190,10 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) { void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) { if (out.dtype() != float32) {
throw std::runtime_error( throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32."); "[GatherMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));

View File

@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
} }
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) { void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6); assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];

View File

@ -1435,12 +1435,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel; using namespace mlx::steel;
// assert(inputs.size() == 2); // assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) { if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[GatherMM] Does not yet support non-floating point types.");
} }
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);

View File

@ -196,7 +196,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6); assert(inputs.size() == 6);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));

View File

@ -34,8 +34,6 @@ NO_GPU(AsType)
NO_GPU(AsStrided) NO_GPU(AsStrided)
NO_GPU(BitwiseBinary) NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(BlockSparseMM)
NO_GPU(BlockSparseQMM)
NO_GPU(Broadcast) NO_GPU(Broadcast)
NO_GPU(Ceil) NO_GPU(Ceil)
NO_GPU_MULTI(Compiled) NO_GPU_MULTI(Compiled)
@ -60,6 +58,8 @@ NO_GPU(FFT)
NO_GPU(Floor) NO_GPU(Floor)
NO_GPU(Full) NO_GPU(Full)
NO_GPU(Gather) NO_GPU(Gather)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater) NO_GPU(Greater)
NO_GPU(GreaterEqual) NO_GPU(GreaterEqual)
NO_GPU(Less) NO_GPU(Less)

View File

@ -3491,7 +3491,7 @@ array dequantize(
return w_full; return w_full;
} }
array block_sparse_qmm( array gather_qmm(
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
@ -3508,7 +3508,7 @@ array block_sparse_qmm(
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits);
// Extract indices and broadcast them // Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s); array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -3529,8 +3529,7 @@ array block_sparse_qmm(
auto out = array( auto out = array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<BlockSparseQMM>( std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
to_stream(s), group_size, bits, transpose),
{astype(x, out_type, s), {astype(x, out_type, s),
w, w,
astype(scales, out_type, s), astype(scales, out_type, s),
@ -3937,7 +3936,7 @@ array block_masked_mm(
} }
/** Compute matrix product with matrix-level gather */ /** Compute matrix product with matrix-level gather */
array block_sparse_mm( array gather_mm(
array a, array a,
array b, array b,
std::optional<array> lhs_indices_ /* = std::nullopt */, std::optional<array> lhs_indices_ /* = std::nullopt */,
@ -3954,7 +3953,7 @@ array block_sparse_mm(
if (a.ndim() == 0 || b.ndim() == 0) { if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"[block_sparse_mm] Got 0 dimension input. Inputs must " "[gather_mm] Got 0 dimension input. Inputs must "
"have at least one dimension."); "have at least one dimension.");
} }
@ -3969,8 +3968,8 @@ array block_sparse_mm(
if (a.shape(-1) != b.shape(-2)) { if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[block_sparse_mm] Last dimension of first input with shape " msg << "[gather_mm] Last dimension of first input with shape " << a.shape()
<< a.shape() << " must match second to last dimension of" << " must match second to last dimension of"
<< " second input with shape " << b.shape() << "."; << " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -3979,7 +3978,7 @@ array block_sparse_mm(
auto out_type = result_type(a, b); auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[block_sparse_mm] Only real floating point types are supported but " msg << "[gather_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type << " were provided which results in " << out_type
<< ", which is not a real floating point type."; << ", which is not a real floating point type.";
@ -3995,12 +3994,12 @@ array block_sparse_mm(
if (!issubdtype(lhs_indices.dtype(), integer)) { if (!issubdtype(lhs_indices.dtype(), integer)) {
throw std::invalid_argument( throw std::invalid_argument(
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral."); "[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
} }
if (!issubdtype(rhs_indices.dtype(), integer)) { if (!issubdtype(rhs_indices.dtype(), integer)) {
throw std::invalid_argument( throw std::invalid_argument(
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral."); "[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
} }
lhs_indices = astype(lhs_indices, uint32, s); lhs_indices = astype(lhs_indices, uint32, s);
@ -4024,7 +4023,7 @@ array block_sparse_mm(
auto out = array( auto out = array(
out_shape, out_shape,
out_type, out_type,
std::make_shared<BlockSparseMM>(to_stream(s)), std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices}); {a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions

View File

@ -1158,7 +1158,7 @@ array dequantize(
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */ /** Compute matrix products with matrix-level gather. */
array block_sparse_qmm( array gather_qmm(
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
@ -1210,7 +1210,7 @@ array block_masked_mm(
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix product with matrix-level gather */ /** Compute matrix product with matrix-level gather */
array block_sparse_mm( array gather_mm(
array a, array a,
array b, array b,
std::optional<array> lhs_indices = std::nullopt, std::optional<array> lhs_indices = std::nullopt,

View File

@ -2376,13 +2376,13 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
transpose_ == qm_other.transpose_; transpose_ == qm_other.transpose_;
} }
std::pair<std::vector<array>, std::vector<int>> BlockSparseQMM::vmap( std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("BlockSparseQMM::vmap NYI"); throw std::runtime_error("GatherQMM::vmap NYI");
} }
std::vector<array> BlockSparseQMM::vjp( std::vector<array> GatherQMM::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
@ -2406,7 +2406,7 @@ std::vector<array> BlockSparseQMM::vjp(
flatten(zeros_like(x, stream()), 0, -3, stream()), flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices, lhs_indices,
expand_dims( expand_dims(
block_sparse_qmm( gather_qmm(
cotan, cotan,
w, w,
scales, scales,
@ -2428,27 +2428,27 @@ std::vector<array> BlockSparseQMM::vjp(
// gradient wrt to the indices is undefined // gradient wrt to the indices is undefined
else if (arg > 3) { else if (arg > 3) {
throw std::runtime_error( throw std::runtime_error(
"BlockSparseQMM::vjp cannot compute the gradient wrt the indices."); "GatherQMM::vjp cannot compute the gradient wrt the indices.");
} }
// gradient wrt to w_q, scales or biases // gradient wrt to w_q, scales or biases
else { else {
throw std::runtime_error( throw std::runtime_error(
"BlockSparseQMM::vjp no gradient wrt the quantized matrix yet."); "GatherQMM::vjp no gradient wrt the quantized matrix yet.");
} }
} }
return vjps; return vjps;
} }
std::vector<array> BlockSparseQMM::jvp( std::vector<array> GatherQMM::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
throw std::runtime_error("BlockSparseQMM::jvp NYI"); throw std::runtime_error("GatherQMM::jvp NYI");
} }
bool BlockSparseQMM::is_equivalent(const Primitive& other) const { bool GatherQMM::is_equivalent(const Primitive& other) const {
const BlockSparseQMM& qm_other = static_cast<const BlockSparseQMM&>(other); const GatherQMM& qm_other = static_cast<const GatherQMM&>(other);
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
transpose_ == qm_other.transpose_; transpose_ == qm_other.transpose_;
} }
@ -3523,7 +3523,7 @@ std::vector<array> BlockMaskedMM::vjp(
return vjps; return vjps;
} }
std::vector<array> BlockSparseMM::vjp( std::vector<array> GatherMM::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
@ -3548,7 +3548,7 @@ std::vector<array> BlockSparseMM::vjp(
base = reshape(base, {-1, M, K}, stream()); base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K) // g : (out_batch_shape) + (M, K)
auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream()); auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
@ -3563,14 +3563,14 @@ std::vector<array> BlockSparseMM::vjp(
base = reshape(base, {-1, K, N}, stream()); base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N) // g : (out_batch_shape) + (K, N)
auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream()); auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
g = expand_dims(g, -3, stream()); g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream())); vjps.push_back(reshape(gacc, base_shape, stream()));
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[BlockSparseMM] Cannot calculate VJP with respect to indices."); "[GatherMM] Cannot calculate VJP with respect to indices.");
} }
} }
return vjps; return vjps;

View File

@ -502,9 +502,9 @@ class BlockMaskedMM : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class BlockSparseMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive {
public: public:
explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {}; explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -515,7 +515,7 @@ class BlockSparseMM : public UnaryPrimitive {
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
DEFINE_PRINT(BlockSparseMM) DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_DEFAULT_IS_EQUIVALENT()
private: private:
@ -1467,13 +1467,9 @@ class QuantizedMatmul : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class BlockSparseQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
public: public:
explicit BlockSparseQMM( explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Stream stream,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
@ -1484,7 +1480,7 @@ class BlockSparseQMM : public UnaryPrimitive {
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS() DEFINE_GRADS()
DEFINE_PRINT(BlockSparseQMM) DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:

View File

@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) {
} }
void init_ops(nb::module_& m) { void init_ops(nb::module_& m) {
// TODO, remove deprecation errors in a future release
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
});
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
throw std::invalid_argument(
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
});
m.def( m.def(
"reshape", "reshape",
&reshape, &reshape,
@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) {
array: The dequantized version of ``w`` array: The dequantized version of ``w``
)pbdoc"); )pbdoc");
m.def( m.def(
"block_sparse_qmm", "gater_qmm",
&block_sparse_qmm, &gather_qmm,
nb::arg(), nb::arg(),
nb::arg(), nb::arg(),
"scales"_a, "scales"_a,
@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather. Perform quantized matrix multiplication with matrix-level gather.
This operation is the quantized equivalent to :func:`block_sparse_mm`. This operation is the quantized equivalent to :func:`gather_mm`.
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
``rhs_indices`` contain flat indices along the batch dimensions (i.e. ``rhs_indices`` contain flat indices along the batch dimensions (i.e.
all but the last two dimensions) of ``x`` and ``w`` respectively. all but the last two dimensions) of ``x`` and ``w`` respectively.
@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"block_sparse_mm", "gather_mm",
&block_sparse_mm, &gather_mm,
nb::arg(), nb::arg(),
nb::arg(), nb::arg(),
"lhs_indices"_a = nb::none(), "lhs_indices"_a = nb::none(),
@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Matrix multiplication with matrix-level gather. Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. Performs a gather of the operands with the given indices followed by a
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`. (possibly batched) matrix multiplication of two arrays. This operation
is more efficient than explicitly applying a :func:`take` followed by a
:func:`matmul`.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
along the batch dimensions (i.e. all but the last two dimensions) of
``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args: Args:
a (array): Input array. a (array): Input array.

View File

@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
def test_block_sparse_matmul(self): def test_gather_matmul(self):
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None): def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
a = a.reshape((-1, a.shape[-2], a.shape[-1])) a = a.reshape((-1, a.shape[-2], a.shape[-1]))
b = b.reshape((-1, b.shape[-2], b.shape[-1])) b = b.reshape((-1, b.shape[-2], b.shape[-1]))
lhs_indices = lhs_indices or np.arange(a.shape[0]) lhs_indices = lhs_indices or np.arange(a.shape[0])
@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase):
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
b_mx = mx.array(b_np) b_mx = mx.array(b_np)
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices) lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices) rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices = [0, 13, 12] lhs_indices = [0, 13, 12]
rhs_indices = [0, 3, 5] rhs_indices = [0, 3, 5]
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
# MLX # MLX
a_mx = a_mx.reshape((5, 1, 32, 32)) a_mx = a_mx.reshape((5, 1, 32, 32))
@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_mx = mx.array(lhs_indices) lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices) rhs_indices_mx = mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase):
rhs_indices = [0, 2] rhs_indices = [0, 2]
b_np_t = np.swapaxes(b_np, -1, -2) b_np_t = np.swapaxes(b_np, -1, -2)
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices) out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices)
lhs_indices_mx = mx.array(lhs_indices) lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices) rhs_indices_mx = mx.array(rhs_indices)
b_mx_t = mx.swapaxes(b_mx, -1, -2) b_mx_t = mx.swapaxes(b_mx, -1, -2)
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_block_sparse_matmul_grad(self): def test_gather_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase):
return a @ b return a @ b
def f_test(a, b): def f_test(a, b):
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices) return mx.gather_mm(a, b, lhs_indices, rhs_indices)
a_mx = mx.random.normal((4, 2, 32, 32)) a_mx = mx.random.normal((4, 2, 32, 32))
b_mx = mx.random.normal((4, 1, 32, 32)) b_mx = mx.random.normal((4, 1, 32, 32))

View File

@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_block_sparse_qmm(self): def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4): def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
if rhs_indices is not None: if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices) rhs_indices = mx.array(rhs_indices)
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices) c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.block_sparse_qmm( c2 = mx.gather_qmm(
x, x,
qw, qw,
s, s,
@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs)
def test_block_sparse_matmul_grad(self): def test_gather_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4): def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat, qw, s, b = quantize(w) w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2): def f_ref(x, w, i1, i2):
return mx.block_sparse_mm(x, w, i1, i2).sum() return mx.gather_mm(x, w, i1, i2).sum()
def f_test(x, qw, s, b, i1, i2): def f_test(x, qw, s, b, i1, i2):
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum() return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices) r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices) r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)