mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
42
mlx/fast.cpp
42
mlx/fast.cpp
@@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
};
|
||||
};
|
||||
|
||||
std::vector<array> outputs;
|
||||
if (s.device == Device::gpu) {
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) / el_per_int;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
outputs = array::make_arrays(
|
||||
{wq_shape, sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w});
|
||||
} else {
|
||||
outputs = fallback({w});
|
||||
}
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) / el_per_int;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
auto outputs = array::make_arrays(
|
||||
{std::move(wq_shape), sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w});
|
||||
return {outputs[0], outputs[1], outputs[2]};
|
||||
}
|
||||
|
||||
@@ -814,16 +809,13 @@ array affine_quantize(
|
||||
return {reshape(packed_w, wshape, s)};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = w.shape(-1) / el_per_int;
|
||||
return array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w, scales, biases});
|
||||
}
|
||||
return fallback({w, scales, biases})[0];
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = w.shape(-1) / el_per_int;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w, scales, biases});
|
||||
}
|
||||
|
||||
array affine_dequantize(
|
||||
@@ -916,7 +908,7 @@ array affine_dequantize(
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = w.shape(-1) * el_per_int;
|
||||
return array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
scales.dtype(),
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||
{w, scales, biases});
|
||||
|
||||
Reference in New Issue
Block a user