More primitives for compiling with shapeless (#1653)

* more shapeless and more Shape

* more shape

* fix

* fix
This commit is contained in:
Awni Hannun
2024-12-06 11:29:18 -08:00
committed by GitHub
parent 95c4a2e3af
commit d0b6cb0425
5 changed files with 160 additions and 81 deletions

View File

@@ -915,6 +915,31 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
bool AffineQuantize::is_equivalent(const Primitive& other) const {
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
return (
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
p_other.dequantize_ == dequantize_);
}
std::vector<Shape> AffineQuantize::output_shapes(
const std::vector<array>& inputs) {
auto& w = inputs[0];
if (dequantize_) {
auto out_size = w.shape(-1) * 32 / bits_;
auto out_shape = w.shape();
out_shape.back() = out_size;
return {std::move(out_shape)};
} else {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) * bits_ / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size_;
auto bshape = sshape;
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
}
}
std::string write_signature(
std::string func_name,
const std::string& header,