mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 19:31:16 +08:00
Add a quantization type in the ops
This commit is contained in:
parent
50f3535693
commit
bdd68bd893
136
mlx/ops.cpp
136
mlx/ops.cpp
@ -75,10 +75,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const std::optional<array>& biases,
|
||||||
bool transpose,
|
bool transpose,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits) {
|
int bits,
|
||||||
|
QuantizationType type) {
|
||||||
|
// Check if we have biases as expected
|
||||||
|
switch (type) {
|
||||||
|
case QuantizationType::Affine:
|
||||||
|
if (!biases.has_value()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag
|
||||||
|
<< "] The biases argument is required for quantization "
|
||||||
|
<< "type '" << type << "'";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case QuantizationType::AffinePacked:
|
||||||
|
if (biases.has_value()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] Quantization type '" << type
|
||||||
|
<< "' does not use "
|
||||||
|
<< "biases but biases were provided";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (w.dtype() != uint32) {
|
if (w.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||||
@ -86,11 +109,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (scales.shape() != biases.shape()) {
|
if (biases.has_value() && scales.shape() != biases.value().shape()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||||
<< "Received scales with shape " << scales.shape()
|
<< "Received scales with shape " << scales.shape()
|
||||||
<< " and biases with " << biases.shape();
|
<< " and biases with " << biases.value().shape();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,25 +122,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag
|
msg << "[" << tag
|
||||||
<< "] Weight, scales and biases should have the same batch shape. "
|
<< "] Weight, scales and biases should have the same batch shape. "
|
||||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
<< "Received weight with shape " << w.shape()
|
||||||
<< scales.shape() << " and biases with " << biases.shape();
|
<< " and scales/biases with " << scales.shape();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
int weight_dims = w.shape(-1) * 32 / bits;
|
||||||
|
int scales_dims = scales.shape(-1) * group_size;
|
||||||
|
if (type == QuantizationType::AffinePacked) {
|
||||||
|
scales_dims /= 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weight_dims != scales_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||||
<< "incompatible based on bits and group_size. w.shape() == "
|
<< "incompatible based on bits, group_size and quantization type. "
|
||||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
<< "w.shape() == " << w.shape()
|
||||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
<< " and scales.shape() == " << scales.shape()
|
||||||
|
<< " with group_size=" << group_size << ", bits=" << bits
|
||||||
|
<< " and type='" << type << "'";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int x_inner_dims = x.shape(-1);
|
int x_inner_dims = x.shape(-1);
|
||||||
|
|
||||||
// Calculate the expanded w's dims
|
// Calculate the expanded w's dims
|
||||||
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
|
int w_inner_dims = (transpose) ? weight_dims : w.shape(-2);
|
||||||
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
|
int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims;
|
||||||
|
|
||||||
if (w_inner_dims != x_inner_dims) {
|
if (w_inner_dims != x_inner_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3662,14 +3693,23 @@ array quantized_matmul(
|
|||||||
array x,
|
array x,
|
||||||
array w,
|
array w,
|
||||||
array scales,
|
array scales,
|
||||||
array biases,
|
std::optional<array> biases,
|
||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Check and extract the quantized matrix shape against x
|
// Check and extract the quantized matrix shape against x
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul",
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
type);
|
||||||
|
|
||||||
// QuantizedMatmul handles w.ndim == 2 case.
|
// QuantizedMatmul handles w.ndim == 2 case.
|
||||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||||
@ -3690,37 +3730,53 @@ array quantized_matmul(
|
|||||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||||
scales = broadcast_to(scales, inner_shape, s);
|
scales = broadcast_to(scales, inner_shape, s);
|
||||||
|
|
||||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
if (biases.has_value()) {
|
||||||
biases = broadcast_to(biases, inner_shape, s);
|
*(inner_shape.end() - 1) = biases.value().shape(-1);
|
||||||
|
biases = broadcast_to(biases.value(), inner_shape, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dtype = result_type(x, scales, biases);
|
auto dtype = result_type(x, scales);
|
||||||
|
if (biases.has_value()) {
|
||||||
|
dtype = promote_types(dtype, biases.value().dtype());
|
||||||
|
}
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||||
<< "the passed types where x.dtype() == " << x.dtype()
|
<< "the passed types where x.dtype() == " << x.dtype()
|
||||||
<< ", scales.dtype() == " << scales.dtype()
|
<< ", scales.dtype() == " << scales.dtype();
|
||||||
<< " and biases.dtype() == " << biases.dtype();
|
if (biases.has_value()) {
|
||||||
|
msg << " and biases.dtype() == " << biases.value().dtype();
|
||||||
|
}
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare the inputs vector
|
||||||
|
std::vector<array> inputs;
|
||||||
|
inputs.reserve(4);
|
||||||
|
inputs.push_back(astype(x, dtype, s));
|
||||||
|
inputs.push_back(w);
|
||||||
|
inputs.push_back(astype(scales, dtype, s));
|
||||||
|
if (biases.has_value()) {
|
||||||
|
inputs.push_back(astype(biases.value(), dtype, s));
|
||||||
|
}
|
||||||
|
|
||||||
auto out_shape = x.shape();
|
auto out_shape = x.shape();
|
||||||
out_shape.back() = w_outer_dims;
|
out_shape.back() = w_outer_dims;
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), type, group_size, bits, transpose),
|
||||||
{astype(x, dtype, s),
|
std::move(inputs));
|
||||||
w,
|
|
||||||
astype(scales, dtype, s),
|
|
||||||
astype(biases, dtype, s)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, std::optional<array>> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fast::affine_quantize(w, group_size, bits, s);
|
return fast::affine_quantize(w, group_size, bits, s);
|
||||||
}
|
}
|
||||||
@ -3728,31 +3784,40 @@ std::tuple<array, array, array> quantize(
|
|||||||
array dequantize(
|
array dequantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const std::optional<array>& biases,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
return fast::affine_dequantize(
|
||||||
|
w, scales, biases.value(), group_size, bits, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array gather_qmm(
|
array gather_qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const std::optional<array>& biases,
|
||||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
const std::optional<array>& lhs_indices_ /* = std::nullopt */,
|
||||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
const std::optional<array>& rhs_indices_ /* = std::nullopt */,
|
||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
x, w, scales, biases, transpose, group_size, bits, s);
|
x, w, scales, biases, transpose, group_size, bits, type, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type != QuantizationType::Affine) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
|
||||||
|
|
||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
@ -3768,16 +3833,17 @@ array gather_qmm(
|
|||||||
out_shape.push_back(w_outer_dims);
|
out_shape.push_back(w_outer_dims);
|
||||||
|
|
||||||
// and output type
|
// and output type
|
||||||
auto out_type = result_type(x, scales, biases);
|
auto out_type = result_type(x, scales, biases.value());
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
std::make_shared<GatherQMM>(
|
||||||
|
to_stream(s), type, group_size, bits, transpose),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
astype(biases, out_type, s),
|
astype(biases.value(), out_type, s),
|
||||||
lhs_indices,
|
lhs_indices,
|
||||||
rhs_indices});
|
rhs_indices});
|
||||||
}
|
}
|
||||||
|
18
mlx/ops.h
18
mlx/ops.h
@ -1277,31 +1277,34 @@ array conv_transpose3d(
|
|||||||
int groups = 1,
|
int groups = 1,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
/** Quantized matmul multiplies x with a quantized matrix w */
|
||||||
array quantized_matmul(
|
array quantized_matmul(
|
||||||
array x,
|
array x,
|
||||||
array w,
|
array w,
|
||||||
array scales,
|
array scales,
|
||||||
array biases,
|
std::optional<array> biases,
|
||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationType type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantize a matrix along its last axis */
|
/** Quantize a matrix along its last axis */
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, std::optional<array>> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationType type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Dequantize a matrix produced by quantize() */
|
/** Dequantize a matrix produced by quantize() */
|
||||||
array dequantize(
|
array dequantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const std::optional<array>& biases,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationType type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute matrix products with matrix-level gather. */
|
/** Compute matrix products with matrix-level gather. */
|
||||||
@ -1309,12 +1312,13 @@ array gather_qmm(
|
|||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const std::optional<array>& biases,
|
||||||
std::optional<array> lhs_indices = std::nullopt,
|
const std::optional<array>& lhs_indices = std::nullopt,
|
||||||
std::optional<array> rhs_indices = std::nullopt,
|
const std::optional<array>& rhs_indices = std::nullopt,
|
||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationType type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
|
@ -2777,10 +2777,11 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
cotangents[0],
|
cotangents[0],
|
||||||
primals[1],
|
primals[1],
|
||||||
primals[2],
|
primals[2],
|
||||||
primals[3],
|
(primals.size() > 3) ? std::optional(primals[3]) : std::nullopt,
|
||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
type_,
|
||||||
stream()));
|
stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2855,6 +2856,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
type_,
|
||||||
stream()),
|
stream()),
|
||||||
-3,
|
-3,
|
||||||
stream()),
|
stream()),
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/io/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#define DEFINE_VMAP() \
|
#define DEFINE_VMAP() \
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
|
||||||
@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit QuantizedMatmul(
|
explicit QuantizedMatmul(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
|
QuantizationType type,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
bool transpose)
|
bool transpose)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
|
type_(type),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose) {}
|
||||||
@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
QuantizationType type_;
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
|
|
||||||
class GatherQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
explicit GatherQMM(
|
||||||
|
Stream stream,
|
||||||
|
QuantizationType type,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
|
type_(type),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose) {}
|
||||||
@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
QuantizationType type_;
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
@ -145,6 +145,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, QuantizationType type) {
|
||||||
|
std::string_view quantization_type;
|
||||||
|
switch (type) {
|
||||||
|
case QuantizationType::Affine:
|
||||||
|
quantization_type = "affine";
|
||||||
|
break;
|
||||||
|
case QuantizationType::AffinePacked:
|
||||||
|
quantization_type = "affine-packed";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return os << quantization_type;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline size_t
|
inline size_t
|
||||||
|
12
mlx/utils.h
12
mlx/utils.h
@ -100,6 +100,18 @@ inline int next_power_of_2(int n) {
|
|||||||
return pow(2, std::ceil(std::log2(n)));
|
return pow(2, std::ceil(std::log2(n)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Enumerate the different quantization types */
|
||||||
|
enum class QuantizationType {
|
||||||
|
// Traditional affine quantization with separate scales and biases
|
||||||
|
Affine,
|
||||||
|
|
||||||
|
// The same quantization as affine but with the scales and biases packed in a
|
||||||
|
// single array and contiguously every 4 rows
|
||||||
|
AffinePacked,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
||||||
|
|
||||||
namespace env {
|
namespace env {
|
||||||
|
|
||||||
int get_var(const char* name, int default_value);
|
int get_var(const char* name, int default_value);
|
||||||
|
Loading…
Reference in New Issue
Block a user