mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Move arange to its own file (#2438)
This commit is contained in:
		@@ -6,6 +6,7 @@
 | 
				
			|||||||
target_sources(
 | 
					target_sources(
 | 
				
			||||||
  mlx
 | 
					  mlx
 | 
				
			||||||
  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
 | 
					  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
 | 
				
			||||||
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
 | 
				
			||||||
@@ -29,7 +30,7 @@ target_sources(
 | 
				
			|||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
 | 
				
			||||||
          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
 | 
					          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										55
									
								
								mlx/backend/cuda/arange.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								mlx/backend/cuda/arange.cu
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
				
			|||||||
 | 
					// Copyright © 2025 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlx/backend/cuda/device.h"
 | 
				
			||||||
 | 
					#include "mlx/backend/cuda/device/fp16_math.cuh"
 | 
				
			||||||
 | 
					#include "mlx/backend/cuda/kernel_utils.cuh"
 | 
				
			||||||
 | 
					#include "mlx/dtype_utils.h"
 | 
				
			||||||
 | 
					#include "mlx/primitives.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <nvtx3/nvtx3.hpp>
 | 
				
			||||||
 | 
					#include <thrust/device_ptr.h>
 | 
				
			||||||
 | 
					#include <thrust/transform.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mlx::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace cu {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					struct Arange {
 | 
				
			||||||
 | 
					  const T start;
 | 
				
			||||||
 | 
					  const T step;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  __device__ T operator()(uint32_t i) const {
 | 
				
			||||||
 | 
					    return start + i * step;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace cu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			||||||
 | 
					  nvtx3::scoped_range r("Arange::eval_gpu");
 | 
				
			||||||
 | 
					  if (out.size() == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  out.set_data(allocator::malloc(out.nbytes()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(stream());
 | 
				
			||||||
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto capture = encoder.capture_context();
 | 
				
			||||||
 | 
					  dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
 | 
				
			||||||
 | 
					    using CTYPE = MLX_GET_TYPE(type_tag);
 | 
				
			||||||
 | 
					    using OutType = cuda_type_t<CTYPE>;
 | 
				
			||||||
 | 
					    CTYPE step =
 | 
				
			||||||
 | 
					        static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
 | 
				
			||||||
 | 
					    thrust::transform(
 | 
				
			||||||
 | 
					        cu::thrust_policy(encoder.stream()),
 | 
				
			||||||
 | 
					        thrust::counting_iterator<uint32_t>(0),
 | 
				
			||||||
 | 
					        thrust::counting_iterator<uint32_t>(out.data_size()),
 | 
				
			||||||
 | 
					        thrust::device_pointer_cast(out.data<OutType>()),
 | 
				
			||||||
 | 
					        cu::Arange<OutType>{
 | 
				
			||||||
 | 
					            static_cast<OutType>(start_), static_cast<OutType>(step)});
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // namespace mlx::core
 | 
				
			||||||
@@ -1,15 +0,0 @@
 | 
				
			|||||||
// Copyright © 2025 Apple Inc.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace mlx::core::cu {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <typename T>
 | 
					 | 
				
			||||||
struct Arange {
 | 
					 | 
				
			||||||
  const T start;
 | 
					 | 
				
			||||||
  const T step;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  __device__ T operator()(uint32_t i) const {
 | 
					 | 
				
			||||||
    return start + i * step;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
} // namespace mlx::core::cu
 | 
					 | 
				
			||||||
@@ -1,47 +1,11 @@
 | 
				
			|||||||
// Copyright © 2025 Apple Inc.
 | 
					// Copyright © 2025 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/backend/cuda/device.h"
 | 
					 | 
				
			||||||
#include "mlx/backend/cuda/device/arange.cuh"
 | 
					 | 
				
			||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
 | 
					 | 
				
			||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
 | 
					 | 
				
			||||||
#include "mlx/distributed/primitives.h"
 | 
					#include "mlx/distributed/primitives.h"
 | 
				
			||||||
#include "mlx/dtype_utils.h"
 | 
					 | 
				
			||||||
#include "mlx/fast_primitives.h"
 | 
					#include "mlx/fast_primitives.h"
 | 
				
			||||||
#include "mlx/primitives.h"
 | 
					#include "mlx/primitives.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <nvtx3/nvtx3.hpp>
 | 
					 | 
				
			||||||
#include <thrust/device_ptr.h>
 | 
					 | 
				
			||||||
#include <thrust/transform.h>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include <cassert>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace mlx::core {
 | 
					namespace mlx::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
					 | 
				
			||||||
  nvtx3::scoped_range r("Arange::eval_gpu");
 | 
					 | 
				
			||||||
  assert(inputs.size() == 0);
 | 
					 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					 | 
				
			||||||
  if (out.size() == 0) {
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(stream());
 | 
					 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					 | 
				
			||||||
  auto capture = encoder.capture_context();
 | 
					 | 
				
			||||||
  dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
 | 
					 | 
				
			||||||
    using CTYPE = MLX_GET_TYPE(type_tag);
 | 
					 | 
				
			||||||
    using OutType = cuda_type_t<CTYPE>;
 | 
					 | 
				
			||||||
    CTYPE step =
 | 
					 | 
				
			||||||
        static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
 | 
					 | 
				
			||||||
    thrust::transform(
 | 
					 | 
				
			||||||
        cu::thrust_policy(encoder.stream()),
 | 
					 | 
				
			||||||
        thrust::counting_iterator<uint32_t>(0),
 | 
					 | 
				
			||||||
        thrust::counting_iterator<uint32_t>(out.data_size()),
 | 
					 | 
				
			||||||
        thrust::device_pointer_cast(out.data<OutType>()),
 | 
					 | 
				
			||||||
        cu::Arange<OutType>{
 | 
					 | 
				
			||||||
            static_cast<OutType>(start_), static_cast<OutType>(step)});
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
bool fast::ScaledDotProductAttention::use_fallback(
 | 
					bool fast::ScaledDotProductAttention::use_fallback(
 | 
				
			||||||
    const array& q,
 | 
					    const array& q,
 | 
				
			||||||
    const array& k,
 | 
					    const array& k,
 | 
				
			||||||
		Reference in New Issue
	
	Block a user