mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fast cuda kernel for mx/nv quantization
This commit is contained in:
@@ -51,6 +51,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ void affine_dequantize(
|
|||||||
enc.set_input_array(scales);
|
enc.set_input_array(scales);
|
||||||
enc.set_input_array(biases);
|
enc.set_input_array(biases);
|
||||||
enc.set_output_array(w);
|
enc.set_output_array(w);
|
||||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
|
||||||
dispatch_groups(group_size_, [&](auto group_size) {
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
dispatch_bits(bits_, [&](auto bits) {
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|||||||
218
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
218
mlx/backend/cuda/quantized/fp_quantize.cu
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cuda_fp4.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <int bits>
|
||||||
|
struct Quantize {
|
||||||
|
__device__ uint8_t operator()(float x) {
|
||||||
|
if constexpr (bits == 8) {
|
||||||
|
return __nv_fp8_e4m3(x).__x;
|
||||||
|
} else {
|
||||||
|
return __nv_fp4_e2m1(x).__x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int bits>
|
||||||
|
struct Dequantize {
|
||||||
|
__device__ float operator()(uint8_t x) {
|
||||||
|
if constexpr (bits == 8) {
|
||||||
|
return float(*(__nv_fp8_e4m3*)(&x));
|
||||||
|
} else {
|
||||||
|
return float(*(__nv_fp4_e2m1*)(&x));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||||
|
__global__ void
|
||||||
|
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||||
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
|
|
||||||
|
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||||
|
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||||
|
|
||||||
|
auto grid_dim_x =
|
||||||
|
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||||
|
size_t out_index = tidx + grid_dim_x * size_t(tidy);
|
||||||
|
size_t in_index = out_index;
|
||||||
|
if (in_index >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float w_thread = w[in_index];
|
||||||
|
|
||||||
|
cg::greater<float> max_op;
|
||||||
|
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||||
|
|
||||||
|
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||||
|
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||||
|
// Convert to mx scale or nv scale
|
||||||
|
using ScaleType =
|
||||||
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||||
|
auto s = ScaleType(scale);
|
||||||
|
uint8_t q_scale = s.__x;
|
||||||
|
scale = float(s);
|
||||||
|
|
||||||
|
// Write out the scales
|
||||||
|
size_t gindex = in_index / group_size;
|
||||||
|
if (in_index % group_size == 0) {
|
||||||
|
scales[gindex] = q_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t output = 0;
|
||||||
|
uint8_t val = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||||
|
output = val;
|
||||||
|
if (bits == 4) {
|
||||||
|
uint8_t sval = warp.shfl_down(val, 1);
|
||||||
|
output |= sval << bits;
|
||||||
|
}
|
||||||
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
|
if (out_index % pack_factor == 0) {
|
||||||
|
out[out_index / pack_factor] = output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||||
|
__global__ void
|
||||||
|
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
||||||
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
|
|
||||||
|
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||||
|
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||||
|
|
||||||
|
auto grid_dim_x =
|
||||||
|
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||||
|
|
||||||
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
|
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||||
|
size_t oindex = offset * pack_factor;
|
||||||
|
|
||||||
|
if (oindex >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t gindex = oindex / group_size;
|
||||||
|
using ScaleType =
|
||||||
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||||
|
auto scale = float(((ScaleType*)(scales))[gindex]);
|
||||||
|
|
||||||
|
out += oindex;
|
||||||
|
|
||||||
|
uint val = w[offset];
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < pack_factor; i++) {
|
||||||
|
uint8_t d;
|
||||||
|
if (bits == 4) {
|
||||||
|
d = (val >> (bits * i)) & 0x0f;
|
||||||
|
} else if (bits == 8) {
|
||||||
|
d = val;
|
||||||
|
}
|
||||||
|
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void fp_quantize(
|
||||||
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
enc.set_input_array(w);
|
||||||
|
enc.set_output_array(wq);
|
||||||
|
enc.set_output_array(scales);
|
||||||
|
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
if constexpr (!std::is_same_v<T, double>) {
|
||||||
|
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||||
|
if (bits == 8) {
|
||||||
|
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||||
|
} else if (group_size == 16) {
|
||||||
|
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||||
|
}
|
||||||
|
bool large = w.size() > UINT_MAX;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
w.data<T>(),
|
||||||
|
wq.data<uint8_t>(),
|
||||||
|
scales.data<uint8_t>(),
|
||||||
|
w.size());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void fp_dequantize(
|
||||||
|
const array& wq,
|
||||||
|
const array& scales,
|
||||||
|
array& w,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
constexpr int uint8_per_uint32 = 4;
|
||||||
|
int packs_per_int = 8 / bits;
|
||||||
|
|
||||||
|
size_t size = w.size() / packs_per_int;
|
||||||
|
bool large = size > UINT_MAX;
|
||||||
|
auto grid_shape = w.shape();
|
||||||
|
grid_shape.back() *= uint8_per_uint32;
|
||||||
|
|
||||||
|
enc.set_input_array(wq);
|
||||||
|
enc.set_input_array(scales);
|
||||||
|
enc.set_output_array(w);
|
||||||
|
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
if constexpr (!std::is_same_v<T, double>) {
|
||||||
|
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
|
||||||
|
if (bits == 8) {
|
||||||
|
kernel = cu::fp_dequantize<T, 32, 8, true>;
|
||||||
|
} else if (group_size == 16) {
|
||||||
|
kernel = cu::fp_dequantize<T, 16, 4, false>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(size, grid_shape, w.strides(), large);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
wq.data<uint8_t>(),
|
||||||
|
scales.data<T>(),
|
||||||
|
w.data<T>(),
|
||||||
|
w.size());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
|
|||||||
if (dequantize_) {
|
if (dequantize_) {
|
||||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
|
||||||
auto& w = outputs[0];
|
auto& w = outputs[0];
|
||||||
|
|
||||||
w.set_data(allocator::malloc(w.nbytes()));
|
w.set_data(allocator::malloc(w.nbytes()));
|
||||||
|
|
||||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||||
|
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||||
|
} else {
|
||||||
|
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
auto& wq = outputs[0];
|
auto& wq = outputs[0];
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
auto& biases = outputs[2];
|
|
||||||
|
|
||||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
|
auto& biases = outputs[2];
|
||||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
|
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||||
|
} else {
|
||||||
|
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,4 +24,22 @@ void affine_dequantize(
|
|||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
|
void fp_quantize(
|
||||||
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
void fp_dequantize(
|
||||||
|
const array& wq,
|
||||||
|
const array& scales,
|
||||||
|
array& w,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
352
mlx/ops.cpp
352
mlx/ops.cpp
@@ -4017,22 +4017,22 @@ array conv_general(
|
|||||||
{in, wt});
|
{in, wt});
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_mode(std::string_view tag, const std::string& mode) {
|
std::pair<Dtype, QuantizationMode> validate_mode_with_type(
|
||||||
if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" &&
|
|
||||||
mode != "nvfp4") {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Dtype validate_mode_with_type(
|
|
||||||
std::string_view tag,
|
std::string_view tag,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const std::optional<array>& biases,
|
const std::optional<array>& biases,
|
||||||
|
const std::optional<Dtype> out_type,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
validate_mode(tag, mode);
|
auto qmode = string_to_quantization_mode(mode, tag);
|
||||||
if (mode == "affine") {
|
// TODO add tests for out_type
|
||||||
|
if (out_type.has_value() && !issubdtype(*out_type, floating)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] Only real floating types are supported but "
|
||||||
|
<< "output dtype == " << *out_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qmode == QuantizationMode::Affine) {
|
||||||
if (!biases) {
|
if (!biases) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Biases must be provided for affine quantization.";
|
msg << "[" << tag << "] Biases must be provided for affine quantization.";
|
||||||
@@ -4046,7 +4046,11 @@ Dtype validate_mode_with_type(
|
|||||||
<< " and biases.dtype() == " << biases->dtype() << ".";
|
<< " and biases.dtype() == " << biases->dtype() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
return dtype;
|
if (out_type.has_value()) {
|
||||||
|
return {*out_type, qmode};
|
||||||
|
} else {
|
||||||
|
return {dtype, qmode};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (biases) {
|
if (biases) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -4054,7 +4058,11 @@ Dtype validate_mode_with_type(
|
|||||||
<< "'.";
|
<< "'.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
return bfloat16;
|
if (out_type.has_value()) {
|
||||||
|
return {*out_type, qmode};
|
||||||
|
} else {
|
||||||
|
return {bfloat16, qmode};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array quantized_matmul(
|
array quantized_matmul(
|
||||||
@@ -4071,8 +4079,8 @@ array quantized_matmul(
|
|||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
auto dtype =
|
auto [dtype, qmode] = validate_mode_with_type(
|
||||||
validate_mode_with_type("quantized_matmul", scales, biases, mode);
|
"quantized_matmul", scales, biases, std::nullopt, mode);
|
||||||
dtype = promote_types(x.dtype(), dtype);
|
dtype = promote_types(x.dtype(), dtype);
|
||||||
|
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
@@ -4082,7 +4090,7 @@ array quantized_matmul(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
if (mode == "affine") {
|
if (qmode == QuantizationMode::Affine) {
|
||||||
inputs = {
|
inputs = {
|
||||||
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
||||||
} else {
|
} else {
|
||||||
@@ -4099,11 +4107,7 @@ array quantized_matmul(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s),
|
to_stream(s), group_size, bits, qmode, transpose),
|
||||||
group_size,
|
|
||||||
bits,
|
|
||||||
string_to_quantization_mode(mode),
|
|
||||||
transpose),
|
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4217,53 +4221,31 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
{w});
|
{w});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> quantize(
|
std::vector<array> fp_quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size /* = 64 */,
|
int group_size,
|
||||||
int bits /* = 4 */,
|
int bits,
|
||||||
const std::string& mode /* = "affine" */,
|
QuantizationMode mode,
|
||||||
StreamOrDevice s /* = {} */) {
|
Stream s) {
|
||||||
validate_mode("quantize", mode);
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||||
if (!issubdtype(w.dtype(), floating)) {
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||||
|
if (group_size != expected_gs) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] Only real floating types can be quantized "
|
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||||
<< "but w has type " << w.dtype() << ".";
|
<< " quantization requires group size " << expected_gs << " but got "
|
||||||
|
<< group_size << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
if (bits != expected_bits) {
|
||||||
if (w.ndim() < 2) {
|
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
msg << "[quantize] " << quantization_mode_to_string(mode)
|
||||||
<< "but it has only " << w.ndim() << ".";
|
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||||
|
<< bits << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
auto fallback = [bits = bits, group_size = group_size, s](
|
||||||
if ((w.shape(-1) % group_size) != 0) {
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
std::ostringstream msg;
|
auto& w = inputs[0];
|
||||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
|
||||||
<< "the quantization group size " << group_size
|
|
||||||
<< ". However the provided "
|
|
||||||
<< " matrix has shape " << w.shape();
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mode == "affine") {
|
|
||||||
return affine_quantize(w, group_size, bits, s);
|
|
||||||
} else {
|
|
||||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
|
||||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
|
||||||
if (group_size != expected_gs) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] " << mode << " quantization requires group size "
|
|
||||||
<< expected_gs << " but got " << group_size << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (bits != expected_bits) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
|
||||||
<< expected_bits << " but got " << bits << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
float maxval = (bits == 4) ? 6.0f : 448.0f;
|
float maxval = (bits == 4) ? 6.0f : 448.0f;
|
||||||
auto new_shape = w.shape();
|
auto new_shape = w.shape();
|
||||||
new_shape.back() = -1;
|
new_shape.back() = -1;
|
||||||
@@ -4314,6 +4296,57 @@ std::vector<array> quantize(
|
|||||||
wq = reshape(wq, new_shape, s);
|
wq = reshape(wq, new_shape, s);
|
||||||
scales = reshape(scales, new_shape, s);
|
scales = reshape(scales, new_shape, s);
|
||||||
return {std::move(wq), std::move(scales)};
|
return {std::move(wq), std::move(scales)};
|
||||||
|
};
|
||||||
|
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
auto wq_shape = w.shape();
|
||||||
|
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||||
|
auto sshape = w.shape();
|
||||||
|
sshape.back() = w.shape(-1) / group_size;
|
||||||
|
return array::make_arrays(
|
||||||
|
{std::move(wq_shape), std::move(sshape)},
|
||||||
|
{uint32, uint8},
|
||||||
|
std::make_shared<fast::Quantize>(
|
||||||
|
s, fallback, group_size, bits, mode, false),
|
||||||
|
{w});
|
||||||
|
}
|
||||||
|
return fallback({w});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> quantize(
|
||||||
|
const array& w,
|
||||||
|
int group_size /* = 64 */,
|
||||||
|
int bits /* = 4 */,
|
||||||
|
const std::string& mode /* = "affine" */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto qmode = string_to_quantization_mode(mode, "quantize");
|
||||||
|
if (!issubdtype(w.dtype(), floating)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] Only real floating types can be quantized "
|
||||||
|
<< "but w has type " << w.dtype() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((w.shape(-1) % group_size) != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||||
|
<< "the quantization group size " << group_size
|
||||||
|
<< ". However the provided "
|
||||||
|
<< " matrix has shape " << w.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qmode == QuantizationMode::Affine) {
|
||||||
|
return affine_quantize(w, group_size, bits, s);
|
||||||
|
} else {
|
||||||
|
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4324,16 +4357,13 @@ array affine_dequantize(
|
|||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
StreamOrDevice s_) {
|
StreamOrDevice s_) {
|
||||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
|
||||||
<< "but it has only " << w.ndim() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto wshape = w.shape();
|
auto wshape = w.shape();
|
||||||
auto sshape = scales.shape();
|
auto sshape = scales.shape();
|
||||||
auto bshape = biases.shape();
|
auto bshape = biases.shape();
|
||||||
|
if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||||
|
}
|
||||||
wshape.back() = -1;
|
wshape.back() = -1;
|
||||||
sshape.back() = -1;
|
sshape.back() = -1;
|
||||||
bshape.back() = -1;
|
bshape.back() = -1;
|
||||||
@@ -4414,88 +4444,66 @@ array affine_dequantize(
|
|||||||
return fallback({w, scales, biases})[0];
|
return fallback({w, scales, biases})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array fp_dequantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const std::optional<array>& biases /* = std::nullopt */,
|
int group_size,
|
||||||
int group_size /* = 64 */,
|
int bits,
|
||||||
int bits /* = 4 */,
|
Dtype out_type,
|
||||||
const std::string& mode /* = "affine" */,
|
QuantizationMode mode,
|
||||||
std::optional<Dtype> dtype /* = std::nullopt */,
|
Stream s) {
|
||||||
StreamOrDevice s /* = {} */) {
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
||||||
validate_mode_with_type("dequantize", scales, biases, mode);
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
||||||
if (bits <= 0) {
|
if (group_size != expected_gs) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||||
|
<< " quantization requires group size " << expected_gs << " but got "
|
||||||
|
<< group_size << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (group_size <= 0) {
|
if (bits != expected_bits) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
msg << "[dequantize] " << quantization_mode_to_string(mode)
|
||||||
|
<< " quantization requires bits to be " << expected_bits << " but got "
|
||||||
|
<< bits << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (w.dtype() != uint32) {
|
|
||||||
|
auto wshape = w.shape();
|
||||||
|
auto sshape = scales.shape();
|
||||||
|
if (wshape.size() != sshape.size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[dequantize] The matrix should be given as a uint32");
|
"[dequantize] Shape of scales does not match the matrix");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode == "affine") {
|
wshape.back() = -1;
|
||||||
auto out = affine_dequantize(w, scales, *biases, group_size, bits, s);
|
sshape.back() = -1;
|
||||||
if (dtype) {
|
|
||||||
out = astype(out, *dtype, s);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
} else {
|
|
||||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
|
||||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
|
||||||
if (group_size != expected_gs) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] " << mode << " quantization requires group size "
|
|
||||||
<< expected_gs << " but got " << group_size << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (bits != expected_bits) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
|
||||||
<< expected_bits << " but got " << bits << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (w.ndim() < 2 || scales.ndim() < 2) {
|
if (wshape != sshape) {
|
||||||
std::ostringstream msg;
|
throw std::invalid_argument(
|
||||||
msg << "[quantize] The matrix to be dequantized must have at least 2 dimension "
|
"[dequantize] Shape of scales does not match the matrix");
|
||||||
<< "but it has only " << w.ndim() << ".";
|
}
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto wshape = w.shape();
|
// Packing into uint32
|
||||||
auto sshape = scales.shape();
|
int out_size = w.shape(-1) * 32 / bits;
|
||||||
wshape.back() = -1;
|
|
||||||
sshape.back() = -1;
|
|
||||||
|
|
||||||
if (wshape != sshape) {
|
if (out_size != scales.shape(-1) * group_size) {
|
||||||
throw std::invalid_argument(
|
std::ostringstream msg;
|
||||||
"[dequantize] Shape of scales does not match the matrix");
|
msg << "[dequantize] Shape of scales does not match the matrix "
|
||||||
}
|
<< "given the quantization parameters. Provided matrix of shape "
|
||||||
|
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
if (w.dtype() != uint32) {
|
auto fallback =
|
||||||
throw std::invalid_argument(
|
[wshape = std::move(wshape),
|
||||||
"[dequantize] The matrix should be given as a uint32");
|
sshape = std::move(sshape),
|
||||||
}
|
group_size,
|
||||||
|
bits,
|
||||||
// Packing into uint32
|
out_type,
|
||||||
int out_size = w.shape(-1) * 32 / bits;
|
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||||
|
auto out = inputs[0];
|
||||||
if (out_size != scales.shape(-1) * group_size) {
|
auto scales = inputs[1];
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[dequantize] Shape of scales does not match the matrix "
|
|
||||||
<< "given the quantization parameters. Provided matrix of shape "
|
|
||||||
<< w.shape() << " and scales of shape " << scales.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto out_type = dtype.has_value() ? *dtype : bfloat16;
|
|
||||||
auto out = w;
|
|
||||||
if (bits == 4) {
|
if (bits == 4) {
|
||||||
auto lut = array(
|
auto lut = array(
|
||||||
{
|
{
|
||||||
@@ -4527,15 +4535,68 @@ array dequantize(
|
|||||||
out = from_fp8(view(out, uint8, s), out_type, s);
|
out = from_fp8(view(out, uint8, s), out_type, s);
|
||||||
}
|
}
|
||||||
out = reshape(out, {-1, group_size}, s);
|
out = reshape(out, {-1, group_size}, s);
|
||||||
auto flat_scales = reshape(scales, {-1, 1}, s);
|
scales = reshape(scales, {-1, 1}, s);
|
||||||
if (group_size == 16) {
|
if (group_size == 16) {
|
||||||
flat_scales = from_fp8(flat_scales, out_type, s);
|
scales = from_fp8(scales, out_type, s);
|
||||||
} else {
|
} else {
|
||||||
flat_scales =
|
scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
|
||||||
subtract(astype(flat_scales, out_type, s), array(127, out_type), s);
|
scales = power(array(2.0f, out_type), scales, s);
|
||||||
flat_scales = power(array(2.0f, out_type), flat_scales, s);
|
|
||||||
}
|
}
|
||||||
return reshape(multiply(out, flat_scales, s), wshape, s);
|
return {reshape(multiply(out, scales, s), wshape, s)};
|
||||||
|
};
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
auto out_shape = w.shape();
|
||||||
|
out_shape.back() = out_size;
|
||||||
|
return array(
|
||||||
|
std::move(out_shape),
|
||||||
|
out_type,
|
||||||
|
std::make_shared<fast::Quantize>(
|
||||||
|
s, fallback, group_size, bits, mode, true),
|
||||||
|
{w, scales});
|
||||||
|
}
|
||||||
|
return fallback({w, scales})[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
array dequantize(
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const std::optional<array>& biases /* = std::nullopt */,
|
||||||
|
int group_size /* = 64 */,
|
||||||
|
int bits /* = 4 */,
|
||||||
|
const std::string& mode /* = "affine" */,
|
||||||
|
std::optional<Dtype> dtype /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto [out_type, qmode] =
|
||||||
|
validate_mode_with_type("dequantize", scales, biases, dtype, mode);
|
||||||
|
if (bits <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (group_size <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (w.dtype() != uint32) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[dequantize] The matrix should be given as a uint32");
|
||||||
|
}
|
||||||
|
if (w.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] The matrix to be dequantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qmode == QuantizationMode::Affine) {
|
||||||
|
return astype(
|
||||||
|
affine_dequantize(w, scales, *biases, group_size, bits, s),
|
||||||
|
out_type,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
return fp_dequantize(
|
||||||
|
w, scales, group_size, bits, out_type, qmode, to_stream(s));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4594,7 +4655,8 @@ array gather_qmm(
|
|||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode);
|
auto [out_type, qmode] =
|
||||||
|
validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode);
|
||||||
out_type = promote_types(x.dtype(), out_type);
|
out_type = promote_types(x.dtype(), out_type);
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
@@ -4634,7 +4696,7 @@ array gather_qmm(
|
|||||||
out_shape.push_back(x.shape(-2));
|
out_shape.push_back(x.shape(-2));
|
||||||
out_shape.push_back(w_outer_dims);
|
out_shape.push_back(w_outer_dims);
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
if (mode == "affine") {
|
if (qmode == QuantizationMode::Affine) {
|
||||||
inputs = {
|
inputs = {
|
||||||
astype(x, out_type, s),
|
astype(x, out_type, s),
|
||||||
std::move(w),
|
std::move(w),
|
||||||
@@ -4657,7 +4719,7 @@ array gather_qmm(
|
|||||||
to_stream(s),
|
to_stream(s),
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
string_to_quantization_mode(mode),
|
qmode,
|
||||||
transpose,
|
transpose,
|
||||||
sorted_indices && !rhs_indices_,
|
sorted_indices && !rhs_indices_,
|
||||||
sorted_indices && !lhs_indices_),
|
sorted_indices && !lhs_indices_),
|
||||||
|
|||||||
@@ -3328,19 +3328,37 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string quantization_mode_to_string(QuantizationMode mode) {
|
std::string quantization_mode_to_string(QuantizationMode mode) {
|
||||||
if (mode == QuantizationMode::Affine) {
|
switch (mode) {
|
||||||
return "affine";
|
case QuantizationMode::Affine:
|
||||||
} else {
|
return "affine";
|
||||||
return "mxfp4";
|
case QuantizationMode::Mxfp4:
|
||||||
|
return "mxfp4";
|
||||||
|
case QuantizationMode::Mxfp8:
|
||||||
|
return "mxfp8";
|
||||||
|
case QuantizationMode::Nvfp4:
|
||||||
|
default:
|
||||||
|
return "nvfp4";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
QuantizationMode string_to_quantization_mode(const std::string& mode) {
|
QuantizationMode string_to_quantization_mode(
|
||||||
|
const std::string& mode,
|
||||||
|
std::string_view tag /* = "" */) {
|
||||||
if (mode == "affine") {
|
if (mode == "affine") {
|
||||||
return QuantizationMode::Affine;
|
return QuantizationMode::Affine;
|
||||||
} else {
|
} else if (mode == "mxfp4") {
|
||||||
return QuantizationMode::Mxfp4;
|
return QuantizationMode::Mxfp4;
|
||||||
|
} else if (mode == "mxfp8") {
|
||||||
|
return QuantizationMode::Mxfp8;
|
||||||
|
} else if (mode == "nvfp4") {
|
||||||
|
return QuantizationMode::Nvfp4;
|
||||||
}
|
}
|
||||||
|
std::string msg;
|
||||||
|
if (!tag.empty()) {
|
||||||
|
msg += "[" + std::string(tag) + "]";
|
||||||
|
}
|
||||||
|
msg += " Invalid quantization mode '" + mode + "'.";
|
||||||
|
throw std::invalid_argument(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
||||||
|
|||||||
@@ -151,10 +151,12 @@ class UnaryPrimitive : public Primitive {
|
|||||||
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class QuantizationMode { Affine, Mxfp4 };
|
enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
|
||||||
|
|
||||||
std::string quantization_mode_to_string(QuantizationMode mode);
|
std::string quantization_mode_to_string(QuantizationMode mode);
|
||||||
QuantizationMode string_to_quantization_mode(const std::string& mode);
|
QuantizationMode string_to_quantization_mode(
|
||||||
|
const std::string& mode,
|
||||||
|
std::string_view error_tag = "");
|
||||||
|
|
||||||
class Abs : public UnaryPrimitive {
|
class Abs : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@@ -61,13 +61,18 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||||
|
|
||||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||||
|
|
||||||
|
# Invalid output type
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.dequantize(
|
||||||
|
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
|
||||||
|
)
|
||||||
|
|
||||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user