mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
format
This commit is contained in:
@@ -50,7 +50,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
|
||||
size_t offset = tidx + grid_dim.x * size_t(tidy);
|
||||
size_t in_index = offset * values_per_reduce;
|
||||
if (in_index > size) {
|
||||
if (in_index >= size) {
|
||||
return;
|
||||
}
|
||||
size_t out_index = power_of_2_bits
|
||||
@@ -174,27 +174,51 @@ __global__ void affine_dequantize(
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = static_cast<T>(w[0] & 0x7) * scale + bias;
|
||||
out[1] = static_cast<T>((w[0] & 0x38) >> 3) * scale + bias;
|
||||
out[2] = static_cast<T>(((w[0] & 0xc0) >> 6) + static_cast<T>((w[1] & 0x1) << 2)) * scale + bias;
|
||||
out[2] = (static_cast<T>((w[0] & 0xc0) >> 6) +
|
||||
static_cast<T>((w[1] & 0x1) << 2)) *
|
||||
scale +
|
||||
bias;
|
||||
out[3] = static_cast<T>((w[1] & 0xe) >> 1) * scale + bias;
|
||||
out[4] = static_cast<T>((w[1] & 0x70) >> 4) * scale + bias;
|
||||
out[5] = static_cast<T>(((w[1] & 0x80) >> 7) + static_cast<T>((w[2] & 0x3) << 1)) * scale + bias;
|
||||
out[5] = (static_cast<T>((w[1] & 0x80) >> 7) +
|
||||
static_cast<T>((w[2] & 0x3) << 1)) *
|
||||
scale +
|
||||
bias;
|
||||
out[6] = static_cast<T>((w[2] & 0x1c) >> 2) * scale + bias;
|
||||
out[7] = static_cast<T>((w[2] & 0xe0) >> 5) * scale + bias;
|
||||
} else if constexpr (bits == 5) {
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = static_cast<T>(w[0] & 0x1f) * scale + bias;
|
||||
out[1] = static_cast<T>(((w[0] & 0xe0) >> 5) + static_cast<T>((w[1] & 0x3) << 3)) * scale + bias;
|
||||
out[1] = (static_cast<T>((w[0] & 0xe0) >> 5) +
|
||||
static_cast<T>((w[1] & 0x3) << 3)) *
|
||||
scale +
|
||||
bias;
|
||||
out[2] = static_cast<T>((w[1] & 0x7c) >> 2) * scale + bias;
|
||||
out[3] = static_cast<T>(((w[1] & 0x80) >> 7) + static_cast<T>((w[2] & 0xf) << 1)) * scale + bias;
|
||||
out[4] = static_cast<T>(((w[2] & 0xf0) >> 4) + static_cast<T>((w[3] & 0x1) << 4)) * scale + bias;
|
||||
out[3] = (static_cast<T>((w[1] & 0x80) >> 7) +
|
||||
static_cast<T>((w[2] & 0xf) << 1)) *
|
||||
scale +
|
||||
bias;
|
||||
out[4] = (static_cast<T>((w[2] & 0xf0) >> 4) +
|
||||
static_cast<T>((w[3] & 0x1) << 4)) *
|
||||
scale +
|
||||
bias;
|
||||
out[5] = static_cast<T>((w[3] & 0x3e) >> 1) * scale + bias;
|
||||
out[6] = static_cast<T>(((w[3] & 0xc0) >> 6) + static_cast<T>((w[4] & 0x7) << 2)) * scale + bias;
|
||||
out[6] = (static_cast<T>((w[3] & 0xc0) >> 6) +
|
||||
static_cast<T>((w[4] & 0x7) << 2)) *
|
||||
scale +
|
||||
bias;
|
||||
out[7] = static_cast<T>((w[4] & 0xf8) >> 3) * scale + bias;
|
||||
} else if constexpr (bits == 6) {
|
||||
w += offset * bytes_per_pack;
|
||||
out[0] = static_cast<T>(w[0] & 0x3f) * scale + bias;
|
||||
out[1] = static_cast<T>(((w[0] >> 6) & 0x03) + static_cast<T>((w[1] & 0x0f) << 2)) * scale + bias;
|
||||
out[2] = static_cast<T>(((w[1] >> 4) & 0x0f) + static_cast<T>((w[2] & 0x03) << 4)) * scale + bias;
|
||||
out[1] = (static_cast<T>((w[0] >> 6) & 0x03) +
|
||||
static_cast<T>((w[1] & 0x0f) << 2)) *
|
||||
scale +
|
||||
bias;
|
||||
out[2] = (static_cast<T>((w[1] >> 4) & 0x0f) +
|
||||
static_cast<T>((w[2] & 0x03) << 4)) *
|
||||
scale +
|
||||
bias;
|
||||
out[3] = static_cast<T>((w[2] >> 2) & 0x3f) * scale + bias;
|
||||
} else {
|
||||
uint val = w[offset];
|
||||
|
||||
Reference in New Issue
Block a user