From 51854d1a191de00f09c0a2e27c15f03b08b3b8a5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 10 Jul 2025 14:34:01 -0700 Subject: [PATCH] format --- mlx/backend/cuda/quantized.cu | 42 +++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu index ae44a760a..12a1f6fe4 100644 --- a/mlx/backend/cuda/quantized.cu +++ b/mlx/backend/cuda/quantized.cu @@ -50,7 +50,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { size_t offset = tidx + grid_dim.x * size_t(tidy); size_t in_index = offset * values_per_reduce; - if (in_index > size) { + if (in_index >= size) { return; } size_t out_index = power_of_2_bits @@ -174,27 +174,51 @@ __global__ void affine_dequantize( w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x7) * scale + bias; out[1] = static_cast((w[0] & 0x38) >> 3) * scale + bias; - out[2] = static_cast(((w[0] & 0xc0) >> 6) + static_cast((w[1] & 0x1) << 2)) * scale + bias; + out[2] = (static_cast((w[0] & 0xc0) >> 6) + + static_cast((w[1] & 0x1) << 2)) * + scale + + bias; out[3] = static_cast((w[1] & 0xe) >> 1) * scale + bias; out[4] = static_cast((w[1] & 0x70) >> 4) * scale + bias; - out[5] = static_cast(((w[1] & 0x80) >> 7) + static_cast((w[2] & 0x3) << 1)) * scale + bias; + out[5] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0x3) << 1)) * + scale + + bias; out[6] = static_cast((w[2] & 0x1c) >> 2) * scale + bias; out[7] = static_cast((w[2] & 0xe0) >> 5) * scale + bias; } else if constexpr (bits == 5) { w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x1f) * scale + bias; - out[1] = static_cast(((w[0] & 0xe0) >> 5) + static_cast((w[1] & 0x3) << 3)) * scale + bias; + out[1] = (static_cast((w[0] & 0xe0) >> 5) + + static_cast((w[1] & 0x3) << 3)) * + scale + + bias; out[2] = static_cast((w[1] & 0x7c) >> 2) * scale + bias; - out[3] = static_cast(((w[1] & 0x80) >> 7) + static_cast((w[2] & 0xf) << 1)) * scale + bias; - out[4] = static_cast(((w[2] & 0xf0) >> 4) + static_cast((w[3] & 0x1) << 4)) * scale + bias; + out[3] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0xf) << 1)) * + scale + + bias; + out[4] = (static_cast((w[2] & 0xf0) >> 4) + + static_cast((w[3] & 0x1) << 4)) * + scale + + bias; out[5] = static_cast((w[3] & 0x3e) >> 1) * scale + bias; - out[6] = static_cast(((w[3] & 0xc0) >> 6) + static_cast((w[4] & 0x7) << 2)) * scale + bias; + out[6] = (static_cast((w[3] & 0xc0) >> 6) + + static_cast((w[4] & 0x7) << 2)) * + scale + + bias; out[7] = static_cast((w[4] & 0xf8) >> 3) * scale + bias; } else if constexpr (bits == 6) { w += offset * bytes_per_pack; out[0] = static_cast(w[0] & 0x3f) * scale + bias; - out[1] = static_cast(((w[0] >> 6) & 0x03) + static_cast((w[1] & 0x0f) << 2)) * scale + bias; - out[2] = static_cast(((w[1] >> 4) & 0x0f) + static_cast((w[2] & 0x03) << 4)) * scale + bias; + out[1] = (static_cast((w[0] >> 6) & 0x03) + + static_cast((w[1] & 0x0f) << 2)) * + scale + + bias; + out[2] = (static_cast((w[1] >> 4) & 0x0f) + + static_cast((w[2] & 0x03) << 4)) * + scale + + bias; out[3] = static_cast((w[2] >> 2) & 0x3f) * scale + bias; } else { uint val = w[offset];