mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4 * fast cuda kernel for mx/nv quantization * fallback for cuda < 12.8 (#2697) * format (#2700) * fix (#2701) * metal kernels * docs * fix jit * add default bits and group sizes * improve quant docs * fix output type of mxfp4 matmuls
This commit is contained in:
@@ -120,7 +120,7 @@ Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
|
||||
struct ToFP8 {
|
||||
template <typename T, int N>
|
||||
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
||||
uint32_t fp8_max = 1087 << 20;
|
||||
uint32_t fp8_max = 543 << 21;
|
||||
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
||||
Simd<uint32_t, N> f_bits;
|
||||
Simd<float, N> f32 = f;
|
||||
|
||||
@@ -51,6 +51,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
@@ -58,6 +59,11 @@ target_sources(
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||
|
||||
# fp4 is not available on < 12.8
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
|
||||
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
|
||||
endif()
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||
|
||||
@@ -306,7 +306,7 @@ void affine_dequantize(
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
83
mlx/backend/cuda/quantized/cuda_fp4.h
Normal file
83
mlx/backend/cuda/quantized/cuda_fp4.h
Normal file
@@ -0,0 +1,83 @@
|
||||
#pragma once
|
||||
|
||||
struct __nv_fp8_e8m0 {
|
||||
__device__ __nv_fp8_e8m0(float x) {
|
||||
if (!std::isfinite(x)) {
|
||||
__x = 0xFF;
|
||||
return;
|
||||
}
|
||||
if (x < 0.0f) {
|
||||
__x = 0x00;
|
||||
return;
|
||||
}
|
||||
float le = std::log2f(x);
|
||||
int n = static_cast<int>(std::nearbyintf(le));
|
||||
|
||||
n = n < -127 ? -127 : n;
|
||||
n = n > 127 ? 127 : n;
|
||||
__x = static_cast<uint8_t>(n + 127);
|
||||
}
|
||||
|
||||
__device__ operator float() {
|
||||
if (__x == 0xFF) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
|
||||
}
|
||||
|
||||
uint8_t __x{0};
|
||||
};
|
||||
|
||||
struct __nv_fp4_e2m1 {
|
||||
__device__ __nv_fp4_e2m1(float x) {
|
||||
if (std::isnan(x)) {
|
||||
__x = 0x7;
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
||||
x = std::abs(x);
|
||||
|
||||
if (x > 5.0f) {
|
||||
__x = 0x7;
|
||||
} else if (x >= 3.5f) {
|
||||
__x = 0x6;
|
||||
} else if (x > 2.5f) {
|
||||
__x = 0x5;
|
||||
} else if (x >= 1.75f) {
|
||||
__x = 0x4;
|
||||
} else if (x > 1.25f) {
|
||||
__x = 0x3;
|
||||
} else if (x >= 0.75f) {
|
||||
__x = 0x2;
|
||||
} else if (x > 0.25f) {
|
||||
__x = 0x1;
|
||||
} else {
|
||||
__x = 0x0;
|
||||
}
|
||||
__x |= sign_bit;
|
||||
}
|
||||
|
||||
__device__ operator float() {
|
||||
static const float LUT[16] = {
|
||||
0.0f,
|
||||
0.5f,
|
||||
1.0f,
|
||||
1.5f,
|
||||
2.0f,
|
||||
3.0f,
|
||||
4.0f,
|
||||
6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
return LUT[__x];
|
||||
}
|
||||
uint8_t __x{0};
|
||||
};
|
||||
216
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
216
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
@@ -0,0 +1,216 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda_fp4.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
__device__ uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
} else {
|
||||
return __nv_fp4_e2m1(x).__x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
if constexpr (bits == 8) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(__nv_fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float w_thread = w[index];
|
||||
|
||||
cg::greater<float> max_op;
|
||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||
|
||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
// Convert to mx scale or nv scale
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto s = ScaleType(scale);
|
||||
uint8_t q_scale = s.__x;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = warp.shfl_down(output, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t oindex = offset * pack_factor;
|
||||
|
||||
if (oindex >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t gindex = oindex / group_size;
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto scale = float(((ScaleType*)(scales))[gindex]);
|
||||
|
||||
out += oindex;
|
||||
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||
}
|
||||
bool large = w.size() > UINT_MAX;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<uint8_t>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int = 8 / bits;
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_dequantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_dequantize<T, 16, 4, false>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
|
||||
}
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,4 +24,22 @@ void affine_dequantize(
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -29,7 +29,7 @@ make_jit_source(
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
|
||||
|
||||
make_jit_source(quantized_utils)
|
||||
make_jit_source(quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
|
||||
kernels/fp4.h)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
|
||||
@@ -24,7 +24,7 @@ const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized_utils();
|
||||
const char* quantized();
|
||||
const char* fp4_quantized();
|
||||
const char* fp_quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
|
||||
@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized_utils(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
|
||||
template_def);
|
||||
return kernel_source;
|
||||
});
|
||||
@@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||
if (mode == "affine") {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
} else {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::fp4_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
"uint8_t",
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
}
|
||||
bool is_affine = mode == "affine";
|
||||
concatenate(
|
||||
kernel_source,
|
||||
is_affine ? metal::quantized() : metal::fp_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
|
||||
@@ -6,6 +6,7 @@ set(BASE_HEADERS
|
||||
defines.h
|
||||
erf.h
|
||||
expm1f.h
|
||||
fp8.h
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
@@ -109,7 +110,8 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
|
||||
${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
|
||||
56
mlx/backend/metal/kernels/fp4.h
Normal file
56
mlx/backend/metal/kernels/fp4.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
constexpr constant static float FP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
struct fp4_e2m1 {
|
||||
fp4_e2m1(float x) {
|
||||
if (metal::isnan(x)) {
|
||||
bits = 0x7;
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
|
||||
x = metal::abs(x);
|
||||
|
||||
if (x > 5.0f) {
|
||||
bits = 0x7;
|
||||
} else if (x >= 3.5f) {
|
||||
bits = 0x6;
|
||||
} else if (x > 2.5f) {
|
||||
bits = 0x5;
|
||||
} else if (x >= 1.75f) {
|
||||
bits = 0x4;
|
||||
} else if (x > 1.25f) {
|
||||
bits = 0x3;
|
||||
} else if (x >= 0.75f) {
|
||||
bits = 0x2;
|
||||
} else if (x > 0.25f) {
|
||||
bits = 0x1;
|
||||
} else {
|
||||
bits = 0x0;
|
||||
}
|
||||
bits |= sign_bit;
|
||||
}
|
||||
|
||||
operator float() {
|
||||
return FP4_LUT[bits];
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
||||
|
||||
#define instantiate_quantized(name, type) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4", \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type) \
|
||||
instantiate_quantized_batched(name, type, 1) \
|
||||
instantiate_quantized_batched(name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4_gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
88
mlx/backend/metal/kernels/fp8.h
Normal file
88
mlx/backend/metal/kernels/fp8.h
Normal file
@@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
|
||||
inline float fp32_from_bits(uint32_t bits) {
|
||||
return *(reinterpret_cast<thread float*>(&bits));
|
||||
}
|
||||
inline float fp32_to_bits(float x) {
|
||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||
}
|
||||
|
||||
struct fp8_e4m3 {
|
||||
template <typename T>
|
||||
fp8_e4m3(T f) {
|
||||
// From PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||
uint32_t fp8_max = 543 << 21;
|
||||
uint32_t denorm_mask = 141 << 23;
|
||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
// Default behavior saturates to min/max
|
||||
bits = 0x7E;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
bits = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
f_bits += mant_odd;
|
||||
bits = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
bits |= static_cast<uint8_t>(sign >> 24);
|
||||
}
|
||||
|
||||
operator float() {
|
||||
// From PyTorch:
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||
uint32_t w = static_cast<uint32_t>(bits) << 24;
|
||||
uint32_t sign = w & 0x80000000;
|
||||
uint32_t nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
uint32_t renorm_shift = metal::clz(nonsign);
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
|
||||
int32_t inf_nan_mask =
|
||||
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
|
||||
struct fp8_e8m0 {
|
||||
fp8_e8m0(float x) {
|
||||
if (!metal::isfinite(x)) {
|
||||
bits = 0xFF;
|
||||
return;
|
||||
}
|
||||
if (x < 0.0f) {
|
||||
bits = 0x00;
|
||||
return;
|
||||
}
|
||||
float le = metal::log2(x);
|
||||
int n = int(metal::round(le));
|
||||
|
||||
n = n < -127 ? -127 : n;
|
||||
n = n > 127 ? 127 : n;
|
||||
bits = static_cast<uint8_t>(n + 127);
|
||||
}
|
||||
|
||||
operator float() {
|
||||
if (bits == 0xFF) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
@@ -3,6 +3,9 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fp4.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
@@ -59,28 +62,10 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
constexpr constant static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename T>
|
||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
if (simd_gid == 0 && simd_lid < 16) {
|
||||
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
|
||||
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
@@ -155,8 +140,7 @@ template <
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
typename S>
|
||||
short group_size>
|
||||
struct QuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
@@ -183,12 +167,12 @@ struct QuantizedBlockLoader {
|
||||
|
||||
threadgroup T* dst;
|
||||
const device uint8_t* src;
|
||||
const device S* scales;
|
||||
const device uint8_t* scales;
|
||||
threadgroup T* lut;
|
||||
|
||||
QuantizedBlockLoader(
|
||||
const device uint8_t* src_,
|
||||
const device S* scales_,
|
||||
const device uint8_t* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
threadgroup T* lut_,
|
||||
@@ -208,7 +192,7 @@ struct QuantizedBlockLoader {
|
||||
bj * bytes_per_pack),
|
||||
scales(scales_ + bi * src_ld / group_size),
|
||||
lut(lut_) {
|
||||
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
load_fp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
}
|
||||
|
||||
void load_unsafe() const {
|
||||
@@ -270,10 +254,10 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, typename S, int D>
|
||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
template <typename T, int group_size, int bits, int D>
|
||||
METAL_FUNC void fp_qmv_quad_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
constant int& in_vec_size,
|
||||
@@ -295,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
@@ -311,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
|
||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
@@ -327,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -353,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
typedef float U;
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -390,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_impl(
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -418,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -448,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
S s = sl[0];
|
||||
uint8_t s = sl[0];
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
@@ -529,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qvm_impl(
|
||||
template <typename T, const int group_size, int bits>
|
||||
METAL_FUNC void fp_qvm_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const int in_vec_size,
|
||||
@@ -561,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
thread U scale = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -633,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
METAL_FUNC void fp_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -677,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K * bytes_per_pack / pack_factor;
|
||||
@@ -759,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
METAL_FUNC void fp_qmm_n_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -803,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
BN_padded,
|
||||
0,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
@@ -891,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device S*& scales,
|
||||
const device uint8_t*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
@@ -926,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device S*& scales,
|
||||
const device uint8_t*& scales,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
device T*& y,
|
||||
@@ -976,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S, int D, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_quad(
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
[[kernel]] void fp_qmv_quad(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1014,7 +996,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
||||
fp_qmv_quad_impl<T, group_size, bits, D>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1029,10 +1011,10 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
||||
lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_fast(
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1065,14 +1047,14 @@ template <typename T, int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv(
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1105,14 +1087,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qvm(
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qvm(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1145,14 +1127,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
[[kernel]] void mxfp4_qvm_split_k(
|
||||
template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
[[kernel]] void fp_qvm_split_k(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1189,7 +1171,7 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1205,15 +1187,15 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_qmm_t(
|
||||
[[kernel]] void fp_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1254,21 +1236,21 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_qmm_n(
|
||||
[[kernel]] void fp_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1311,14 +1293,14 @@ template <
|
||||
tid);
|
||||
}
|
||||
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv_fast(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1361,14 +1343,14 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1411,14 +1393,14 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qvm(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qvm(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1461,21 +1443,21 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_gather_qmm_t(
|
||||
[[kernel]] void fp_gather_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1526,20 +1508,20 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_gather_qmm_n(
|
||||
[[kernel]] void fp_gather_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1591,24 +1573,24 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
typename S,
|
||||
int bits,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void mxfp4_gather_qmm_rhs(
|
||||
[[kernel]] void fp_gather_qmm_rhs(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device uint32_t* indices,
|
||||
device T* y,
|
||||
const constant int& M,
|
||||
@@ -1644,8 +1626,7 @@ template <
|
||||
transpose ? BK_padded : BN_padded,
|
||||
transpose,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||
@@ -1789,3 +1770,100 @@ template <
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
uint8_t operator()(float x) {
|
||||
if (bits == 8) {
|
||||
return fp8_e4m3(x).bits;
|
||||
} else {
|
||||
return fp4_e2m1(x).bits;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
float operator()(uint8_t x) {
|
||||
if (bits == 8) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(thread fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void fp_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device uint8_t* scales [[buffer(2)]],
|
||||
uint2 tidx [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr bool use_mx_scale = group_size == 32;
|
||||
size_t index = tidx.x + grid_dim.x * size_t(tidx.y);
|
||||
|
||||
float scale;
|
||||
float w_thread = w[index];
|
||||
if (use_mx_scale) {
|
||||
scale = simd_max(abs(w_thread));
|
||||
} else {
|
||||
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
|
||||
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
|
||||
scale = tidx.x < 16 ? w_max_l : w_max_r;
|
||||
}
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
|
||||
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||
auto s = ScaleType(scale);
|
||||
uint8_t q_scale = s.bits;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales and biases
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = simd_shuffle_down(output, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void fp_dequantize(
|
||||
const device uint8_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr bool use_mx_scale = group_size == 32;
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t oindex = offset * pack_factor;
|
||||
size_t gindex = oindex / group_size;
|
||||
|
||||
out += oindex;
|
||||
|
||||
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||
auto q_scale = ((device ScaleType*)(scales))[gindex];
|
||||
auto scale = float(q_scale);
|
||||
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||
}
|
||||
}
|
||||
147
mlx/backend/metal/kernels/fp_quantized.metal
Normal file
147
mlx/backend/metal/kernels/fp_quantized.metal
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
||||
|
||||
#define instantiate_quantized(mode, name, type) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4", \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4)
|
||||
|
||||
#define instantiate_quantized_batched(mode, name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(mode, name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(mode, name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(mode, name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(mode, name, type) \
|
||||
instantiate_quantized_batched(mode, name, type, 1) \
|
||||
instantiate_quantized_batched(mode, name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4, gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_quantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits) \
|
||||
instantiate_kernel( \
|
||||
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_dequantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantize_dequantize_modes(type) \
|
||||
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
|
||||
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
|
||||
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type) \
|
||||
instantiate_quantize_dequantize_modes(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/metal/kernels/cexpf.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
|
||||
namespace {
|
||||
constant float inf = metal::numeric_limits<float>::infinity();
|
||||
@@ -439,63 +440,15 @@ complex64_t ArcTan::operator()(complex64_t x) {
|
||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||
};
|
||||
|
||||
inline float fp32_from_bits(uint32_t bits) {
|
||||
return *(reinterpret_cast<thread float*>(&bits));
|
||||
}
|
||||
inline float fp32_to_bits(float x) {
|
||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||
}
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
uint8_t operator()(T f) {
|
||||
// From PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||
uint32_t fp8_max = 1087 << 20;
|
||||
uint32_t denorm_mask = 141 << 23;
|
||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||
uint8_t result = 0u;
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
// Default behavior saturates to min/max
|
||||
result = 0x7E;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
f_bits += mant_odd;
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
return fp8_e4m3(f).bits;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
float operator()(uint8_t x) {
|
||||
// From PyTorch:
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||
uint32_t w = static_cast<uint32_t>(x) << 24;
|
||||
uint32_t sign = w & 0x80000000;
|
||||
uint32_t nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
uint32_t renorm_shift = metal::clz(nonsign);
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
|
||||
int32_t inf_nan_mask =
|
||||
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped(
|
||||
int bits,
|
||||
Args... args) {
|
||||
std::string template_def;
|
||||
auto fname = mode + "_" + func;
|
||||
if (mode == "affine") {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
||||
}
|
||||
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
return get_quantized_kernel(d, name, template_def, mode);
|
||||
}
|
||||
|
||||
@@ -1045,26 +1040,31 @@ void fast::Quantize::eval_gpu(
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
if (dequantize_) {
|
||||
auto scales = ensure_row_contiguous(inputs[1], d, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], d, s);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], d, s);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
}
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
}
|
||||
}
|
||||
|
||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||
: get_type_string(w_pre.dtype());
|
||||
auto mode = quantization_mode_to_string(mode_);
|
||||
std::string kname;
|
||||
concatenate(
|
||||
kname,
|
||||
dequantize_ ? "affine_dequantize" : "affine_quantize",
|
||||
mode + (dequantize_ ? "_dequantize" : "_quantize"),
|
||||
"_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
@@ -1075,7 +1075,7 @@ void fast::Quantize::eval_gpu(
|
||||
d,
|
||||
kname,
|
||||
dequantize_ ? "dequantize" : "quantize",
|
||||
"affine",
|
||||
mode,
|
||||
type_string,
|
||||
group_size_,
|
||||
bits_);
|
||||
@@ -1088,7 +1088,8 @@ void fast::Quantize::eval_gpu(
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||
int per_thread =
|
||||
dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1);
|
||||
size_t nthreads =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ class ConvertFP8 : public Primitive {
|
||||
};
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE();
|
||||
|
||||
private:
|
||||
bool to_fp8_;
|
||||
|
||||
487
mlx/ops.cpp
487
mlx/ops.cpp
@@ -2,7 +2,6 @@
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
@@ -4017,21 +4016,50 @@ array conv_general(
|
||||
{in, wt});
|
||||
}
|
||||
|
||||
void validate_mode(std::string_view tag, const std::string& mode) {
|
||||
if (mode != "affine" && mode != "mxfp4") {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
std::pair<int, int> quantization_params_from_mode(
|
||||
QuantizationMode mode,
|
||||
std::optional<int> group_size_,
|
||||
std::optional<int> bits_) {
|
||||
int default_group_size;
|
||||
int default_bits;
|
||||
switch (mode) {
|
||||
case QuantizationMode::Affine:
|
||||
default_group_size = 64;
|
||||
default_bits = 4;
|
||||
break;
|
||||
case QuantizationMode::Nvfp4:
|
||||
default_group_size = 16;
|
||||
default_bits = 4;
|
||||
break;
|
||||
case QuantizationMode::Mxfp4:
|
||||
default_group_size = 32;
|
||||
default_bits = 4;
|
||||
break;
|
||||
case QuantizationMode::Mxfp8:
|
||||
default_group_size = 32;
|
||||
default_bits = 8;
|
||||
break;
|
||||
}
|
||||
return {
|
||||
group_size_.has_value() ? *group_size_ : default_group_size,
|
||||
bits_.has_value() ? *bits_ : default_bits};
|
||||
}
|
||||
|
||||
Dtype validate_mode_with_type(
|
||||
std::pair<Dtype, QuantizationMode> validate_mode_with_type(
|
||||
std::string_view tag,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<Dtype> out_type,
|
||||
const std::string& mode) {
|
||||
validate_mode(tag, mode);
|
||||
if (mode == "affine") {
|
||||
auto qmode = string_to_quantization_mode(mode, tag);
|
||||
if (out_type.has_value() && !issubdtype(*out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Only real floating types are supported but "
|
||||
<< "output dtype == " << *out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
if (!biases) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Biases must be provided for affine quantization.";
|
||||
@@ -4045,7 +4073,11 @@ Dtype validate_mode_with_type(
|
||||
<< " and biases.dtype() == " << biases->dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return dtype;
|
||||
if (out_type.has_value()) {
|
||||
return {*out_type, qmode};
|
||||
} else {
|
||||
return {dtype, qmode};
|
||||
}
|
||||
}
|
||||
if (biases) {
|
||||
std::ostringstream msg;
|
||||
@@ -4053,7 +4085,11 @@ Dtype validate_mode_with_type(
|
||||
<< "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return bfloat16;
|
||||
if (out_type.has_value()) {
|
||||
return {*out_type, qmode};
|
||||
} else {
|
||||
return {bfloat16, qmode};
|
||||
}
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
@@ -4062,17 +4098,24 @@ array quantized_matmul(
|
||||
array scales,
|
||||
std::optional<array> biases /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
std::optional<int> group_size_ /* = std::nullopt */,
|
||||
std::optional<int> bits_ /* = std::nullopt */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto [dtype, qmode] = validate_mode_with_type(
|
||||
"quantized_matmul", scales, biases, std::nullopt, mode);
|
||||
|
||||
auto [group_size, bits] =
|
||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto dtype =
|
||||
validate_mode_with_type("quantized_matmul", scales, biases, mode);
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
} else {
|
||||
dtype = x.dtype();
|
||||
}
|
||||
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -4081,7 +4124,7 @@ array quantized_matmul(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
||||
} else {
|
||||
@@ -4098,11 +4141,7 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s),
|
||||
group_size,
|
||||
bits,
|
||||
string_to_quantization_mode(mode),
|
||||
transpose),
|
||||
to_stream(s), group_size, bits, qmode, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
@@ -4216,13 +4255,110 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
{w});
|
||||
}
|
||||
|
||||
std::vector<array> fp_quantize(
|
||||
const array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationMode mode,
|
||||
Stream s) {
|
||||
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires group size " << expected_gs << " but got "
|
||||
<< group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||
<< bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto fallback = [bits = bits, group_size = group_size, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
float maxval = (bits == 4) ? 6.0f : 448.0f;
|
||||
auto new_shape = w.shape();
|
||||
new_shape.back() = -1;
|
||||
auto wq = reshape(w, {-1, group_size}, s);
|
||||
auto scales =
|
||||
divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
|
||||
if (group_size == 16) {
|
||||
// convert to e4m3
|
||||
scales = to_fp8(scales, s);
|
||||
wq = divide(wq, from_fp8(scales, w.dtype(), s), s);
|
||||
} else {
|
||||
// convert to e8m0
|
||||
auto z = array(0, scales.dtype());
|
||||
scales = where(
|
||||
equal(scales, z, s),
|
||||
z,
|
||||
astype(round(log2(scales, s), s), int32, s),
|
||||
s);
|
||||
|
||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||
}
|
||||
if (bits == 4) {
|
||||
auto lut = array({
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
});
|
||||
lut = astype(lut, w.dtype(), s);
|
||||
wq = argmin(
|
||||
abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
|
||||
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
|
||||
wq = reshape(wq, {-1, 4, 8}, s);
|
||||
wq = sum(multiply(wq, shifts, s), -1, false, s);
|
||||
} else {
|
||||
wq = view(to_fp8(wq, s), uint32, s);
|
||||
}
|
||||
wq = reshape(wq, new_shape, s);
|
||||
scales = reshape(scales, new_shape, s);
|
||||
return {std::move(wq), std::move(scales)};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
return array::make_arrays(
|
||||
{std::move(wq_shape), std::move(sshape)},
|
||||
{uint32, uint8},
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, mode, false),
|
||||
{w});
|
||||
}
|
||||
return fallback({w});
|
||||
}
|
||||
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
std::optional<int> group_size_ /* = std::nullopt */,
|
||||
std::optional<int> bits_ /* = std::nullopt */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_mode("quantize", mode);
|
||||
auto qmode = string_to_quantization_mode(mode, "quantize");
|
||||
auto [group_size, bits] =
|
||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||
if (!issubdtype(w.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Only real floating types can be quantized "
|
||||
@@ -4246,57 +4382,10 @@ std::vector<array> quantize(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
return affine_quantize(w, group_size, bits, s);
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto lut = array({
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
});
|
||||
lut = astype(lut, w.dtype(), s);
|
||||
|
||||
auto new_shape = w.shape();
|
||||
new_shape.back() = -1;
|
||||
auto wq = reshape(w, {-1, group_size}, s);
|
||||
auto scales =
|
||||
divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s);
|
||||
scales = astype(log2(scales, s), int32, s);
|
||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||
wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
|
||||
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
|
||||
wq = reshape(wq, {-1, group_size / 8, 8}, s);
|
||||
wq = sum(multiply(wq, shifts, s), -1, false, s);
|
||||
wq = reshape(wq, new_shape, s);
|
||||
scales = reshape(scales, new_shape, s);
|
||||
return {std::move(wq), std::move(scales)};
|
||||
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4307,16 +4396,13 @@ array affine_dequantize(
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
@@ -4397,15 +4483,132 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
array fp_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
Dtype out_type,
|
||||
QuantizationMode mode,
|
||||
Stream s) {
|
||||
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires group size " << expected_gs << " but got "
|
||||
<< group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||
<< bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
if (wshape.size() != sshape.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
|
||||
if (wshape != sshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
out_type,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto out = inputs[0];
|
||||
auto scales = inputs[1];
|
||||
if (bits == 4) {
|
||||
auto lut = array(
|
||||
{
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
},
|
||||
out_type);
|
||||
out = view(reshape(out, {-1, 4}, s), int8, s);
|
||||
auto idx_lo = bitwise_and(out, array(0x0F, int8), s);
|
||||
auto idx_hi = right_shift(out, array(4, int8), s);
|
||||
auto lo = gather(lut, idx_lo, 0, {1}, s);
|
||||
auto hi = gather(lut, idx_hi, 0, {1}, s);
|
||||
out = concatenate({lo, hi}, -1, s);
|
||||
} else {
|
||||
out = from_fp8(view(out, uint8, s), out_type, s);
|
||||
}
|
||||
out = reshape(out, {-1, group_size}, s);
|
||||
scales = reshape(scales, {-1, 1}, s);
|
||||
if (group_size == 16) {
|
||||
scales = from_fp8(scales, out_type, s);
|
||||
} else {
|
||||
scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
|
||||
scales = power(array(2.0f, out_type), scales, s);
|
||||
}
|
||||
return {reshape(multiply(out, scales, s), wshape, s)};
|
||||
};
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<fast::Quantize>(
|
||||
s, fallback, group_size, bits, mode, true),
|
||||
{w, scales});
|
||||
}
|
||||
return fallback({w, scales})[0];
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases /* = std::nullopt */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
std::optional<int> group_size_ /* = std::nullopt */,
|
||||
std::optional<int> bits_ /* = std::nullopt */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
std::optional<Dtype> dtype /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_mode_with_type("dequantize", scales, biases, mode);
|
||||
auto [out_type, qmode] =
|
||||
validate_mode_with_type("dequantize", scales, biases, dtype, mode);
|
||||
auto [group_size, bits] =
|
||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
@@ -4420,89 +4623,21 @@ array dequantize(
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
return affine_dequantize(w, scales, *biases, group_size, bits, s);
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
return astype(
|
||||
affine_dequantize(w, scales, *biases, group_size, bits, s),
|
||||
out_type,
|
||||
s);
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2 || scales.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
|
||||
if (wshape != sshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto dtype = bfloat16;
|
||||
auto lut = array(
|
||||
{
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
},
|
||||
dtype);
|
||||
|
||||
auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s);
|
||||
|
||||
auto idx_lo = bitwise_and(what, array(0x0F, int8), s);
|
||||
auto idx_hi = right_shift(what, array(4, int8), s);
|
||||
auto lo = gather(lut, idx_lo, 0, {1}, s);
|
||||
auto hi = gather(lut, idx_hi, 0, {1}, s);
|
||||
what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s);
|
||||
auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s);
|
||||
exponent = reshape(exponent, {-1, 1}, s);
|
||||
return reshape(
|
||||
multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s);
|
||||
return fp_dequantize(
|
||||
w, scales, group_size, bits, out_type, qmode, to_stream(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4548,21 +4683,27 @@ array gather_qmm(
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
std::optional<int> group_size_ /* = std::nullopt */,
|
||||
std::optional<int> bits_ /* = std::nullopt */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
bool sorted_indices /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||
x, w, scales, biases, transpose, group_size_, bits_, mode, s);
|
||||
}
|
||||
|
||||
auto [out_type, qmode] =
|
||||
validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode);
|
||||
auto [group_size, bits] =
|
||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode);
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
} else {
|
||||
out_type = x.dtype();
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -4601,7 +4742,7 @@ array gather_qmm(
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
inputs = {
|
||||
astype(x, out_type, s),
|
||||
std::move(w),
|
||||
@@ -4624,7 +4765,7 @@ array gather_qmm(
|
||||
to_stream(s),
|
||||
group_size,
|
||||
bits,
|
||||
string_to_quantization_mode(mode),
|
||||
qmode,
|
||||
transpose,
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
|
||||
17
mlx/ops.h
17
mlx/ops.h
@@ -1379,16 +1379,16 @@ array quantized_matmul(
|
||||
array scales,
|
||||
std::optional<array> biases = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
std::optional<int> group_size = std::nullopt,
|
||||
std::optional<int> bits = std::nullopt,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::vector<array> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
std::optional<int> group_size = std::nullopt,
|
||||
std::optional<int> bits = std::nullopt,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -1397,9 +1397,10 @@ array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases = std::nullopt,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
std::optional<int> group_size = std::nullopt,
|
||||
std::optional<int> bits = std::nullopt,
|
||||
const std::string& mode = "affine",
|
||||
std::optional<Dtype> dtype = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an E4M3 float8 to the given floating point dtype. */
|
||||
@@ -1417,8 +1418,8 @@ array gather_qmm(
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
std::optional<int> group_size = std::nullopt,
|
||||
std::optional<int> bits = std::nullopt,
|
||||
const std::string& mode = "affine",
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -3328,19 +3328,37 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap(
|
||||
}
|
||||
|
||||
std::string quantization_mode_to_string(QuantizationMode mode) {
|
||||
if (mode == QuantizationMode::Affine) {
|
||||
return "affine";
|
||||
} else {
|
||||
return "mxfp4";
|
||||
switch (mode) {
|
||||
case QuantizationMode::Affine:
|
||||
return "affine";
|
||||
case QuantizationMode::Mxfp4:
|
||||
return "mxfp4";
|
||||
case QuantizationMode::Mxfp8:
|
||||
return "mxfp8";
|
||||
case QuantizationMode::Nvfp4:
|
||||
default:
|
||||
return "nvfp4";
|
||||
}
|
||||
}
|
||||
|
||||
QuantizationMode string_to_quantization_mode(const std::string& mode) {
|
||||
QuantizationMode string_to_quantization_mode(
|
||||
const std::string& mode,
|
||||
std::string_view tag /* = "" */) {
|
||||
if (mode == "affine") {
|
||||
return QuantizationMode::Affine;
|
||||
} else {
|
||||
} else if (mode == "mxfp4") {
|
||||
return QuantizationMode::Mxfp4;
|
||||
} else if (mode == "mxfp8") {
|
||||
return QuantizationMode::Mxfp8;
|
||||
} else if (mode == "nvfp4") {
|
||||
return QuantizationMode::Nvfp4;
|
||||
}
|
||||
std::string msg;
|
||||
if (!tag.empty()) {
|
||||
msg += "[" + std::string(tag) + "]";
|
||||
}
|
||||
msg += " Invalid quantization mode '" + mode + "'.";
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
||||
@@ -3404,6 +3422,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
group_size_,
|
||||
bits_,
|
||||
quantization_mode_to_string(mode_),
|
||||
std::nullopt,
|
||||
stream());
|
||||
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
||||
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
|
||||
@@ -3558,6 +3577,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
group_size_,
|
||||
bits_,
|
||||
quantization_mode_to_string(mode_),
|
||||
std::nullopt,
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
|
||||
@@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive {
|
||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||
};
|
||||
|
||||
enum class QuantizationMode { Affine, Mxfp4 };
|
||||
enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
|
||||
|
||||
std::string quantization_mode_to_string(QuantizationMode mode);
|
||||
QuantizationMode string_to_quantization_mode(const std::string& mode);
|
||||
QuantizationMode string_to_quantization_mode(
|
||||
const std::string& mode,
|
||||
std::string_view error_tag = "");
|
||||
|
||||
class Abs : public UnaryPrimitive {
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user