mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Move arange to its own file (#2438)
This commit is contained in:
		| @@ -6,6 +6,7 @@ | ||||
| target_sources( | ||||
|   mlx | ||||
|   PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu | ||||
| @@ -29,7 +30,7 @@ target_sources( | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/random.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu | ||||
|   | ||||
							
								
								
									
										55
									
								
								mlx/backend/cuda/arange.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								mlx/backend/cuda/arange.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/device/fp16_math.cuh" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
| #include "mlx/dtype_utils.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| #include <nvtx3/nvtx3.hpp> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/transform.h> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace cu { | ||||
|  | ||||
| template <typename T> | ||||
| struct Arange { | ||||
|   const T start; | ||||
|   const T step; | ||||
|  | ||||
|   __device__ T operator()(uint32_t i) const { | ||||
|     return start + i * step; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace cu | ||||
|  | ||||
| void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   nvtx3::scoped_range r("Arange::eval_gpu"); | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|  | ||||
|   auto& encoder = cu::get_command_encoder(stream()); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   auto capture = encoder.capture_context(); | ||||
|   dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { | ||||
|     using CTYPE = MLX_GET_TYPE(type_tag); | ||||
|     using OutType = cuda_type_t<CTYPE>; | ||||
|     CTYPE step = | ||||
|         static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); | ||||
|     thrust::transform( | ||||
|         cu::thrust_policy(encoder.stream()), | ||||
|         thrust::counting_iterator<uint32_t>(0), | ||||
|         thrust::counting_iterator<uint32_t>(out.data_size()), | ||||
|         thrust::device_pointer_cast(out.data<OutType>()), | ||||
|         cu::Arange<OutType>{ | ||||
|             static_cast<OutType>(start_), static_cast<OutType>(step)}); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
| @@ -1,15 +0,0 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| template <typename T> | ||||
| struct Arange { | ||||
|   const T start; | ||||
|   const T step; | ||||
|  | ||||
|   __device__ T operator()(uint32_t i) const { | ||||
|     return start + i * step; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
| @@ -1,47 +1,11 @@ | ||||
| // Copyright © 2025 Apple Inc.
 | ||||
| 
 | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/device/arange.cuh" | ||||
| #include "mlx/backend/cuda/device/fp16_math.cuh" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
| #include "mlx/distributed/primitives.h" | ||||
| #include "mlx/dtype_utils.h" | ||||
| #include "mlx/fast_primitives.h" | ||||
| #include "mlx/primitives.h" | ||||
| 
 | ||||
| #include <nvtx3/nvtx3.hpp> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/transform.h> | ||||
| 
 | ||||
| #include <cassert> | ||||
| 
 | ||||
| namespace mlx::core { | ||||
| 
 | ||||
| void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   nvtx3::scoped_range r("Arange::eval_gpu"); | ||||
|   assert(inputs.size() == 0); | ||||
|   out.set_data(allocator::malloc(out.nbytes())); | ||||
|   if (out.size() == 0) { | ||||
|     return; | ||||
|   } | ||||
|   auto& encoder = cu::get_command_encoder(stream()); | ||||
|   encoder.set_output_array(out); | ||||
|   auto capture = encoder.capture_context(); | ||||
|   dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { | ||||
|     using CTYPE = MLX_GET_TYPE(type_tag); | ||||
|     using OutType = cuda_type_t<CTYPE>; | ||||
|     CTYPE step = | ||||
|         static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); | ||||
|     thrust::transform( | ||||
|         cu::thrust_policy(encoder.stream()), | ||||
|         thrust::counting_iterator<uint32_t>(0), | ||||
|         thrust::counting_iterator<uint32_t>(out.data_size()), | ||||
|         thrust::device_pointer_cast(out.data<OutType>()), | ||||
|         cu::Arange<OutType>{ | ||||
|             static_cast<OutType>(start_), static_cast<OutType>(step)}); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| bool fast::ScaledDotProductAttention::use_fallback( | ||||
|     const array& q, | ||||
|     const array& k, | ||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng