mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
This commit is contained in:
25
mlx/fast.cpp
25
mlx/fast.cpp
@@ -915,6 +915,31 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||
return (
|
||||
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
|
||||
p_other.dequantize_ == dequantize_);
|
||||
}
|
||||
|
||||
std::vector<Shape> AffineQuantize::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
auto& w = inputs[0];
|
||||
if (dequantize_) {
|
||||
auto out_size = w.shape(-1) * 32 / bits_;
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return {std::move(out_shape)};
|
||||
} else {
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits_ / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size_;
|
||||
auto bshape = sshape;
|
||||
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||
}
|
||||
}
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
|
||||
Reference in New Issue
Block a user