Block sparse qmm (#1124)

This commit is contained in:
Angelos Katharopoulos
2024-05-16 15:24:14 -07:00
committed by GitHub
parent 1873ffda01
commit e78a6518fa
15 changed files with 1724 additions and 164 deletions

View File

@@ -1157,6 +1157,19 @@ array dequantize(
int bits = 4,
StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
array block_sparse_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */
array tensordot(
const array& a,