mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* Add quantize/dequantize slow path for mxfp8 and nvfp4 * fast cuda kernel for mx/nv quantization * fallback for cuda < 12.8 (#2697) * format (#2700) * fix (#2701) * metal kernels * docs * fix jit * add default bits and group sizes * improve quant docs * fix output type of mxfp4 matmuls
46 lines
823 B
C++
46 lines
823 B
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/cuda/device.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
void affine_quantize(
|
|
const array& w,
|
|
array& wq,
|
|
array& scales,
|
|
array& biases,
|
|
int group_size_,
|
|
int bits_,
|
|
cu::CommandEncoder& enc,
|
|
const Stream& s);
|
|
|
|
void affine_dequantize(
|
|
const array& wq,
|
|
const array& scales,
|
|
const array& biases,
|
|
array& w,
|
|
int group_size_,
|
|
int bits_,
|
|
cu::CommandEncoder& enc,
|
|
const Stream& s);
|
|
|
|
void fp_quantize(
|
|
const array& w,
|
|
array& wq,
|
|
array& scales,
|
|
int group_size,
|
|
int bits,
|
|
cu::CommandEncoder& enc,
|
|
const Stream& s);
|
|
|
|
void fp_dequantize(
|
|
const array& wq,
|
|
const array& scales,
|
|
array& w,
|
|
int group_size,
|
|
int bits,
|
|
cu::CommandEncoder& enc,
|
|
const Stream& s);
|
|
|
|
} // namespace mlx::core
|