mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 01:06:43 +08:00
Move arange to its own file (#2438)
This commit is contained in:
parent
254476718b
commit
a0ae49d397
@ -6,6 +6,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||||
@ -29,7 +30,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
|
55
mlx/backend/cuda/arange.cu
Normal file
55
mlx/backend/cuda/arange.cu
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Arange {
|
||||||
|
const T start;
|
||||||
|
const T step;
|
||||||
|
|
||||||
|
__device__ T operator()(uint32_t i) const {
|
||||||
|
return start + i * step;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
|
CTYPE step =
|
||||||
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||||
|
thrust::transform(
|
||||||
|
cu::thrust_policy(encoder.stream()),
|
||||||
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
|
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
cu::Arange<OutType>{
|
||||||
|
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,15 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Arange {
|
|
||||||
const T start;
|
|
||||||
const T step;
|
|
||||||
|
|
||||||
__device__ T operator()(uint32_t i) const {
|
|
||||||
return start + i * step;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
@ -1,47 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/arange.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <thrust/transform.h>
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
|
||||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
|
||||||
using OutType = cuda_type_t<CTYPE>;
|
|
||||||
CTYPE step =
|
|
||||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
|
||||||
thrust::transform(
|
|
||||||
cu::thrust_policy(encoder.stream()),
|
|
||||||
thrust::counting_iterator<uint32_t>(0),
|
|
||||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
|
||||||
thrust::device_pointer_cast(out.data<OutType>()),
|
|
||||||
cu::Arange<OutType>{
|
|
||||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
bool fast::ScaledDotProductAttention::use_fallback(
|
||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
Loading…
Reference in New Issue
Block a user