Rename block sparse (#1149)

* block_sparse_mm to gather_mm

* rename

* nit

* nit
This commit is contained in:
Awni Hannun
2024-05-22 07:48:34 -07:00
committed by GitHub
parent e6fecbb3e1
commit d568c7ee36
16 changed files with 120 additions and 111 deletions

View File

@@ -502,9 +502,9 @@ class BlockMaskedMM : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BlockSparseMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive {
public:
explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {};
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -515,7 +515,7 @@ class BlockSparseMM : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(BlockSparseMM)
DEFINE_PRINT(GatherMM)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
@@ -1467,13 +1467,9 @@ class QuantizedMatmul : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BlockSparseQMM : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive {
public:
explicit BlockSparseQMM(
Stream stream,
int group_size,
int bits,
bool transpose)
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
@@ -1484,7 +1480,7 @@ class BlockSparseQMM : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(BlockSparseQMM)
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
private: