mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
metal kernels
This commit is contained in:
@@ -120,7 +120,7 @@ Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
|
|||||||
struct ToFP8 {
|
struct ToFP8 {
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
||||||
uint32_t fp8_max = 1087 << 20;
|
uint32_t fp8_max = 543 << 21;
|
||||||
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
||||||
Simd<uint32_t, N> f_bits;
|
Simd<uint32_t, N> f_bits;
|
||||||
Simd<float, N> f32 = f;
|
Simd<float, N> f32 = f;
|
||||||
|
|||||||
@@ -1,7 +1,22 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
struct __nv_fp8_e8m0 {
|
struct __nv_fp8_e8m0 {
|
||||||
__device__ __nv_fp8_e8m0(uint8_t x) : __x(x) {}
|
__device__ __nv_fp8_e8m0(float x) {
|
||||||
|
if (!std::isfinite(x)) {
|
||||||
|
__x = 0xFF;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (x < 0.0f) {
|
||||||
|
__x = 0x00;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
float le = std::log2f(x);
|
||||||
|
int n = static_cast<int>(std::nearbyintf(le));
|
||||||
|
|
||||||
|
n = n < -127 ? -127 : n;
|
||||||
|
n = n > 127 ? 127 : n;
|
||||||
|
__x = static_cast<uint8_t>(n + 127);
|
||||||
|
}
|
||||||
|
|
||||||
__device__ operator float() {
|
__device__ operator float() {
|
||||||
if (__x == 0xFF) {
|
if (__x == 0xFF) {
|
||||||
|
|||||||
@@ -49,13 +49,12 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|||||||
|
|
||||||
auto grid_dim_x =
|
auto grid_dim_x =
|
||||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||||
size_t out_index = tidx + grid_dim_x * size_t(tidy);
|
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||||
size_t in_index = out_index;
|
if (index >= size) {
|
||||||
if (in_index >= size) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float w_thread = w[in_index];
|
float w_thread = w[index];
|
||||||
|
|
||||||
cg::greater<float> max_op;
|
cg::greater<float> max_op;
|
||||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||||
@@ -70,21 +69,19 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|||||||
scale = float(s);
|
scale = float(s);
|
||||||
|
|
||||||
// Write out the scales
|
// Write out the scales
|
||||||
size_t gindex = in_index / group_size;
|
size_t gindex = index / group_size;
|
||||||
if (in_index % group_size == 0) {
|
if (index % group_size == 0) {
|
||||||
scales[gindex] = q_scale;
|
scales[gindex] = q_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t output = 0;
|
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||||
uint8_t val = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
|
||||||
output = val;
|
|
||||||
if (bits == 4) {
|
if (bits == 4) {
|
||||||
uint8_t sval = warp.shfl_down(val, 1);
|
uint8_t sval = warp.shfl_down(output, 1);
|
||||||
output |= sval << bits;
|
output |= sval << bits;
|
||||||
}
|
}
|
||||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
if (out_index % pack_factor == 0) {
|
if (index % pack_factor == 0) {
|
||||||
out[out_index / pack_factor] = output;
|
out[index / pack_factor] = output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ make_jit_source(
|
|||||||
kernels/bf16_math.h
|
kernels/bf16_math.h
|
||||||
kernels/complex.h
|
kernels/complex.h
|
||||||
kernels/defines.h)
|
kernels/defines.h)
|
||||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
||||||
make_jit_source(binary_ops)
|
make_jit_source(binary_ops)
|
||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ set(BASE_HEADERS
|
|||||||
defines.h
|
defines.h
|
||||||
erf.h
|
erf.h
|
||||||
expm1f.h
|
expm1f.h
|
||||||
|
fp8.h
|
||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
@@ -109,7 +110,7 @@ if(NOT MLX_METAL_JIT)
|
|||||||
reduction/reduce_col.h
|
reduction/reduce_col.h
|
||||||
reduction/reduce_row.h)
|
reduction/reduce_row.h)
|
||||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||||
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||||
build_kernel(scan scan.h)
|
build_kernel(scan scan.h)
|
||||||
build_kernel(softmax softmax.h)
|
build_kernel(softmax softmax.h)
|
||||||
build_kernel(logsumexp logsumexp.h)
|
build_kernel(logsumexp logsumexp.h)
|
||||||
|
|||||||
56
mlx/backend/metal/kernels/fp4.h
Normal file
56
mlx/backend/metal/kernels/fp4.h
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
constexpr constant static float FP4_LUT[16] = {
|
||||||
|
+0.0f,
|
||||||
|
+0.5f,
|
||||||
|
+1.0f,
|
||||||
|
+1.5f,
|
||||||
|
+2.0f,
|
||||||
|
+3.0f,
|
||||||
|
+4.0f,
|
||||||
|
+6.0f,
|
||||||
|
-0.0f,
|
||||||
|
-0.5f,
|
||||||
|
-1.0f,
|
||||||
|
-1.5f,
|
||||||
|
-2.0f,
|
||||||
|
-3.0f,
|
||||||
|
-4.0f,
|
||||||
|
-6.0f};
|
||||||
|
|
||||||
|
struct fp4_e2m1 {
|
||||||
|
fp4_e2m1(float x) {
|
||||||
|
if (metal::isnan(x)) {
|
||||||
|
bits = 0x7;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
|
||||||
|
x = metal::abs(x);
|
||||||
|
|
||||||
|
if (x > 5.0f) {
|
||||||
|
bits = 0x7;
|
||||||
|
} else if (x >= 3.5f) {
|
||||||
|
bits = 0x6;
|
||||||
|
} else if (x > 2.5f) {
|
||||||
|
bits = 0x5;
|
||||||
|
} else if (x >= 1.75f) {
|
||||||
|
bits = 0x4;
|
||||||
|
} else if (x > 1.25f) {
|
||||||
|
bits = 0x3;
|
||||||
|
} else if (x >= 0.75f) {
|
||||||
|
bits = 0x2;
|
||||||
|
} else if (x > 0.25f) {
|
||||||
|
bits = 0x1;
|
||||||
|
} else {
|
||||||
|
bits = 0x0;
|
||||||
|
}
|
||||||
|
bits |= sign_bit;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator float() {
|
||||||
|
return FP4_LUT[bits];
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t bits;
|
||||||
|
};
|
||||||
88
mlx/backend/metal/kernels/fp8.h
Normal file
88
mlx/backend/metal/kernels/fp8.h
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
inline float fp32_from_bits(uint32_t bits) {
|
||||||
|
return *(reinterpret_cast<thread float*>(&bits));
|
||||||
|
}
|
||||||
|
inline float fp32_to_bits(float x) {
|
||||||
|
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct fp8_e4m3 {
|
||||||
|
template <typename T>
|
||||||
|
fp8_e4m3(T f) {
|
||||||
|
// From PyTorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||||
|
uint32_t fp8_max = 543 << 21;
|
||||||
|
uint32_t denorm_mask = 141 << 23;
|
||||||
|
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||||
|
uint32_t sign = f_bits & 0x80000000;
|
||||||
|
f_bits ^= sign;
|
||||||
|
if (f_bits >= fp8_max) {
|
||||||
|
// Default behavior saturates to min/max
|
||||||
|
bits = 0x7E;
|
||||||
|
} else {
|
||||||
|
if (f_bits < (121 << 23)) {
|
||||||
|
f_bits =
|
||||||
|
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||||
|
bits = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||||
|
} else {
|
||||||
|
// resulting mantissa is odd
|
||||||
|
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||||
|
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||||
|
f_bits += mant_odd;
|
||||||
|
bits = static_cast<uint8_t>(f_bits >> 20);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bits |= static_cast<uint8_t>(sign >> 24);
|
||||||
|
}
|
||||||
|
|
||||||
|
operator float() {
|
||||||
|
// From PyTorch:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||||
|
uint32_t w = static_cast<uint32_t>(bits) << 24;
|
||||||
|
uint32_t sign = w & 0x80000000;
|
||||||
|
uint32_t nonsign = w & 0x7FFFFFFF;
|
||||||
|
|
||||||
|
uint32_t renorm_shift = metal::clz(nonsign);
|
||||||
|
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||||
|
|
||||||
|
int32_t inf_nan_mask =
|
||||||
|
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||||
|
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
||||||
|
uint32_t result = sign |
|
||||||
|
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||||
|
inf_nan_mask) &
|
||||||
|
~zero_mask);
|
||||||
|
return fp32_from_bits(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t bits;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct fp8_e8m0 {
|
||||||
|
fp8_e8m0(float x) {
|
||||||
|
if (!metal::isfinite(x)) {
|
||||||
|
bits = 0xFF;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (x < 0.0f) {
|
||||||
|
bits = 0x00;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
float le = metal::log2(x);
|
||||||
|
int n = int(metal::round(le));
|
||||||
|
|
||||||
|
n = n < -127 ? -127 : n;
|
||||||
|
n = n > 127 ? 127 : n;
|
||||||
|
bits = static_cast<uint8_t>(n + 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
operator float() {
|
||||||
|
if (bits == 0xFF) {
|
||||||
|
return metal::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t bits;
|
||||||
|
};
|
||||||
@@ -59,28 +59,10 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr constant static float MXFP4_LUT[16] = {
|
|
||||||
+0.0f,
|
|
||||||
+0.5f,
|
|
||||||
+1.0f,
|
|
||||||
+1.5f,
|
|
||||||
+2.0f,
|
|
||||||
+3.0f,
|
|
||||||
+4.0f,
|
|
||||||
+6.0f,
|
|
||||||
-0.0f,
|
|
||||||
-0.5f,
|
|
||||||
-1.0f,
|
|
||||||
-1.5f,
|
|
||||||
-2.0f,
|
|
||||||
-3.0f,
|
|
||||||
-4.0f,
|
|
||||||
-6.0f};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||||
if (simd_gid == 0 && simd_lid < 16) {
|
if (simd_gid == 0 && simd_lid < 16) {
|
||||||
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
|
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
@@ -1789,3 +1771,100 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int bits>
|
||||||
|
struct Quantize {
|
||||||
|
uint8_t operator()(float x) {
|
||||||
|
if constexpr (bits == 8) {
|
||||||
|
return fp8_e4m3(x).bits;
|
||||||
|
} else {
|
||||||
|
return fp4_e2m1(x).bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int bits>
|
||||||
|
struct Dequantize {
|
||||||
|
float operator()(uint8_t x) {
|
||||||
|
if constexpr (bits == 8) {
|
||||||
|
return float(*(thread fp8_e4m3*)(&x));
|
||||||
|
} else {
|
||||||
|
return float(*(thread fp4_e2m1*)(&x));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void fp_quantize(
|
||||||
|
const device T* w [[buffer(0)]],
|
||||||
|
device uint8_t* out [[buffer(1)]],
|
||||||
|
device uint8_t* scales [[buffer(2)]],
|
||||||
|
uint2 tidx [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
constexpr bool use_mx_scale = group_size == 32;
|
||||||
|
size_t index = tidx.x + grid_dim.x * size_t(tidx.y);
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
float w_thread = w[index];
|
||||||
|
if (use_mx_scale) {
|
||||||
|
scale = simd_max(abs(w_thread));
|
||||||
|
} else {
|
||||||
|
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
|
||||||
|
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
|
||||||
|
scale = tidx.x < 16 ? w_max_l : w_max_r;
|
||||||
|
}
|
||||||
|
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||||
|
|
||||||
|
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||||
|
auto s = ScaleType(scale);
|
||||||
|
uint8_t q_scale = s.bits;
|
||||||
|
scale = float(s);
|
||||||
|
|
||||||
|
// Write out the scales and biases
|
||||||
|
size_t gindex = index / group_size;
|
||||||
|
if (index % group_size == 0) {
|
||||||
|
scales[gindex] = q_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||||
|
if (bits == 4) {
|
||||||
|
uint8_t sval = simd_shuffle_down(output, 1);
|
||||||
|
output |= sval << bits;
|
||||||
|
}
|
||||||
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
|
if (index % pack_factor == 0) {
|
||||||
|
out[index / pack_factor] = output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void fp_dequantize(
|
||||||
|
const device uint8_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
device T* out [[buffer(3)]],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
constexpr bool use_mx_scale = group_size == 32;
|
||||||
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
size_t oindex = offset * pack_factor;
|
||||||
|
size_t gindex = oindex / group_size;
|
||||||
|
|
||||||
|
out += oindex;
|
||||||
|
|
||||||
|
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||||
|
auto q_scale = ((device ScaleType*)(scales))[gindex];
|
||||||
|
auto scale = float(q_scale);
|
||||||
|
|
||||||
|
uint val = w[offset];
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < pack_factor; i++) {
|
||||||
|
uint8_t d;
|
||||||
|
if (bits == 4) {
|
||||||
|
d = (val >> (bits * i)) & 0x0f;
|
||||||
|
} else if (bits == 8) {
|
||||||
|
d = val;
|
||||||
|
}
|
||||||
|
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,9 @@
|
|||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||||
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
#include "mlx/backend/metal/kernels/fp8.h"
|
||||||
|
#include "mlx/backend/metal/kernels/fp4.h"
|
||||||
|
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
||||||
|
|
||||||
#define instantiate_quantized(name, type) \
|
#define instantiate_quantized(name, type) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@@ -113,13 +115,33 @@
|
|||||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||||
|
|
||||||
|
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
|
fp_quantize, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
|
fp_dequantize, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits)
|
||||||
|
|
||||||
|
#define instantiate_quantize_dequantize_modes(type) \
|
||||||
|
instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \
|
||||||
|
instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \
|
||||||
|
instantiate_quantize_dequantize(type, "mxfp8", 32, 8)
|
||||||
|
|
||||||
#define instantiate_quantized_types(type) \
|
#define instantiate_quantized_types(type) \
|
||||||
instantiate_quantized_all_batched(type) \
|
instantiate_quantized_all_batched(type) \
|
||||||
instantiate_quantized_all_quad(type) \
|
instantiate_quantized_all_quad(type) \
|
||||||
instantiate_quantized_all_splitk(type) \
|
instantiate_quantized_all_splitk(type) \
|
||||||
instantiate_quantized_all_single(type) \
|
instantiate_quantized_all_single(type) \
|
||||||
instantiate_quantized_all_aligned(type) \
|
instantiate_quantized_all_aligned(type) \
|
||||||
instantiate_quantized_all_rhs(type)
|
instantiate_quantized_all_rhs(type) \
|
||||||
|
instantiate_quantize_dequantize_modes(type)
|
||||||
|
|
||||||
instantiate_quantized_types(float)
|
instantiate_quantized_types(float)
|
||||||
instantiate_quantized_types(bfloat16_t)
|
instantiate_quantized_types(bfloat16_t)
|
||||||
@@ -8,6 +8,7 @@
|
|||||||
#include "mlx/backend/metal/kernels/cexpf.h"
|
#include "mlx/backend/metal/kernels/cexpf.h"
|
||||||
#include "mlx/backend/metal/kernels/erf.h"
|
#include "mlx/backend/metal/kernels/erf.h"
|
||||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||||
|
#include "mlx/backend/metal/kernels/fp8.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constant float inf = metal::numeric_limits<float>::infinity();
|
constant float inf = metal::numeric_limits<float>::infinity();
|
||||||
@@ -439,63 +440,15 @@ complex64_t ArcTan::operator()(complex64_t x) {
|
|||||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||||
};
|
};
|
||||||
|
|
||||||
inline float fp32_from_bits(uint32_t bits) {
|
|
||||||
return *(reinterpret_cast<thread float*>(&bits));
|
|
||||||
}
|
|
||||||
inline float fp32_to_bits(float x) {
|
|
||||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ToFP8 {
|
struct ToFP8 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
uint8_t operator()(T f) {
|
uint8_t operator()(T f) {
|
||||||
// From PyTorch
|
return fp8_e4m3(f).bits;
|
||||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
|
||||||
uint32_t fp8_max = 1087 << 20;
|
|
||||||
uint32_t denorm_mask = 141 << 23;
|
|
||||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
|
||||||
uint8_t result = 0u;
|
|
||||||
uint32_t sign = f_bits & 0x80000000;
|
|
||||||
f_bits ^= sign;
|
|
||||||
if (f_bits >= fp8_max) {
|
|
||||||
// Default behavior saturates to min/max
|
|
||||||
result = 0x7E;
|
|
||||||
} else {
|
|
||||||
if (f_bits < (121 << 23)) {
|
|
||||||
f_bits =
|
|
||||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
|
||||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
|
||||||
} else {
|
|
||||||
// resulting mantissa is odd
|
|
||||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
|
||||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
|
||||||
f_bits += mant_odd;
|
|
||||||
result = static_cast<uint8_t>(f_bits >> 20);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result |= static_cast<uint8_t>(sign >> 24);
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FromFP8 {
|
struct FromFP8 {
|
||||||
float operator()(uint8_t x) {
|
float operator()(uint8_t x) {
|
||||||
// From PyTorch:
|
return float(*(thread fp8_e4m3*)(&x));
|
||||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
|
||||||
uint32_t w = static_cast<uint32_t>(x) << 24;
|
|
||||||
uint32_t sign = w & 0x80000000;
|
|
||||||
uint32_t nonsign = w & 0x7FFFFFFF;
|
|
||||||
|
|
||||||
uint32_t renorm_shift = metal::clz(nonsign);
|
|
||||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
|
||||||
|
|
||||||
int32_t inf_nan_mask =
|
|
||||||
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
|
||||||
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
|
||||||
uint32_t result = sign |
|
|
||||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
|
||||||
inf_nan_mask) &
|
|
||||||
~zero_mask);
|
|
||||||
return fp32_from_bits(result);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1045,26 +1045,31 @@ void fast::Quantize::eval_gpu(
|
|||||||
compute_encoder.set_input_array(w, 0);
|
compute_encoder.set_input_array(w, 0);
|
||||||
if (dequantize_) {
|
if (dequantize_) {
|
||||||
auto scales = ensure_row_contiguous(inputs[1], d, s);
|
auto scales = ensure_row_contiguous(inputs[1], d, s);
|
||||||
auto biases = ensure_row_contiguous(inputs[2], d, s);
|
|
||||||
compute_encoder.set_input_array(scales, 1);
|
compute_encoder.set_input_array(scales, 1);
|
||||||
compute_encoder.set_input_array(biases, 2);
|
|
||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
|
auto biases = ensure_row_contiguous(inputs[2], d, s);
|
||||||
|
compute_encoder.set_input_array(biases, 2);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
auto& biases = outputs[2];
|
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
compute_encoder.set_output_array(scales, 2);
|
compute_encoder.set_output_array(scales, 2);
|
||||||
compute_encoder.set_output_array(biases, 3);
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
|
auto& biases = outputs[2];
|
||||||
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
|
compute_encoder.set_output_array(biases, 3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||||
: get_type_string(w_pre.dtype());
|
: get_type_string(w_pre.dtype());
|
||||||
|
auto mode = quantization_mode_to_string(mode_);
|
||||||
std::string kname;
|
std::string kname;
|
||||||
concatenate(
|
concatenate(
|
||||||
kname,
|
kname,
|
||||||
dequantize_ ? "affine_dequantize" : "affine_quantize",
|
mode + (dequantize_ ? "_dequantize" : "_quantize"),
|
||||||
"_",
|
"_",
|
||||||
type_string,
|
type_string,
|
||||||
"_gs_",
|
"_gs_",
|
||||||
@@ -1075,7 +1080,7 @@ void fast::Quantize::eval_gpu(
|
|||||||
d,
|
d,
|
||||||
kname,
|
kname,
|
||||||
dequantize_ ? "dequantize" : "quantize",
|
dequantize_ ? "dequantize" : "quantize",
|
||||||
"affine",
|
mode,
|
||||||
type_string,
|
type_string,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_);
|
bits_);
|
||||||
@@ -1088,7 +1093,8 @@ void fast::Quantize::eval_gpu(
|
|||||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||||
: bits_ == 6 ? 4
|
: bits_ == 6 ? 4
|
||||||
: 8 / bits_;
|
: 8 / bits_;
|
||||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
int per_thread =
|
||||||
|
dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1);
|
||||||
size_t nthreads =
|
size_t nthreads =
|
||||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
// Required for using M_PI in MSVC.
|
// Required for using M_PI in MSVC.
|
||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -4259,8 +4258,11 @@ std::vector<array> fp_quantize(
|
|||||||
} else {
|
} else {
|
||||||
// convert to e8m0
|
// convert to e8m0
|
||||||
auto z = array(0, scales.dtype());
|
auto z = array(0, scales.dtype());
|
||||||
scales =
|
scales = where(
|
||||||
where(equal(scales, z, s), z, astype(log2(scales, s), int32, s), s);
|
equal(scales, z, s),
|
||||||
|
z,
|
||||||
|
astype(round(log2(scales, s), s), int32, s),
|
||||||
|
s);
|
||||||
|
|
||||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||||
|
|||||||
@@ -92,7 +92,6 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
|
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
|
||||||
|
|
||||||
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
|
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@@ -102,7 +101,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
|
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
|
||||||
|
|
||||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
||||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2))
|
|
||||||
|
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1))
|
||||||
|
|
||||||
# test quantize/dequantize 0s
|
# test quantize/dequantize 0s
|
||||||
a = mx.zeros((256, 512))
|
a = mx.zeros((256, 512))
|
||||||
|
|||||||
Reference in New Issue
Block a user