mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
CUDA backend: unary ops (#2158)
This commit is contained in:
parent
5866b3857b
commit
f8bad60609
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
|
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,32 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
|
@ -15,6 +15,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
|
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
121
mlx/backend/cuda/iterators/general_iterator.cuh
Normal file
@ -0,0 +1,121 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <cuda/std/utility>
|
||||
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Iterating non-contiguous array.
|
||||
template <typename Iterator, typename IdxT = int64_t>
|
||||
class general_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides)
|
||||
: super_t(it),
|
||||
index_(index),
|
||||
ndim_(ndim),
|
||||
shape_(cuda::std::move(shape)),
|
||||
strides_(cuda::std::move(strides)) {}
|
||||
|
||||
__host__ __device__ IdxT index() const {
|
||||
return index_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Shape& shape() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
__host__ __device__ const Strides& strides() const {
|
||||
return strides_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
||||
return this->base() == other.base() && this->index() == other.index();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->index_ += n;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->index_ += 1;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->index_ -= 1;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const general_iterator& other) const {
|
||||
_CCCL_ASSERT(
|
||||
this->base() == other.base(),
|
||||
"Underlying iterator must point to same base iterator");
|
||||
return other.index() - this->index();
|
||||
}
|
||||
|
||||
// The dereference is device-only to avoid accidental running in host.
|
||||
__device__ typename super_t::reference dereference() const {
|
||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
||||
return *(this->base() + offset);
|
||||
}
|
||||
|
||||
IdxT index_;
|
||||
int ndim_;
|
||||
Shape shape_;
|
||||
Strides strides_;
|
||||
};
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
__host__ __device__ auto make_general_iterator(
|
||||
Iterator it,
|
||||
IdxT index,
|
||||
int ndim,
|
||||
Shape shape,
|
||||
Strides strides) {
|
||||
return general_iterator<Iterator, IdxT>(
|
||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterator(
|
||||
Iterator it,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
return make_general_iterator<IdxT>(
|
||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
||||
}
|
||||
|
||||
template <typename IdxT, typename Iterator>
|
||||
auto make_general_iterators(
|
||||
Iterator it,
|
||||
IdxT size,
|
||||
const std::vector<int32_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
auto ndim = shape.size();
|
||||
auto shape_arg = const_param(shape);
|
||||
auto strides_arg = const_param(strides);
|
||||
return std::make_pair(
|
||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -7,10 +7,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -38,6 +40,24 @@ struct CTypeToCudaType<complex64_t> {
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
// Type traits for detecting floating numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
cuda::std::array<T, NDIM> result;
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
|
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
240
mlx/backend/cuda/kernels/cucomplex_math.cuh
Normal file
@ -0,0 +1,240 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
||||
//
|
||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance with the
|
||||
// License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
// This header provides some helper functions for cuComplex types.
|
||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
||||
// all provided by CUDA
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCadd(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCsub(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCmul(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
return cuCdiv(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
||||
return make_cuDoubleComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuDoubleComplex& a,
|
||||
const cuDoubleComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator+(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator-(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator*(double a, const cuDoubleComplex& b) {
|
||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(const cuDoubleComplex& a, double b) {
|
||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
||||
operator/(double a, const cuDoubleComplex& b) {
|
||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
||||
return make_cuDoubleComplex(
|
||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCaddf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCsubf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCmulf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
return cuCdivf(a, b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
||||
return make_cuFloatComplex(r, i);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator==(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator!=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
||||
return mag_a > mag_b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator>=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return a > b || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ bool operator<=(
|
||||
const cuFloatComplex& a,
|
||||
const cuFloatComplex& b) {
|
||||
return b > a || a == b;
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator+(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator-(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator*(float a, const cuFloatComplex& b) {
|
||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(const cuFloatComplex& a, float b) {
|
||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
||||
}
|
||||
|
||||
__forceinline__ __host__ __device__ cuFloatComplex
|
||||
operator/(float a, const cuFloatComplex& b) {
|
||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
||||
return make_cuFloatComplex(
|
||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
||||
}
|
@ -9,6 +9,78 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return ::NAME(__half2float(x)); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return ::NAME(__bfloat162float(x)); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
|
||||
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||
#if __CUDA_ARCH__ >= 1280
|
||||
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||
#else
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||
#endif
|
||||
|
||||
#undef MLX_DEFINE_UNARY_OP
|
||||
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
349
mlx/backend/cuda/kernels/unary_ops.cuh
Normal file
@ -0,0 +1,349 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return {cuCrealf(x), -cuCimagf(x)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return cos(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return cosh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCimagf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return 0 - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
__device__ float operator()(cuComplex x) {
|
||||
return cuCrealf(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
||||
} else {
|
||||
return sin(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return {
|
||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
||||
} else {
|
||||
return sinh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tan_a = tan(cuCrealf(x));
|
||||
float tanh_b = tanh(cuCimagf(x));
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
} else {
|
||||
return tan(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float tanh_a = tanh(cuCrealf(x));
|
||||
float tan_b = tan(cuCimagf(x));
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
43
mlx/backend/cuda/kernels/utils.cuh
Normal file
43
mlx/backend/cuda/kernels/utils.cuh
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -71,39 +71,22 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BitwiseInvert)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
@ -111,13 +94,9 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
@ -126,33 +105,22 @@ NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
|
196
mlx/backend/cuda/unary.cu
Normal file
196
mlx/backend/cuda/unary.cu
Normal file
@ -0,0 +1,196 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
if (in.flags().contiguous) {
|
||||
thrust::transform(
|
||||
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.data_size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Round::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (recip_) {
|
||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||
} else {
|
||||
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -99,21 +100,7 @@ void unary_op_gpu(
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user