[CUDA] Bundle CCCL for JIT compilation (#2357)

* Ship CCCL for JIT compilation

* Remove cexpf
This commit is contained in:
Cheng 2025-07-12 10:45:37 +09:00 committed by GitHub
parent 42cc9cfbc7
commit 6325f60d52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 48 additions and 176 deletions

View File

@ -1,5 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -11,6 +13,17 @@ std::string get_primitive_string(Primitive* primitive) {
return op_t.str(); return op_t.str();
} }
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims( std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape, const Shape& shape,
const std::vector<Strides>& strides, const std::vector<Strides>& strides,

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <filesystem>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@ -11,6 +12,9 @@ namespace mlx::core {
std::string get_primitive_string(Primitive* primitive); std::string get_primitive_string(Primitive* primitive);
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
inline int64_t inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) { elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0; int64_t loc = 0;

View File

@ -125,3 +125,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@ -58,12 +58,7 @@ inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
#if __CUDA_ARCH__ < 800 #if __CUDA_ARCH__ < 800
#if CCCL_VERSION >= 2008000
atomic_add_general(out, val); atomic_add_general(out, val);
#else
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
assert(cccl_version_too_old_for_bfloat16_atomic_add);
#endif
#else #else
atomicAdd(out, val); atomicAdd(out, val);
#endif #endif

View File

@ -1,138 +0,0 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <cuComplex.h>
#include <cuda/std/cstdint>
namespace mlx::core::cu::detail {
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline __device__ void get_float_word(uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void get_float_word(int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void set_float_word(float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline __device__ float frexp_expf(float x, int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = expf(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = cuCrealf(z);
y = cuCimagf(z);
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return cuComplex{
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
}
inline __device__ cuComplex cexpf(const cuComplex& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = cuCrealf(z);
y = cuCimagf(z);
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return cuComplex{expf(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return cuComplex{cosf(y), sinf(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return cuComplex{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return cuComplex{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return cuComplex{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = expf(x);
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
}
}
} // namespace mlx::core::cu::detail

View File

@ -2,12 +2,12 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/cexpf.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h> #include <math_constants.h>
#include <cuda/std/complex>
namespace mlx::core::cu { namespace mlx::core::cu {
@ -152,7 +152,8 @@ struct Exp {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return detail::cexpf(x); auto r = exp(cuda::std::complex<float>{cuCrealf(x), cuCimagf(x)});
return cuComplex{r.real(), r.imag()};
} else { } else {
return exp(x); return exp(x);
} }

View File

@ -13,6 +13,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <unistd.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@ -50,6 +51,16 @@ const std::string& cuda_home() {
return home; return home;
} }
// Return the location of CCCL headers shipped with the distribution.
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
}
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() { const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path { static std::filesystem::path cache = []() -> std::filesystem::path {
@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "atomic_ops.cuh",
INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "cexpf.cuh",
INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "fp16_math.cuh",
@ -178,7 +188,6 @@ constexpr const char* g_headers[] = {
jit_source_atomic_ops, jit_source_atomic_ops,
jit_source_binary_ops, jit_source_binary_ops,
jit_source_cast_op, jit_source_cast_op,
jit_source_cexpf,
jit_source_config, jit_source_config,
jit_source_cucomplex_math, jit_source_cucomplex_math,
jit_source_fp16_math, jit_source_fp16_math,
@ -217,16 +226,23 @@ JitModule::JitModule(
} }
// Compile program. // Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device); bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format( std::string compute = fmt::format(
"--gpu-architecture={}_{}{}", "--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute", use_sass ? "sm" : "compute",
device.compute_capability_major(), device.compute_capability_major(),
device.compute_capability_minor()); device.compute_capability_minor());
std::string include = fmt::format("--include-path={}/include", cuda_home()); args.push_back(compute.c_str());
const char* args[] = {compute.c_str(), include.c_str()}; std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result = nvrtcResult compile_result =
nvrtcCompileProgram(prog, std::size(args), args); nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) { if (compile_result != NVRTC_SUCCESS) {
size_t log_size; size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));

View File

@ -1,20 +1,18 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <filesystem>
#include <sstream> #include <sstream>
#define NS_PRIVATE_IMPLEMENTATION #define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
namespace { namespace {
@ -80,12 +78,7 @@ MTL::Library* try_load_bundle(
std::pair<MTL::Library*, NS::Error*> load_colocated_library( std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device, MTL::Device* device,
const std::string& relative_path) { const std::string& relative_path) {
std::string binary_dir = get_binary_directory(); auto path = current_binary_dir() / relative_path;
if (binary_dir.size() == 0) {
return {nullptr, nullptr};
}
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) { if (!path.has_extension()) {
path.replace_extension(".metallib"); path.replace_extension(".metallib");
} }
@ -197,7 +190,7 @@ MTL::Library* load_library(
std::ostringstream msg; std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. " msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_binary_directory() << "/" << "We attempted to load it from <" << current_binary_dir() << "/"
<< lib_name << ".metallib" << ">"; << lib_name << ".metallib" << ">";
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle."; msg << " and from the Swift PM bundle.";

View File

@ -3,8 +3,6 @@
#pragma once #pragma once
#include <Metal/Metal.hpp> #include <Metal/Metal.hpp>
#include <dlfcn.h>
#include <filesystem>
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <shared_mutex> #include <shared_mutex>
@ -15,22 +13,8 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/device.h" #include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_binary_directory() {
Dl_info info;
std::string directory;
int success = dladdr((void*)get_binary_directory, &info);
if (success) {
directory = fs::path(info.dli_fname).remove_filename().c_str();
}
return directory;
}
using MTLFCList = using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;