diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index c88885101..ced813e90 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -35,7 +35,6 @@ Operations bitwise_or bitwise_xor block_masked_mm - block_sparse_mm broadcast_to ceil clip @@ -69,6 +68,8 @@ Operations floor floor_divide full + gather_mm + gather_qmm greater greater_equal identity diff --git a/examples/extensions/README.md b/examples/extensions/README.md index 1cb113459..610de99f4 100644 --- a/examples/extensions/README.md +++ b/examples/extensions/README.md @@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace ``` python test.py -` +``` diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1187015cf..ad778cbc7 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -32,8 +32,6 @@ DEFAULT(ArgReduce) DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(BlockMaskedMM) -DEFAULT(BlockSparseMM) -DEFAULT(BlockSparseQMM) DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) @@ -49,6 +47,8 @@ DEFAULT(ErfInv) DEFAULT(FFT) DEFAULT(Floor) DEFAULT(Gather) +DEFAULT(GatherMM) +DEFAULT(GatherQMM) DEFAULT(Greater) DEFAULT(GreaterEqual) DEFAULT(Less) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 0b502615c..6c9648461 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -43,8 +43,8 @@ DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(BlockMaskedMM) -DEFAULT(BlockSparseMM) -DEFAULT(BlockSparseQMM) +DEFAULT(GatherMM) +DEFAULT(GatherQMM) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index fa2dd32af..655e260e5 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -190,10 +190,10 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { } } -void BlockSparseMM::eval(const std::vector& inputs, array& out) { +void GatherMM::eval(const std::vector& inputs, array& out) { if (out.dtype() != float32) { throw std::runtime_error( - "[BlockSparseMM::eval] Currently only supports float32."); + "[GatherMM::eval] Currently only supports float32."); } out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -277,4 +277,4 @@ void BlockSparseMM::eval(const std::vector& inputs, array& out) { } } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 4dfb1780e..7f37e7bc2 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector& inputs, array& out) { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); } -void BlockSparseQMM::eval(const std::vector& inputs, array& out) { +void GatherQMM::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 6); auto& x_pre = inputs[0]; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 255391238..4de064905 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -324,12 +324,12 @@ void steel_matmul_conv_groups( }; // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -575,12 +575,12 @@ void steel_matmul( }; // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -1170,12 +1170,12 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { }; // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -1435,12 +1435,12 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } -void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { +void GatherMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( - "[matmul] Does not yet support non-floating point types."); + "[GatherMM] Does not yet support non-floating point types."); } auto& s = stream(); auto& d = metal::device(s.device); @@ -1700,12 +1700,12 @@ void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { }; // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 609d7bfac..48f1387a9 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -196,7 +196,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } -void BlockSparseQMM::eval_gpu(const std::vector& inputs, array& out) { +void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 6); out.set_data(allocator::malloc_or_wait(out.nbytes())); diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 43ee0efad..3398a023b 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -34,8 +34,6 @@ NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) -NO_GPU(BlockSparseMM) -NO_GPU(BlockSparseQMM) NO_GPU(Broadcast) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) @@ -60,6 +58,8 @@ NO_GPU(FFT) NO_GPU(Floor) NO_GPU(Full) NO_GPU(Gather) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Less) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8f2062a6d..006789306 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3491,7 +3491,7 @@ array dequantize( return w_full; } -array block_sparse_qmm( +array gather_qmm( const array& x, const array& w, const array& scales, @@ -3508,7 +3508,7 @@ array block_sparse_qmm( } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits); + "gather_qmm", x, w, scales, biases, transpose, group_size, bits); // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); @@ -3529,8 +3529,7 @@ array block_sparse_qmm( auto out = array( std::move(out_shape), out_type, - std::make_shared( - to_stream(s), group_size, bits, transpose), + std::make_shared(to_stream(s), group_size, bits, transpose), {astype(x, out_type, s), w, astype(scales, out_type, s), @@ -3937,7 +3936,7 @@ array block_masked_mm( } /** Compute matrix product with matrix-level gather */ -array block_sparse_mm( +array gather_mm( array a, array b, std::optional lhs_indices_ /* = std::nullopt */, @@ -3954,7 +3953,7 @@ array block_sparse_mm( if (a.ndim() == 0 || b.ndim() == 0) { throw std::invalid_argument( - "[block_sparse_mm] Got 0 dimension input. Inputs must " + "[gather_mm] Got 0 dimension input. Inputs must " "have at least one dimension."); } @@ -3969,8 +3968,8 @@ array block_sparse_mm( if (a.shape(-1) != b.shape(-2)) { std::ostringstream msg; - msg << "[block_sparse_mm] Last dimension of first input with shape " - << a.shape() << " must match second to last dimension of" + msg << "[gather_mm] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" << " second input with shape " << b.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -3979,7 +3978,7 @@ array block_sparse_mm( auto out_type = result_type(a, b); if (!issubdtype(out_type, floating)) { std::ostringstream msg; - msg << "[block_sparse_mm] Only real floating point types are supported but " + msg << "[gather_mm] Only real floating point types are supported but " << a.dtype() << " and " << b.dtype() << " were provided which results in " << out_type << ", which is not a real floating point type."; @@ -3995,12 +3994,12 @@ array block_sparse_mm( if (!issubdtype(lhs_indices.dtype(), integer)) { throw std::invalid_argument( - "[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral."); + "[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral."); } if (!issubdtype(rhs_indices.dtype(), integer)) { throw std::invalid_argument( - "[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral."); + "[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral."); } lhs_indices = astype(lhs_indices, uint32, s); @@ -4024,7 +4023,7 @@ array block_sparse_mm( auto out = array( out_shape, out_type, - std::make_shared(to_stream(s)), + std::make_shared(to_stream(s)), {a, b, lhs_indices, rhs_indices}); // Remove the possibly inserted singleton dimensions diff --git a/mlx/ops.h b/mlx/ops.h index c334024ac..6a6b7ac05 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1158,7 +1158,7 @@ array dequantize( StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ -array block_sparse_qmm( +array gather_qmm( const array& x, const array& w, const array& scales, @@ -1210,7 +1210,7 @@ array block_masked_mm( StreamOrDevice s = {}); /** Compute matrix product with matrix-level gather */ -array block_sparse_mm( +array gather_mm( array a, array b, std::optional lhs_indices = std::nullopt, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 89454c035..4fb5d8754 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2376,13 +2376,13 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const { transpose_ == qm_other.transpose_; } -std::pair, std::vector> BlockSparseQMM::vmap( +std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("BlockSparseQMM::vmap NYI"); + throw std::runtime_error("GatherQMM::vmap NYI"); } -std::vector BlockSparseQMM::vjp( +std::vector GatherQMM::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, @@ -2406,7 +2406,7 @@ std::vector BlockSparseQMM::vjp( flatten(zeros_like(x, stream()), 0, -3, stream()), lhs_indices, expand_dims( - block_sparse_qmm( + gather_qmm( cotan, w, scales, @@ -2428,27 +2428,27 @@ std::vector BlockSparseQMM::vjp( // gradient wrt to the indices is undefined else if (arg > 3) { throw std::runtime_error( - "BlockSparseQMM::vjp cannot compute the gradient wrt the indices."); + "GatherQMM::vjp cannot compute the gradient wrt the indices."); } // gradient wrt to w_q, scales or biases else { throw std::runtime_error( - "BlockSparseQMM::vjp no gradient wrt the quantized matrix yet."); + "GatherQMM::vjp no gradient wrt the quantized matrix yet."); } } return vjps; } -std::vector BlockSparseQMM::jvp( +std::vector GatherQMM::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - throw std::runtime_error("BlockSparseQMM::jvp NYI"); + throw std::runtime_error("GatherQMM::jvp NYI"); } -bool BlockSparseQMM::is_equivalent(const Primitive& other) const { - const BlockSparseQMM& qm_other = static_cast(other); +bool GatherQMM::is_equivalent(const Primitive& other) const { + const GatherQMM& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && transpose_ == qm_other.transpose_; } @@ -3523,7 +3523,7 @@ std::vector BlockMaskedMM::vjp( return vjps; } -std::vector BlockSparseMM::vjp( +std::vector GatherMM::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, @@ -3548,7 +3548,7 @@ std::vector BlockSparseMM::vjp( base = reshape(base, {-1, M, K}, stream()); // g : (out_batch_shape) + (M, K) - auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream()); + auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); @@ -3563,14 +3563,14 @@ std::vector BlockSparseMM::vjp( base = reshape(base, {-1, K, N}, stream()); // g : (out_batch_shape) + (K, N) - auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream()); + auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); vjps.push_back(reshape(gacc, base_shape, stream())); } else { throw std::invalid_argument( - "[BlockSparseMM] Cannot calculate VJP with respect to indices."); + "[GatherMM] Cannot calculate VJP with respect to indices."); } } return vjps; diff --git a/mlx/primitives.h b/mlx/primitives.h index 7569eb834..5dd99ea60 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -502,9 +502,9 @@ class BlockMaskedMM : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; -class BlockSparseMM : public UnaryPrimitive { +class GatherMM : public UnaryPrimitive { public: - explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {}; + explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -515,7 +515,7 @@ class BlockSparseMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(BlockSparseMM) + DEFINE_PRINT(GatherMM) DEFINE_DEFAULT_IS_EQUIVALENT() private: @@ -1467,13 +1467,9 @@ class QuantizedMatmul : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; -class BlockSparseQMM : public UnaryPrimitive { +class GatherQMM : public UnaryPrimitive { public: - explicit BlockSparseQMM( - Stream stream, - int group_size, - int bits, - bool transpose) + explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), @@ -1484,7 +1480,7 @@ class BlockSparseQMM : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(BlockSparseQMM) + DEFINE_PRINT(GatherQMM) bool is_equivalent(const Primitive& other) const override; private: diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 0f8f680a7..8a4be83b0 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) { } void init_ops(nb::module_& m) { + // TODO, remove deprecation errors in a future release + m.def("block_sparse_mm", [](nb::args, nb::kwargs) { + throw std::invalid_argument( + "block_sparse_mm is deprecated. Please use gather_mm which has the same signature"); + }); + m.def("block_sparse_qmm", [](nb::args, nb::kwargs) { + throw std::invalid_argument( + "block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature"); + }); m.def( "reshape", &reshape, @@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) { array: The dequantized version of ``w`` )pbdoc"); m.def( - "block_sparse_qmm", - &block_sparse_qmm, + "gater_qmm", + &gather_qmm, nb::arg(), nb::arg(), "scales"_a, @@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. - This operation is the quantized equivalent to :func:`block_sparse_mm`. - Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and + This operation is the quantized equivalent to :func:`gather_mm`. + Similar to :func:`gather_mm`, the indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``x`` and ``w`` respectively. @@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( - "block_sparse_mm", - &block_sparse_mm, + "gather_mm", + &gather_mm, nb::arg(), nb::arg(), "lhs_indices"_a = nb::none(), @@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix multiplication with matrix-level gather. - Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. - This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`. + Performs a gather of the operands with the given indices followed by a + (possibly batched) matrix multiplication of two arrays. This operation + is more efficient than explicitly applying a :func:`take` followed by a + :func:`matmul`. - The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. + The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices + along the batch dimensions (i.e. all but the last two dimensions) of + ``a`` and ``b`` respectively. - For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, - ``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)`` + For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices`` + contains indices from the range ``[0, A1 * A2 * ... * AS)`` - For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, - ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` + For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` + contains indices from the range ``[0, B1 * B2 * ... * BS)`` Args: a (array): Input array. diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 884e8d73b..52aef4868 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -408,9 +408,9 @@ class TestBlas(mlx_tests.MLXTestCase): with self.subTest( B=B, # Batch size D=D, # Dimension of mm - n_kv_heads=n_kv_heads, # key-value heads - factor=factor, # factor to get query heads - qsl=qsl, # Query sequence length + n_kv_heads=n_kv_heads, # key-value heads + factor=factor, # factor to get query heads + qsl=qsl, # Query sequence length ksl=ksl, # Key sequence length dtype=dtype # Data type ): @@ -432,22 +432,22 @@ class TestBlas(mlx_tests.MLXTestCase): k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype) v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype) - # Rearrange to move heads up + # Rearrange to move heads up q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) # Do attn style matmul s_np = q_np_reshape @ k_np_reshape - o_np = s_np @ v_np_reshape + o_np = s_np @ v_np_reshape o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1) - # Test mlx + # Test mlx q_mx = mx.array(q_np) k_mx = mx.array(k_np) v_mx = mx.array(v_np) - # Rearrange to move heads up + # Rearrange to move heads up q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) @@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) - def test_block_sparse_matmul(self): - def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None): + def test_gather_matmul(self): + def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None): a = a.reshape((-1, a.shape[-2], a.shape[-1])) b = b.reshape((-1, b.shape[-2], b.shape[-1])) lhs_indices = lhs_indices or np.arange(a.shape[0]) @@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase): a_mx = mx.array(a_np) b_mx = mx.array(b_np) - out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) + out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices) lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices) rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices) - out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) + out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) @@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices = [0, 13, 12] rhs_indices = [0, 3, 5] - out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) + out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices) # MLX a_mx = a_mx.reshape((5, 1, 32, 32)) @@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices_mx = mx.array(lhs_indices) rhs_indices_mx = mx.array(rhs_indices) - out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) + out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) @@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase): rhs_indices = [0, 2] b_np_t = np.swapaxes(b_np, -1, -2) - out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices) + out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices) lhs_indices_mx = mx.array(lhs_indices) rhs_indices_mx = mx.array(rhs_indices) b_mx_t = mx.swapaxes(b_mx, -1, -2) - out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) + out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) - def test_block_sparse_matmul_grad(self): + def test_gather_matmul_grad(self): lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) @@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase): return a @ b def f_test(a, b): - return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices) + return mx.gather_mm(a, b, lhs_indices, rhs_indices) a_mx = mx.random.normal((4, 2, 32, 32)) b_mx = mx.random.normal((4, 1, 32, 32)) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 21dcc3103..92ad3d3e7 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) - def test_block_sparse_qmm(self): + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=64, bits=4): qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) @@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase): if rhs_indices is not None: rhs_indices = mx.array(rhs_indices) - c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices) - c2 = mx.block_sparse_qmm( + c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices) + c2 = mx.gather_qmm( x, qw, s, @@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase): test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs) - def test_block_sparse_matmul_grad(self): + def test_gather_matmul_grad(self): def quantize(w, transpose=True, group_size=64, bits=4): qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) @@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase): w_hat, qw, s, b = quantize(w) def f_ref(x, w, i1, i2): - return mx.block_sparse_mm(x, w, i1, i2).sum() + return mx.gather_mm(x, w, i1, i2).sum() def f_test(x, qw, s, b, i1, i2): - return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum() + return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum() r1 = f_ref(x, w_hat, lhs_indices, rhs_indices) r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)