diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 9766e5e0c..942f9576e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -1,5 +1,7 @@ // Copyright © 2023-2024 Apple Inc. +#include + #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -11,6 +13,17 @@ std::string get_primitive_string(Primitive* primitive) { return op_t.str(); } +std::filesystem::path current_binary_dir() { + static std::filesystem::path binary_dir = []() { + Dl_info info; + if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { + throw std::runtime_error("Unable to get current binary dir."); + } + return std::filesystem::path(info.dli_fname).parent_path(); + }(); + return binary_dir; +} + std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 114878846..543868e36 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include @@ -11,6 +12,9 @@ namespace mlx::core { std::string get_primitive_string(Primitive* primitive); +// Return the directory that contains current shared library. +std::filesystem::path current_binary_dir(); + inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 87f4cb4ae..29f2eeab6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -125,3 +125,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) + +# Install CCCL headers for JIT. +install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh index b6915606e..e0d3c3eac 100644 --- a/mlx/backend/cuda/device/atomic_ops.cuh +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -58,12 +58,7 @@ inline __device__ void atomic_add(cuComplex* out, cuComplex val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { #if __CUDA_ARCH__ < 800 -#if CCCL_VERSION >= 2008000 atomic_add_general(out, val); -#else - bool cccl_version_too_old_for_bfloat16_atomic_add = false; - assert(cccl_version_too_old_for_bfloat16_atomic_add); -#endif #else atomicAdd(out, val); #endif diff --git a/mlx/backend/cuda/device/cexpf.cuh b/mlx/backend/cuda/device/cexpf.cuh deleted file mode 100644 index 61c94c00f..000000000 --- a/mlx/backend/cuda/device/cexpf.cuh +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright © 2025 Apple Inc. -// Copyright © 2008-2013 NVIDIA Corporation -// Copyright © 2013 Filipe RNC Maia -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Forked from -// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h - -// TODO: We should use thrust::exp but the thrust header in old CUDA versions -// can not be used in JIT. - -#pragma once - -#include -#include - -namespace mlx::core::cu::detail { - -using ieee_float_shape_type = union { - float value; - uint32_t word; -}; - -inline __device__ void get_float_word(uint32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline __device__ void get_float_word(int32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline __device__ void set_float_word(float& d, uint32_t i) { - ieee_float_shape_type sf_u; - sf_u.word = (i); - (d) = sf_u.value; -} - -inline __device__ float frexp_expf(float x, int* expt) { - const uint32_t k = 235; - const float kln2 = 162.88958740F; - - float exp_x; - uint32_t hx; - - exp_x = expf(x - kln2); - get_float_word(hx, exp_x); - *expt = (hx >> 23) - (0x7f + 127) + k; - set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); - return exp_x; -} - -inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) { - float x, y, exp_x, scale1, scale2; - int ex_expt, half_expt; - - x = cuCrealf(z); - y = cuCimagf(z); - exp_x = frexp_expf(x, &ex_expt); - expt += ex_expt; - - half_expt = expt / 2; - set_float_word(scale1, (0x7f + half_expt) << 23); - half_expt = expt - half_expt; - set_float_word(scale2, (0x7f + half_expt) << 23); - - return cuComplex{ - cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2}; -} - -inline __device__ cuComplex cexpf(const cuComplex& z) { - float x, y, exp_x; - uint32_t hx, hy; - - const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; - - x = cuCrealf(z); - y = cuCimagf(z); - - get_float_word(hy, y); - hy &= 0x7fffffff; - - /* cexp(x + I 0) = exp(x) + I 0 */ - if (hy == 0) { - return cuComplex{expf(x), y}; - } - get_float_word(hx, x); - /* cexp(0 + I y) = cos(y) + I sin(y) */ - if ((hx & 0x7fffffff) == 0) { - return cuComplex{cosf(y), sinf(y)}; - } - if (hy >= 0x7f800000) { - if ((hx & 0x7fffffff) != 0x7f800000) { - /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ - return cuComplex{y - y, y - y}; - } else if (hx & 0x80000000) { - /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ - return cuComplex{0.0, 0.0}; - } else { - /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ - return cuComplex{x, y - y}; - } - } - - if (hx >= exp_ovfl && hx <= cexp_ovfl) { - /* - * x is between 88.7 and 192, so we must scale to avoid - * overflow in expf(x). - */ - return ldexp_cexpf(z, 0); - } else { - /* - * Cases covered here: - * - x < exp_ovfl and exp(x) won't overflow (common case) - * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 - * - x = +-Inf (generated by exp()) - * - x = NaN (spurious inexact exception from y) - */ - exp_x = expf(x); - return cuComplex{exp_x * cosf(y), exp_x * sinf(y)}; - } -} - -} // namespace mlx::core::cu::detail diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 8716d3a8c..447569eeb 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,12 +2,12 @@ #pragma once -#include "mlx/backend/cuda/device/cexpf.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include +#include namespace mlx::core::cu { @@ -152,7 +152,8 @@ struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { - return detail::cexpf(x); + auto r = exp(cuda::std::complex{cuCrealf(x), cuCimagf(x)}); + return cuComplex{r.real(), r.imag()}; } else { return exp(x); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 834e4a3d1..4ce79999e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace mlx::core::cu { @@ -50,6 +51,16 @@ const std::string& cuda_home() { return home; } +// Return the location of CCCL headers shipped with the distribution. +bool get_cccl_include(std::string* out) { + auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; + if (!std::filesystem::exists(cccl_headers)) { + return false; + } + *out = fmt::format("--include-path={}", cccl_headers.string()); + return true; +} + // Get the cache directory for storing compiled results. const std::filesystem::path& ptx_cache_dir() { static std::filesystem::path cache = []() -> std::filesystem::path { @@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", - INCLUDE_PREFIX "cexpf.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", @@ -178,7 +188,6 @@ constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, - jit_source_cexpf, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, @@ -217,16 +226,23 @@ JitModule::JitModule( } // Compile program. + std::vector args; bool use_sass = compiler_supports_device_sass(device); std::string compute = fmt::format( "--gpu-architecture={}_{}{}", use_sass ? "sm" : "compute", device.compute_capability_major(), device.compute_capability_minor()); - std::string include = fmt::format("--include-path={}/include", cuda_home()); - const char* args[] = {compute.c_str(), include.c_str()}; + args.push_back(compute.c_str()); + std::string cccl_include; + if (get_cccl_include(&cccl_include)) { + args.push_back(cccl_include.c_str()); + } + std::string cuda_include = + fmt::format("--include-path={}/include", cuda_home()); + args.push_back(cuda_include.c_str()); nvrtcResult compile_result = - nvrtcCompileProgram(prog, std::size(args), args); + nvrtcCompileProgram(prog, args.size(), args.data()); if (compile_result != NVRTC_SUCCESS) { size_t log_size; CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 88835eb75..e22d9da2d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,20 +1,18 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { namespace { @@ -80,12 +78,7 @@ MTL::Library* try_load_bundle( std::pair load_colocated_library( MTL::Device* device, const std::string& relative_path) { - std::string binary_dir = get_binary_directory(); - if (binary_dir.size() == 0) { - return {nullptr, nullptr}; - } - - auto path = fs::path(binary_dir) / relative_path; + auto path = current_binary_dir() / relative_path; if (!path.has_extension()) { path.replace_extension(".metallib"); } @@ -197,7 +190,7 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_binary_directory() << "/" + << "We attempted to load it from <" << current_binary_dir() << "/" << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index f87a8c48b..52595e6e6 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -3,8 +3,6 @@ #pragma once #include -#include -#include #include #include #include @@ -15,22 +13,8 @@ #include "mlx/array.h" #include "mlx/device.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { -// Note, this function must be left inline in a header so that it is not -// dynamically linked. -inline std::string get_binary_directory() { - Dl_info info; - std::string directory; - int success = dladdr((void*)get_binary_directory, &info); - if (success) { - directory = fs::path(info.dli_fname).remove_filename().c_str(); - } - return directory; -} - using MTLFCList = std::vector>;