// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void arange(T* out, IdxT size, T start, T step) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_WRITES > size) { for (IdxT i = index * N_WRITES; i < size; ++i) { out[i] = start + i * step; } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_WRITES; ++i) { out_vec[i] = start + (index * N_WRITES + i) * step; } store_vector(out, index, out_vec); } } } // namespace cu void Arange::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Arange::eval_gpu"); if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(stream()); out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); encoder.set_output_array(out); dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; constexpr int N_WRITES = 16 / sizeof(OutType); dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES); encoder.add_kernel_node( cu::arange, num_blocks, block_dims, 0, gpu_ptr(out), out.data_size(), static_cast(start_), static_cast(start_ + step_) - static_cast(start_)); }); }); } } // namespace mlx::core