// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" namespace mlx::core { namespace { inline array ensure_row_contiguous( const array& x, cu::CommandEncoder& enc, const Stream& s) { if (!x.flags().row_contiguous) { array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } else { return x; } } inline array ensure_row_contiguous_matrix( const array& x, cu::CommandEncoder& enc, const Stream& s) { auto stride_0 = x.strides()[x.ndim() - 2]; auto stride_1 = x.strides()[x.ndim() - 1]; if (stride_0 == x.shape(-1) && stride_1 == 1) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } } } // namespace void fast::AffineQuantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); auto biases = ensure_row_contiguous(inputs[2], enc, s); auto& w = outputs[0]; w.set_data(allocator::malloc(w.nbytes())); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { auto w = ensure_row_contiguous(inputs[0], enc, s); auto& wq = outputs[0]; auto& scales = outputs[1]; auto& biases = outputs[2]; wq.set_data(allocator::malloc(wq.nbytes())); scales.set_data(allocator::malloc(scales.nbytes())); biases.set_data(allocator::malloc(biases.nbytes())); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } } } // namespace mlx::core