mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Rename block sparse (#1149)
* block_sparse_mm to gather_mm * rename * nit * nit
This commit is contained in:
parent
e6fecbb3e1
commit
d568c7ee36
@ -35,7 +35,6 @@ Operations
|
|||||||
bitwise_or
|
bitwise_or
|
||||||
bitwise_xor
|
bitwise_xor
|
||||||
block_masked_mm
|
block_masked_mm
|
||||||
block_sparse_mm
|
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
@ -69,6 +68,8 @@ Operations
|
|||||||
floor
|
floor
|
||||||
floor_divide
|
floor_divide
|
||||||
full
|
full
|
||||||
|
gather_mm
|
||||||
|
gather_qmm
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
identity
|
identity
|
||||||
|
@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace
|
|||||||
|
|
||||||
```
|
```
|
||||||
python test.py
|
python test.py
|
||||||
`
|
```
|
||||||
|
@ -32,8 +32,6 @@ DEFAULT(ArgReduce)
|
|||||||
DEFAULT(ArgSort)
|
DEFAULT(ArgSort)
|
||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(BlockSparseMM)
|
|
||||||
DEFAULT(BlockSparseQMM)
|
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
@ -49,6 +47,8 @@ DEFAULT(ErfInv)
|
|||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
DEFAULT(Floor)
|
DEFAULT(Floor)
|
||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
|
DEFAULT(GatherMM)
|
||||||
|
DEFAULT(GatherQMM)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
|
@ -43,8 +43,8 @@ DEFAULT(AsType)
|
|||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(BlockSparseMM)
|
DEFAULT(GatherMM)
|
||||||
DEFAULT(BlockSparseQMM)
|
DEFAULT(GatherQMM)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
|
@ -190,10 +190,10 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[BlockSparseMM::eval] Currently only supports float32.");
|
"[GatherMM::eval] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
@ -277,4 +277,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
|
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 6);
|
assert(inputs.size() == 6);
|
||||||
|
|
||||||
auto& x_pre = inputs[0];
|
auto& x_pre = inputs[0];
|
||||||
|
@ -324,12 +324,12 @@ void steel_matmul_conv_groups(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
@ -575,12 +575,12 @@ void steel_matmul(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
@ -1170,12 +1170,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
@ -1435,12 +1435,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
// assert(inputs.size() == 2);
|
// assert(inputs.size() == 2);
|
||||||
if (!issubdtype(out.dtype(), floating)) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul] Does not yet support non-floating point types.");
|
"[GatherMM] Does not yet support non-floating point types.");
|
||||||
}
|
}
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -1700,12 +1700,12 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
std::string hash_name = kname.str();
|
std::string hash_name = kname.str();
|
||||||
|
@ -196,7 +196,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 6);
|
assert(inputs.size() == 6);
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
@ -34,8 +34,6 @@ NO_GPU(AsType)
|
|||||||
NO_GPU(AsStrided)
|
NO_GPU(AsStrided)
|
||||||
NO_GPU(BitwiseBinary)
|
NO_GPU(BitwiseBinary)
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(BlockSparseMM)
|
|
||||||
NO_GPU(BlockSparseQMM)
|
|
||||||
NO_GPU(Broadcast)
|
NO_GPU(Broadcast)
|
||||||
NO_GPU(Ceil)
|
NO_GPU(Ceil)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
@ -60,6 +58,8 @@ NO_GPU(FFT)
|
|||||||
NO_GPU(Floor)
|
NO_GPU(Floor)
|
||||||
NO_GPU(Full)
|
NO_GPU(Full)
|
||||||
NO_GPU(Gather)
|
NO_GPU(Gather)
|
||||||
|
NO_GPU(GatherMM)
|
||||||
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Greater)
|
NO_GPU(Greater)
|
||||||
NO_GPU(GreaterEqual)
|
NO_GPU(GreaterEqual)
|
||||||
NO_GPU(Less)
|
NO_GPU(Less)
|
||||||
|
23
mlx/ops.cpp
23
mlx/ops.cpp
@ -3491,7 +3491,7 @@ array dequantize(
|
|||||||
return w_full;
|
return w_full;
|
||||||
}
|
}
|
||||||
|
|
||||||
array block_sparse_qmm(
|
array gather_qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
@ -3508,7 +3508,7 @@ array block_sparse_qmm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
@ -3529,8 +3529,7 @@ array block_sparse_qmm(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<BlockSparseQMM>(
|
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||||
to_stream(s), group_size, bits, transpose),
|
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
@ -3937,7 +3936,7 @@ array block_masked_mm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Compute matrix product with matrix-level gather */
|
/** Compute matrix product with matrix-level gather */
|
||||||
array block_sparse_mm(
|
array gather_mm(
|
||||||
array a,
|
array a,
|
||||||
array b,
|
array b,
|
||||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||||
@ -3954,7 +3953,7 @@ array block_sparse_mm(
|
|||||||
|
|
||||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[block_sparse_mm] Got 0 dimension input. Inputs must "
|
"[gather_mm] Got 0 dimension input. Inputs must "
|
||||||
"have at least one dimension.");
|
"have at least one dimension.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3969,8 +3968,8 @@ array block_sparse_mm(
|
|||||||
|
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[block_sparse_mm] Last dimension of first input with shape "
|
msg << "[gather_mm] Last dimension of first input with shape " << a.shape()
|
||||||
<< a.shape() << " must match second to last dimension of"
|
<< " must match second to last dimension of"
|
||||||
<< " second input with shape " << b.shape() << ".";
|
<< " second input with shape " << b.shape() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -3979,7 +3978,7 @@ array block_sparse_mm(
|
|||||||
auto out_type = result_type(a, b);
|
auto out_type = result_type(a, b);
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[block_sparse_mm] Only real floating point types are supported but "
|
msg << "[gather_mm] Only real floating point types are supported but "
|
||||||
<< a.dtype() << " and " << b.dtype()
|
<< a.dtype() << " and " << b.dtype()
|
||||||
<< " were provided which results in " << out_type
|
<< " were provided which results in " << out_type
|
||||||
<< ", which is not a real floating point type.";
|
<< ", which is not a real floating point type.";
|
||||||
@ -3995,12 +3994,12 @@ array block_sparse_mm(
|
|||||||
|
|
||||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
"[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
"[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||||
}
|
}
|
||||||
|
|
||||||
lhs_indices = astype(lhs_indices, uint32, s);
|
lhs_indices = astype(lhs_indices, uint32, s);
|
||||||
@ -4024,7 +4023,7 @@ array block_sparse_mm(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<BlockSparseMM>(to_stream(s)),
|
std::make_shared<GatherMM>(to_stream(s)),
|
||||||
{a, b, lhs_indices, rhs_indices});
|
{a, b, lhs_indices, rhs_indices});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
|
@ -1158,7 +1158,7 @@ array dequantize(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute matrix products with matrix-level gather. */
|
/** Compute matrix products with matrix-level gather. */
|
||||||
array block_sparse_qmm(
|
array gather_qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
@ -1210,7 +1210,7 @@ array block_masked_mm(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute matrix product with matrix-level gather */
|
/** Compute matrix product with matrix-level gather */
|
||||||
array block_sparse_mm(
|
array gather_mm(
|
||||||
array a,
|
array a,
|
||||||
array b,
|
array b,
|
||||||
std::optional<array> lhs_indices = std::nullopt,
|
std::optional<array> lhs_indices = std::nullopt,
|
||||||
|
@ -2376,13 +2376,13 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
|||||||
transpose_ == qm_other.transpose_;
|
transpose_ == qm_other.transpose_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> BlockSparseQMM::vmap(
|
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("BlockSparseQMM::vmap NYI");
|
throw std::runtime_error("GatherQMM::vmap NYI");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> BlockSparseQMM::vjp(
|
std::vector<array> GatherQMM::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
@ -2406,7 +2406,7 @@ std::vector<array> BlockSparseQMM::vjp(
|
|||||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||||
lhs_indices,
|
lhs_indices,
|
||||||
expand_dims(
|
expand_dims(
|
||||||
block_sparse_qmm(
|
gather_qmm(
|
||||||
cotan,
|
cotan,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
@ -2428,27 +2428,27 @@ std::vector<array> BlockSparseQMM::vjp(
|
|||||||
// gradient wrt to the indices is undefined
|
// gradient wrt to the indices is undefined
|
||||||
else if (arg > 3) {
|
else if (arg > 3) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"BlockSparseQMM::vjp cannot compute the gradient wrt the indices.");
|
"GatherQMM::vjp cannot compute the gradient wrt the indices.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// gradient wrt to w_q, scales or biases
|
// gradient wrt to w_q, scales or biases
|
||||||
else {
|
else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"BlockSparseQMM::vjp no gradient wrt the quantized matrix yet.");
|
"GatherQMM::vjp no gradient wrt the quantized matrix yet.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> BlockSparseQMM::jvp(
|
std::vector<array> GatherQMM::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
throw std::runtime_error("BlockSparseQMM::jvp NYI");
|
throw std::runtime_error("GatherQMM::jvp NYI");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BlockSparseQMM::is_equivalent(const Primitive& other) const {
|
bool GatherQMM::is_equivalent(const Primitive& other) const {
|
||||||
const BlockSparseQMM& qm_other = static_cast<const BlockSparseQMM&>(other);
|
const GatherQMM& qm_other = static_cast<const GatherQMM&>(other);
|
||||||
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
||||||
transpose_ == qm_other.transpose_;
|
transpose_ == qm_other.transpose_;
|
||||||
}
|
}
|
||||||
@ -3523,7 +3523,7 @@ std::vector<array> BlockMaskedMM::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> BlockSparseMM::vjp(
|
std::vector<array> GatherMM::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
@ -3548,7 +3548,7 @@ std::vector<array> BlockSparseMM::vjp(
|
|||||||
base = reshape(base, {-1, M, K}, stream());
|
base = reshape(base, {-1, M, K}, stream());
|
||||||
|
|
||||||
// g : (out_batch_shape) + (M, K)
|
// g : (out_batch_shape) + (M, K)
|
||||||
auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
||||||
g = expand_dims(g, -3, stream());
|
g = expand_dims(g, -3, stream());
|
||||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||||
|
|
||||||
@ -3563,14 +3563,14 @@ std::vector<array> BlockSparseMM::vjp(
|
|||||||
base = reshape(base, {-1, K, N}, stream());
|
base = reshape(base, {-1, K, N}, stream());
|
||||||
|
|
||||||
// g : (out_batch_shape) + (K, N)
|
// g : (out_batch_shape) + (K, N)
|
||||||
auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
||||||
g = expand_dims(g, -3, stream());
|
g = expand_dims(g, -3, stream());
|
||||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||||
|
|
||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[BlockSparseMM] Cannot calculate VJP with respect to indices.");
|
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
|
@ -502,9 +502,9 @@ class BlockMaskedMM : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class BlockSparseMM : public UnaryPrimitive {
|
class GatherMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {};
|
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -515,7 +515,7 @@ class BlockSparseMM : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(BlockSparseMM)
|
DEFINE_PRINT(GatherMM)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1467,13 +1467,9 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class BlockSparseQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit BlockSparseQMM(
|
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
||||||
Stream stream,
|
|
||||||
int group_size,
|
|
||||||
int bits,
|
|
||||||
bool transpose)
|
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
@ -1484,7 +1480,7 @@ class BlockSparseQMM : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(BlockSparseQMM)
|
DEFINE_PRINT(GatherQMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -40,6 +40,15 @@ double scalar_to_double(Scalar s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_ops(nb::module_& m) {
|
void init_ops(nb::module_& m) {
|
||||||
|
// TODO, remove deprecation errors in a future release
|
||||||
|
m.def("block_sparse_mm", [](nb::args, nb::kwargs) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"block_sparse_mm is deprecated. Please use gather_mm which has the same signature");
|
||||||
|
});
|
||||||
|
m.def("block_sparse_qmm", [](nb::args, nb::kwargs) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"block_sparse_qmm is deprecated. Please use gather_qmm which has the same signature");
|
||||||
|
});
|
||||||
m.def(
|
m.def(
|
||||||
"reshape",
|
"reshape",
|
||||||
&reshape,
|
&reshape,
|
||||||
@ -3748,8 +3757,8 @@ void init_ops(nb::module_& m) {
|
|||||||
array: The dequantized version of ``w``
|
array: The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"block_sparse_qmm",
|
"gater_qmm",
|
||||||
&block_sparse_qmm,
|
&gather_qmm,
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
@ -3762,12 +3771,12 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform quantized matrix multiplication with matrix-level gather.
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
This operation is the quantized equivalent to :func:`block_sparse_mm`.
|
This operation is the quantized equivalent to :func:`gather_mm`.
|
||||||
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
|
Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
|
||||||
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
||||||
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
||||||
|
|
||||||
@ -3965,8 +3974,8 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"block_sparse_mm",
|
"gather_mm",
|
||||||
&block_sparse_mm,
|
&gather_mm,
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"lhs_indices"_a = nb::none(),
|
"lhs_indices"_a = nb::none(),
|
||||||
@ -3974,20 +3983,24 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Matrix multiplication with matrix-level gather.
|
Matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
Performs a gather of the operands with the given indices followed by a
|
||||||
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
|
(possibly batched) matrix multiplication of two arrays. This operation
|
||||||
|
is more efficient than explicitly applying a :func:`take` followed by a
|
||||||
|
:func:`matmul`.
|
||||||
|
|
||||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
|
||||||
|
along the batch dimensions (i.e. all but the last two dimensions) of
|
||||||
|
``a`` and ``b`` respectively.
|
||||||
|
|
||||||
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
|
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
|
||||||
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
||||||
|
|
||||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
|
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
||||||
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
|
@ -408,9 +408,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
with self.subTest(
|
with self.subTest(
|
||||||
B=B, # Batch size
|
B=B, # Batch size
|
||||||
D=D, # Dimension of mm
|
D=D, # Dimension of mm
|
||||||
n_kv_heads=n_kv_heads, # key-value heads
|
n_kv_heads=n_kv_heads, # key-value heads
|
||||||
factor=factor, # factor to get query heads
|
factor=factor, # factor to get query heads
|
||||||
qsl=qsl, # Query sequence length
|
qsl=qsl, # Query sequence length
|
||||||
ksl=ksl, # Key sequence length
|
ksl=ksl, # Key sequence length
|
||||||
dtype=dtype # Data type
|
dtype=dtype # Data type
|
||||||
):
|
):
|
||||||
@ -432,22 +432,22 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
|
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
|
||||||
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
|
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
|
||||||
|
|
||||||
# Rearrange to move heads up
|
# Rearrange to move heads up
|
||||||
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||||
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||||
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||||
|
|
||||||
# Do attn style matmul
|
# Do attn style matmul
|
||||||
s_np = q_np_reshape @ k_np_reshape
|
s_np = q_np_reshape @ k_np_reshape
|
||||||
o_np = s_np @ v_np_reshape
|
o_np = s_np @ v_np_reshape
|
||||||
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||||
|
|
||||||
# Test mlx
|
# Test mlx
|
||||||
q_mx = mx.array(q_np)
|
q_mx = mx.array(q_np)
|
||||||
k_mx = mx.array(k_np)
|
k_mx = mx.array(k_np)
|
||||||
v_mx = mx.array(v_np)
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
# Rearrange to move heads up
|
# Rearrange to move heads up
|
||||||
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||||
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||||
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||||
@ -810,8 +810,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||||
|
|
||||||
def test_block_sparse_matmul(self):
|
def test_gather_matmul(self):
|
||||||
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None):
|
def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||||
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
|
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
|
||||||
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
|
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
|
||||||
lhs_indices = lhs_indices or np.arange(a.shape[0])
|
lhs_indices = lhs_indices or np.arange(a.shape[0])
|
||||||
@ -848,12 +848,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
a_mx = mx.array(a_np)
|
a_mx = mx.array(a_np)
|
||||||
b_mx = mx.array(b_np)
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||||
|
|
||||||
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
|
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
|
||||||
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
|
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
|
||||||
|
|
||||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||||
|
|
||||||
@ -920,7 +920,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
lhs_indices = [0, 13, 12]
|
lhs_indices = [0, 13, 12]
|
||||||
rhs_indices = [0, 3, 5]
|
rhs_indices = [0, 3, 5]
|
||||||
|
|
||||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||||
|
|
||||||
# MLX
|
# MLX
|
||||||
a_mx = a_mx.reshape((5, 1, 32, 32))
|
a_mx = a_mx.reshape((5, 1, 32, 32))
|
||||||
@ -932,7 +932,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
lhs_indices_mx = mx.array(lhs_indices)
|
lhs_indices_mx = mx.array(lhs_indices)
|
||||||
rhs_indices_mx = mx.array(rhs_indices)
|
rhs_indices_mx = mx.array(rhs_indices)
|
||||||
|
|
||||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||||
|
|
||||||
@ -946,17 +946,17 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
rhs_indices = [0, 2]
|
rhs_indices = [0, 2]
|
||||||
|
|
||||||
b_np_t = np.swapaxes(b_np, -1, -2)
|
b_np_t = np.swapaxes(b_np, -1, -2)
|
||||||
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices)
|
out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices)
|
||||||
|
|
||||||
lhs_indices_mx = mx.array(lhs_indices)
|
lhs_indices_mx = mx.array(lhs_indices)
|
||||||
rhs_indices_mx = mx.array(rhs_indices)
|
rhs_indices_mx = mx.array(rhs_indices)
|
||||||
|
|
||||||
b_mx_t = mx.swapaxes(b_mx, -1, -2)
|
b_mx_t = mx.swapaxes(b_mx, -1, -2)
|
||||||
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
|
out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||||
|
|
||||||
def test_block_sparse_matmul_grad(self):
|
def test_gather_matmul_grad(self):
|
||||||
|
|
||||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||||
@ -977,7 +977,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
return a @ b
|
return a @ b
|
||||||
|
|
||||||
def f_test(a, b):
|
def f_test(a, b):
|
||||||
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices)
|
return mx.gather_mm(a, b, lhs_indices, rhs_indices)
|
||||||
|
|
||||||
a_mx = mx.random.normal((4, 2, 32, 32))
|
a_mx = mx.random.normal((4, 2, 32, 32))
|
||||||
b_mx = mx.random.normal((4, 1, 32, 32))
|
b_mx = mx.random.normal((4, 1, 32, 32))
|
||||||
|
@ -277,7 +277,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
def test_block_sparse_qmm(self):
|
def test_gather_qmm(self):
|
||||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
@ -322,8 +322,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
if rhs_indices is not None:
|
if rhs_indices is not None:
|
||||||
rhs_indices = mx.array(rhs_indices)
|
rhs_indices = mx.array(rhs_indices)
|
||||||
|
|
||||||
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
|
c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)
|
||||||
c2 = mx.block_sparse_qmm(
|
c2 = mx.gather_qmm(
|
||||||
x,
|
x,
|
||||||
qw,
|
qw,
|
||||||
s,
|
s,
|
||||||
@ -390,7 +390,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
test_shape(32, 512, 32, transpose=False, **kwargs)
|
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||||
test_shape(1, 512, 32, transpose=False, **kwargs)
|
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||||
|
|
||||||
def test_block_sparse_matmul_grad(self):
|
def test_gather_matmul_grad(self):
|
||||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
@ -406,10 +406,10 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_hat, qw, s, b = quantize(w)
|
w_hat, qw, s, b = quantize(w)
|
||||||
|
|
||||||
def f_ref(x, w, i1, i2):
|
def f_ref(x, w, i1, i2):
|
||||||
return mx.block_sparse_mm(x, w, i1, i2).sum()
|
return mx.gather_mm(x, w, i1, i2).sum()
|
||||||
|
|
||||||
def f_test(x, qw, s, b, i1, i2):
|
def f_test(x, qw, s, b, i1, i2):
|
||||||
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
|
return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
|
||||||
|
|
||||||
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
|
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
|
||||||
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
|
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
|
||||||
|
Loading…
Reference in New Issue
Block a user