mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Block sparse qmm (#1124)
This commit is contained in:
committed by
GitHub
parent
1873ffda01
commit
e78a6518fa
13
mlx/ops.h
13
mlx/ops.h
@@ -1157,6 +1157,19 @@ array dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
array block_sparse_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
array tensordot(
|
||||
const array& a,
|
||||
|
||||
Reference in New Issue
Block a user