diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 38ca9f371..da7788971 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -543,8 +543,8 @@ void quantize( T* scales = scales_.data(); T* biases = biases_.data(); - T n_bins = (1 << bits) - 1; - T eps = 1e-7; + float n_bins = (1 << bits) - 1; + float eps = 1e-7; bool power_of_2_bits = is_power_of_2(bits); int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 @@ -554,32 +554,30 @@ void quantize( for (size_t i = 0; i < n_groups; ++i) { size_t w_idx = i * group_size; - T w_min = std::numeric_limits::infinity(); - T w_max = -w_min; + float w_min = std::numeric_limits::infinity(); + float w_max = -w_min; for (int j = 0; j < group_size; ++j) { - w_max = std::max(w_max, w[w_idx + j]); - w_min = std::min(w_min, w[w_idx + j]); + w_max = std::max(w_max, (float)w[w_idx + j]); + w_min = std::min(w_min, (float)w[w_idx + j]); } bool mask = std::abs(w_min) > std::abs(w_max); - T scale = std::max(T((w_max - w_min) / n_bins), eps); + float scale = std::max((w_max - w_min) / n_bins, eps); scale = mask ? scale : -scale; - auto edge = mask ? w_min : w_max; - auto q0 = std::rint(edge / scale); - if (q0 == 0) { - scales[i] = scale; - biases[i] = 0; - } else { - scales[i] = edge / q0; - biases[i] = edge; + float edge = mask ? w_min : w_max; + float q0 = std::rint(edge / scale); + float bias = 0; + if (q0 != 0) { + scale = edge / q0; + bias = edge; } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { uint32_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { - T w_el = w[w_idx + j * el_per_int + k]; - w_el = std::rint((w_el - biases[i]) / scales[i]); - w_el = std::min(std::max(w_el, T(0)), n_bins); + float w_el = w[w_idx + j * el_per_int + k]; + w_el = std::rint((w_el - bias) / scale); + w_el = std::min(std::max(w_el, 0.0f), n_bins); out_el |= static_cast(w_el) << (k * bits); } if (power_of_2_bits) { @@ -590,6 +588,8 @@ void quantize( out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; } } + scales[i] = static_cast(scale); + biases[i] = static_cast(bias); } } diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 1652207e3..3af3c971f 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2015,9 +2015,9 @@ template device T* biases [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr T eps = T(1e-7); + constexpr float eps = 1e-7; constexpr int simd_size = 32; - constexpr T n_bins = (1 << bits) - 1; + constexpr float n_bins = (1 << bits) - 1; constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = packs_per_int / values_per_reduce; @@ -2036,13 +2036,13 @@ template ? offset * writes_per_pack : offset * bytes_per_pack / writes_per_reduce; - T w_thread[values_per_reduce]; - T w_min = Limits::max; - T w_max = 0; + float w_thread[values_per_reduce]; + float w_min = Limits::max; + float w_max = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { - T val = w[in_index + i]; + float val = w[in_index + i]; w_thread[i] = val; w_min = min(w_min, val); w_max = max(w_max, val); @@ -2051,20 +2051,20 @@ template w_min = simd_min(w_min); w_max = simd_max(w_max); - T scale = max((w_max - w_min) / n_bins, eps); + float scale = max((w_max - w_min) / n_bins, eps); bool side = abs(w_min) > abs(w_max); scale = side ? scale : -scale; - T edge = side ? w_min : w_max; - T q0 = round(edge / scale); + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); bool at_zero = q0 == 0.0f; scale = at_zero ? scale : edge / q0; - T bias = at_zero ? T(0) : edge; + float bias = at_zero ? 0 : edge; // Write out the scales and biases size_t gindex = in_index / group_size; if (in_index % group_size == 0) { - scales[gindex] = scale; - biases[gindex] = bias; + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); } // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 8ef6a8469..136c7796a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -827,14 +827,17 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { auto wshape = w.shape(); wshape.back() = -1; - array zero(0, w.dtype()); - array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1 - array eps(1e-7, w.dtype()); + array zero(0, float32); + array n_bins((1 << bits) - 1, float32); // 2**bits - 1 + array eps(1e-7, float32); array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + w_max = astype(w_max, float32, s); + w_min = astype(w_min, float32, s); + array mask = greater(abs(w_min, s), abs(w_max, s), s); array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); @@ -845,6 +848,9 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { array biases = where(equal(q0, zero, s), zero, edge, s); packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); + + scales = astype(scales, w.dtype(), s); + biases = astype(biases, w.dtype(), s); return { reshape(packed_w, wshape, s), reshape(scales, wshape, s),