// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" #include namespace mlx::core { namespace { inline array ensure_row_contiguous( const array& x, cu::CommandEncoder& enc, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } else { return x; } } inline array ensure_row_contiguous_matrix( const array& x, cu::CommandEncoder& enc, const Stream& s) { if (x.ndim() < 2) { if (x.strides()[0] == 1) { return x; } } else { auto stride_0 = x.strides()[x.ndim() - 2]; auto stride_1 = x.strides()[x.ndim() - 1]; if (stride_0 == x.shape(-1) && stride_1 == 1) { return x; } } array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } } // namespace void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("Quantize::eval_gpu"); auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); auto& w = outputs[0]; w.set_data(cu::malloc_async(w.nbytes(), enc)); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); auto& wq = outputs[0]; auto& scales = outputs[1]; wq.set_data(cu::malloc_async(wq.nbytes(), enc)); scales.set_data(cu::malloc_async(scales.nbytes(), enc)); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { fp_quantize(w, wq, scales, group_size_, bits_, enc, s); } } } } // namespace mlx::core