Gather qmm batched kernel and refactoring of quantized (#2078)

This commit is contained in:
Angelos Katharopoulos
2025-04-17 13:53:11 -07:00
committed by GitHub
parent 99eefd2ec0
commit 5de6d94a90
15 changed files with 1479 additions and 449 deletions

View File

@@ -1591,11 +1591,19 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
bool left_sorted = false,
bool right_sorted = false)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
left_sorted_(left_sorted),
right_sorted_(right_sorted) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -1605,13 +1613,16 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
}
private:
int group_size_;
int bits_;
bool transpose_;
bool left_sorted_;
bool right_sorted_;
};
class RandomBits : public UnaryPrimitive {