Block sparse qmm (#1124)

This commit is contained in:
Angelos Katharopoulos
2024-05-16 15:24:14 -07:00
committed by GitHub
parent 1873ffda01
commit e78a6518fa
15 changed files with 1724 additions and 164 deletions

View File

@@ -1467,6 +1467,34 @@ class QuantizedMatmul : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BlockSparseQMM : public UnaryPrimitive {
public:
explicit BlockSparseQMM(
Stream stream,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(BlockSparseQMM)
bool is_equivalent(const Primitive& other) const override;
private:
int group_size_;
int bits_;
bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
};
class RandomBits : public UnaryPrimitive {
public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)