mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
[CUDA] Bundle CCCL for JIT compilation (#2357)
* Ship CCCL for JIT compilation * Remove cexpf
This commit is contained in:
parent
42cc9cfbc7
commit
6325f60d52
@ -1,5 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -11,6 +13,17 @@ std::string get_primitive_string(Primitive* primitive) {
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
std::filesystem::path current_binary_dir() {
|
||||
static std::filesystem::path binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
@ -11,6 +12,9 @@ namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
// Return the directory that contains current shared library.
|
||||
std::filesystem::path current_binary_dir();
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
|
@ -125,3 +125,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
@ -58,12 +58,7 @@ inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if CCCL_VERSION >= 2008000
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||
#endif
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
|
@ -1,138 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2008-2013 NVIDIA Corporation
|
||||
// Copyright © 2013 Filipe RNC Maia
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
||||
|
||||
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
||||
// can not be used in JIT.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/cstdint>
|
||||
|
||||
namespace mlx::core::cu::detail {
|
||||
|
||||
using ieee_float_shape_type = union {
|
||||
float value;
|
||||
uint32_t word;
|
||||
};
|
||||
|
||||
inline __device__ void get_float_word(uint32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline __device__ void get_float_word(int32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline __device__ void set_float_word(float& d, uint32_t i) {
|
||||
ieee_float_shape_type sf_u;
|
||||
sf_u.word = (i);
|
||||
(d) = sf_u.value;
|
||||
}
|
||||
|
||||
inline __device__ float frexp_expf(float x, int* expt) {
|
||||
const uint32_t k = 235;
|
||||
const float kln2 = 162.88958740F;
|
||||
|
||||
float exp_x;
|
||||
uint32_t hx;
|
||||
|
||||
exp_x = expf(x - kln2);
|
||||
get_float_word(hx, exp_x);
|
||||
*expt = (hx >> 23) - (0x7f + 127) + k;
|
||||
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
||||
return exp_x;
|
||||
}
|
||||
|
||||
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
|
||||
float x, y, exp_x, scale1, scale2;
|
||||
int ex_expt, half_expt;
|
||||
|
||||
x = cuCrealf(z);
|
||||
y = cuCimagf(z);
|
||||
exp_x = frexp_expf(x, &ex_expt);
|
||||
expt += ex_expt;
|
||||
|
||||
half_expt = expt / 2;
|
||||
set_float_word(scale1, (0x7f + half_expt) << 23);
|
||||
half_expt = expt - half_expt;
|
||||
set_float_word(scale2, (0x7f + half_expt) << 23);
|
||||
|
||||
return cuComplex{
|
||||
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
|
||||
}
|
||||
|
||||
inline __device__ cuComplex cexpf(const cuComplex& z) {
|
||||
float x, y, exp_x;
|
||||
uint32_t hx, hy;
|
||||
|
||||
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
||||
|
||||
x = cuCrealf(z);
|
||||
y = cuCimagf(z);
|
||||
|
||||
get_float_word(hy, y);
|
||||
hy &= 0x7fffffff;
|
||||
|
||||
/* cexp(x + I 0) = exp(x) + I 0 */
|
||||
if (hy == 0) {
|
||||
return cuComplex{expf(x), y};
|
||||
}
|
||||
get_float_word(hx, x);
|
||||
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
||||
if ((hx & 0x7fffffff) == 0) {
|
||||
return cuComplex{cosf(y), sinf(y)};
|
||||
}
|
||||
if (hy >= 0x7f800000) {
|
||||
if ((hx & 0x7fffffff) != 0x7f800000) {
|
||||
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
||||
return cuComplex{y - y, y - y};
|
||||
} else if (hx & 0x80000000) {
|
||||
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
||||
return cuComplex{0.0, 0.0};
|
||||
} else {
|
||||
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
||||
return cuComplex{x, y - y};
|
||||
}
|
||||
}
|
||||
|
||||
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
||||
/*
|
||||
* x is between 88.7 and 192, so we must scale to avoid
|
||||
* overflow in expf(x).
|
||||
*/
|
||||
return ldexp_cexpf(z, 0);
|
||||
} else {
|
||||
/*
|
||||
* Cases covered here:
|
||||
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
||||
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
||||
* - x = +-Inf (generated by exp())
|
||||
* - x = NaN (spurious inexact exception from y)
|
||||
*/
|
||||
exp_x = expf(x);
|
||||
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu::detail
|
@ -2,12 +2,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cexpf.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
#include <cuda/std/complex>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
@ -152,7 +152,8 @@ struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return detail::cexpf(x);
|
||||
auto r = exp(cuda::std::complex<float>{cuCrealf(x), cuCimagf(x)});
|
||||
return cuComplex{r.real(), r.imag()};
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
@ -50,6 +51,16 @@ const std::string& cuda_home() {
|
||||
return home;
|
||||
}
|
||||
|
||||
// Return the location of CCCL headers shipped with the distribution.
|
||||
bool get_cccl_include(std::string* out) {
|
||||
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
|
||||
if (!std::filesystem::exists(cccl_headers)) {
|
||||
return false;
|
||||
}
|
||||
*out = fmt::format("--include-path={}", cccl_headers.string());
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
const std::filesystem::path& ptx_cache_dir() {
|
||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||
@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "cexpf.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
@ -178,7 +188,6 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_atomic_ops,
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_cexpf,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
@ -217,16 +226,23 @@ JitModule::JitModule(
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
||||
const char* args[] = {compute.c_str(), include.c_str()};
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include;
|
||||
if (get_cccl_include(&cccl_include)) {
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, std::size(args), args);
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
|
@ -1,20 +1,18 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@ -80,12 +78,7 @@ MTL::Library* try_load_bundle(
|
||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||
MTL::Device* device,
|
||||
const std::string& relative_path) {
|
||||
std::string binary_dir = get_binary_directory();
|
||||
if (binary_dir.size() == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
auto path = fs::path(binary_dir) / relative_path;
|
||||
auto path = current_binary_dir() / relative_path;
|
||||
if (!path.has_extension()) {
|
||||
path.replace_extension(".metallib");
|
||||
}
|
||||
@ -197,7 +190,7 @@ MTL::Library* load_library(
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
||||
<< "We attempted to load it from <" << current_binary_dir() << "/"
|
||||
<< lib_name << ".metallib" << ">";
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
msg << " and from the Swift PM bundle.";
|
||||
|
@ -3,8 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -15,22 +13,8 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_binary_directory() {
|
||||
Dl_info info;
|
||||
std::string directory;
|
||||
int success = dladdr((void*)get_binary_directory, &info);
|
||||
if (success) {
|
||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
||||
}
|
||||
return directory;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user