mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	| @@ -2,7 +2,9 @@ | ||||
|  | ||||
| #include <cassert> | ||||
|  | ||||
| #include "mlx/backend/metal/copy.h" | ||||
| #include "mlx/backend/common/copy.h" | ||||
| #include "mlx/backend/common/ops.h" | ||||
| #include "mlx/fast_primitives.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
| @@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) { | ||||
|       transpose_); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void quantize( | ||||
|     const array& w_, | ||||
|     array& out_, | ||||
|     array& scales_, | ||||
|     array& biases_, | ||||
|     int bits, | ||||
|     int group_size, | ||||
|     bool compute_scale_bias) { | ||||
|   const T* w = w_.data<T>(); | ||||
|   T* scales = scales_.data<T>(); | ||||
|   T* biases = biases_.data<T>(); | ||||
|   auto out = out_.data<uint32_t>(); | ||||
|  | ||||
|   T n_bins = (1 << bits) - 1; | ||||
|   T eps = 1e-7; | ||||
|   int el_per_int = 32 / bits; | ||||
|   int int_per_group = group_size / el_per_int; | ||||
|   size_t n_groups = w_.size() / group_size; | ||||
|  | ||||
|   for (size_t i = 0; i < n_groups; ++i) { | ||||
|     size_t w_idx = i * group_size; | ||||
|     if (compute_scale_bias) { | ||||
|       T w_min = std::numeric_limits<float>::infinity(); | ||||
|       T w_max = -w_min; | ||||
|       for (int j = 0; j < group_size; ++j) { | ||||
|         w_max = std::max(w_max, w[w_idx + j]); | ||||
|         w_min = std::min(w_min, w[w_idx + j]); | ||||
|       } | ||||
|       bool mask = std::abs(w_min) > std::abs(w_max); | ||||
|       T scale = std::max(T((w_max - w_min) / n_bins), eps); | ||||
|       scale = mask ? scale : -scale; | ||||
|  | ||||
|       auto edge = mask ? w_min : w_max; | ||||
|       auto q0 = std::rint(edge / scale); | ||||
|       if (q0 == 0) { | ||||
|         scales[i] = scale; | ||||
|         biases[i] = 0; | ||||
|       } else { | ||||
|         scales[i] = edge / q0; | ||||
|         biases[i] = edge; | ||||
|       } | ||||
|     } | ||||
|     size_t out_idx = i * int_per_group; | ||||
|     for (int j = 0; j < int_per_group; ++j) { | ||||
|       uint32_t out_el = 0; | ||||
|       for (int k = 0; k < el_per_int; ++k) { | ||||
|         T w_el = w[w_idx + j * el_per_int + k]; | ||||
|         w_el = std::rint((w_el - biases[i]) / scales[i]); | ||||
|         w_el = std::min(std::max(w_el, T(0)), n_bins); | ||||
|         out_el |= static_cast<uint32_t>(w_el) << (k * bits); | ||||
|       } | ||||
|       out[out_idx + j] = out_el; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void fast::AffineQuantize::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   bool compute_scale_bias = inputs.size() == 1; | ||||
|  | ||||
|   auto ensure_row_contiguous = [](const array& arr) { | ||||
|     if (arr.flags().row_contiguous) { | ||||
|       return arr; | ||||
|     } else { | ||||
|       array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||||
|       copy(arr, arr_copy, CopyType::General); | ||||
|       return arr_copy; | ||||
|     } | ||||
|   }; | ||||
|   auto w = ensure_row_contiguous(inputs[0]); | ||||
|  | ||||
|   auto& out = outputs[0]; | ||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|  | ||||
|   auto& scales = | ||||
|       compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]); | ||||
|   auto& biases = | ||||
|       compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]); | ||||
|   if (compute_scale_bias) { | ||||
|     scales.set_data(allocator::malloc_or_wait(scales.nbytes())); | ||||
|     biases.set_data(allocator::malloc_or_wait(biases.nbytes())); | ||||
|   } | ||||
|   if (w.dtype() == float16) { | ||||
|     quantize<float16_t>( | ||||
|         w, out, scales, biases, bits_, group_size_, compute_scale_bias); | ||||
|   } else if (w.dtype() == bfloat16) { | ||||
|     quantize<bfloat16_t>( | ||||
|         w, out, scales, biases, bits_, group_size_, compute_scale_bias); | ||||
|   } else if (w.dtype() == float32) { | ||||
|     quantize<float>( | ||||
|         w, out, scales, biases, bits_, group_size_, compute_scale_bias); | ||||
|   } else { | ||||
|     throw std::runtime_error( | ||||
|         "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| // Copyright © 2024 Apple Inc. | ||||
|  | ||||
| #include "mlx/primitives.h" | ||||
| #include "mlx/fast_primitives.h" | ||||
|  | ||||
| #define NO_CPU_MULTI(func)                                             \ | ||||
|   void func::eval_cpu(                                                 \ | ||||
| @@ -112,4 +113,8 @@ NO_CPU(Transpose) | ||||
| NO_CPU(Inverse) | ||||
| NO_CPU(View) | ||||
|  | ||||
| namespace fast { | ||||
| NO_CPU_MULTI(AffineQuantize) | ||||
| } // namespace fast | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
							
								
								
									
										42
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							| @@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { | ||||
|     }; | ||||
|   }; | ||||
|  | ||||
|   std::vector<array> outputs; | ||||
|   if (s.device == Device::gpu) { | ||||
|     auto wq_shape = w.shape(); | ||||
|     wq_shape.back() = w.shape(-1) / el_per_int; | ||||
|     auto sshape = w.shape(); | ||||
|     sshape.back() = w.shape(-1) / group_size; | ||||
|     outputs = array::make_arrays( | ||||
|         {wq_shape, sshape, sshape}, | ||||
|         {uint32, w.dtype(), w.dtype()}, | ||||
|         std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false), | ||||
|         {w}); | ||||
|   } else { | ||||
|     outputs = fallback({w}); | ||||
|   } | ||||
|   auto wq_shape = w.shape(); | ||||
|   wq_shape.back() = w.shape(-1) / el_per_int; | ||||
|   auto sshape = w.shape(); | ||||
|   sshape.back() = w.shape(-1) / group_size; | ||||
|   auto outputs = array::make_arrays( | ||||
|       {std::move(wq_shape), sshape, sshape}, | ||||
|       {uint32, w.dtype(), w.dtype()}, | ||||
|       std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false), | ||||
|       {w}); | ||||
|   return {outputs[0], outputs[1], outputs[2]}; | ||||
| } | ||||
|  | ||||
| @@ -814,16 +809,13 @@ array affine_quantize( | ||||
|     return {reshape(packed_w, wshape, s)}; | ||||
|   }; | ||||
|  | ||||
|   if (s.device == Device::gpu) { | ||||
|     auto out_shape = w.shape(); | ||||
|     out_shape.back() = w.shape(-1) / el_per_int; | ||||
|     return array( | ||||
|         out_shape, | ||||
|         uint32, | ||||
|         std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false), | ||||
|         {w, scales, biases}); | ||||
|   } | ||||
|   return fallback({w, scales, biases})[0]; | ||||
|   auto out_shape = w.shape(); | ||||
|   out_shape.back() = w.shape(-1) / el_per_int; | ||||
|   return array( | ||||
|       std::move(out_shape), | ||||
|       uint32, | ||||
|       std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false), | ||||
|       {w, scales, biases}); | ||||
| } | ||||
|  | ||||
| array affine_dequantize( | ||||
| @@ -916,7 +908,7 @@ array affine_dequantize( | ||||
|     auto out_shape = w.shape(); | ||||
|     out_shape.back() = w.shape(-1) * el_per_int; | ||||
|     return array( | ||||
|         out_shape, | ||||
|         std::move(out_shape), | ||||
|         scales.dtype(), | ||||
|         std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true), | ||||
|         {w, scales, biases}); | ||||
|   | ||||
| @@ -228,9 +228,7 @@ class AffineQuantize : public Custom { | ||||
|         dequantize_(dequantize) {} | ||||
|  | ||||
|   void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override { | ||||
|     throw std::runtime_error("NYI"); | ||||
|   } | ||||
|       override; | ||||
|  | ||||
|   void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun