mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
[CUDA] Bundle CCCL for JIT compilation (#2357)
* Ship CCCL for JIT compilation * Remove cexpf
This commit is contained in:
parent
42cc9cfbc7
commit
6325f60d52
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -11,6 +13,17 @@ std::string get_primitive_string(Primitive* primitive) {
|
|||||||
return op_t.str();
|
return op_t.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::filesystem::path current_binary_dir() {
|
||||||
|
static std::filesystem::path binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -11,6 +12,9 @@ namespace mlx::core {
|
|||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
|
|
||||||
|
// Return the directory that contains current shared library.
|
||||||
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
|
@ -125,3 +125,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
|||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
|
||||||
|
# Install CCCL headers for JIT.
|
||||||
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
|
@ -58,12 +58,7 @@ inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
|||||||
|
|
||||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||||
#if __CUDA_ARCH__ < 800
|
#if __CUDA_ARCH__ < 800
|
||||||
#if CCCL_VERSION >= 2008000
|
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
|
||||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
|
||||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
atomicAdd(out, val);
|
atomicAdd(out, val);
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,138 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
// Copyright © 2008-2013 NVIDIA Corporation
|
|
||||||
// Copyright © 2013 Filipe RNC Maia
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
//
|
|
||||||
// Forked from
|
|
||||||
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
|
||||||
|
|
||||||
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
|
||||||
// can not be used in JIT.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda/std/cstdint>
|
|
||||||
|
|
||||||
namespace mlx::core::cu::detail {
|
|
||||||
|
|
||||||
using ieee_float_shape_type = union {
|
|
||||||
float value;
|
|
||||||
uint32_t word;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline __device__ void get_float_word(uint32_t& i, float d) {
|
|
||||||
ieee_float_shape_type gf_u;
|
|
||||||
gf_u.value = (d);
|
|
||||||
(i) = gf_u.word;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void get_float_word(int32_t& i, float d) {
|
|
||||||
ieee_float_shape_type gf_u;
|
|
||||||
gf_u.value = (d);
|
|
||||||
(i) = gf_u.word;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void set_float_word(float& d, uint32_t i) {
|
|
||||||
ieee_float_shape_type sf_u;
|
|
||||||
sf_u.word = (i);
|
|
||||||
(d) = sf_u.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float frexp_expf(float x, int* expt) {
|
|
||||||
const uint32_t k = 235;
|
|
||||||
const float kln2 = 162.88958740F;
|
|
||||||
|
|
||||||
float exp_x;
|
|
||||||
uint32_t hx;
|
|
||||||
|
|
||||||
exp_x = expf(x - kln2);
|
|
||||||
get_float_word(hx, exp_x);
|
|
||||||
*expt = (hx >> 23) - (0x7f + 127) + k;
|
|
||||||
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
|
||||||
return exp_x;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
|
|
||||||
float x, y, exp_x, scale1, scale2;
|
|
||||||
int ex_expt, half_expt;
|
|
||||||
|
|
||||||
x = cuCrealf(z);
|
|
||||||
y = cuCimagf(z);
|
|
||||||
exp_x = frexp_expf(x, &ex_expt);
|
|
||||||
expt += ex_expt;
|
|
||||||
|
|
||||||
half_expt = expt / 2;
|
|
||||||
set_float_word(scale1, (0x7f + half_expt) << 23);
|
|
||||||
half_expt = expt - half_expt;
|
|
||||||
set_float_word(scale2, (0x7f + half_expt) << 23);
|
|
||||||
|
|
||||||
return cuComplex{
|
|
||||||
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ cuComplex cexpf(const cuComplex& z) {
|
|
||||||
float x, y, exp_x;
|
|
||||||
uint32_t hx, hy;
|
|
||||||
|
|
||||||
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
|
||||||
|
|
||||||
x = cuCrealf(z);
|
|
||||||
y = cuCimagf(z);
|
|
||||||
|
|
||||||
get_float_word(hy, y);
|
|
||||||
hy &= 0x7fffffff;
|
|
||||||
|
|
||||||
/* cexp(x + I 0) = exp(x) + I 0 */
|
|
||||||
if (hy == 0) {
|
|
||||||
return cuComplex{expf(x), y};
|
|
||||||
}
|
|
||||||
get_float_word(hx, x);
|
|
||||||
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
|
||||||
if ((hx & 0x7fffffff) == 0) {
|
|
||||||
return cuComplex{cosf(y), sinf(y)};
|
|
||||||
}
|
|
||||||
if (hy >= 0x7f800000) {
|
|
||||||
if ((hx & 0x7fffffff) != 0x7f800000) {
|
|
||||||
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
|
||||||
return cuComplex{y - y, y - y};
|
|
||||||
} else if (hx & 0x80000000) {
|
|
||||||
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
|
||||||
return cuComplex{0.0, 0.0};
|
|
||||||
} else {
|
|
||||||
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
|
||||||
return cuComplex{x, y - y};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
|
||||||
/*
|
|
||||||
* x is between 88.7 and 192, so we must scale to avoid
|
|
||||||
* overflow in expf(x).
|
|
||||||
*/
|
|
||||||
return ldexp_cexpf(z, 0);
|
|
||||||
} else {
|
|
||||||
/*
|
|
||||||
* Cases covered here:
|
|
||||||
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
|
||||||
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
|
||||||
* - x = +-Inf (generated by exp())
|
|
||||||
* - x = NaN (spurious inexact exception from y)
|
|
||||||
*/
|
|
||||||
exp_x = expf(x);
|
|
||||||
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu::detail
|
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cexpf.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <math_constants.h>
|
#include <math_constants.h>
|
||||||
|
#include <cuda/std/complex>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@ -152,7 +152,8 @@ struct Exp {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return detail::cexpf(x);
|
auto r = exp(cuda::std::complex<float>{cuCrealf(x), cuCimagf(x)});
|
||||||
|
return cuComplex{r.real(), r.imag()};
|
||||||
} else {
|
} else {
|
||||||
return exp(x);
|
return exp(x);
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvrtc.h>
|
#include <nvrtc.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@ -50,6 +51,16 @@ const std::string& cuda_home() {
|
|||||||
return home;
|
return home;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the location of CCCL headers shipped with the distribution.
|
||||||
|
bool get_cccl_include(std::string* out) {
|
||||||
|
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
|
||||||
|
if (!std::filesystem::exists(cccl_headers)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*out = fmt::format("--include-path={}", cccl_headers.string());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
const std::filesystem::path& ptx_cache_dir() {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = {
|
|||||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||||
INCLUDE_PREFIX "binary_ops.cuh",
|
INCLUDE_PREFIX "binary_ops.cuh",
|
||||||
INCLUDE_PREFIX "cast_op.cuh",
|
INCLUDE_PREFIX "cast_op.cuh",
|
||||||
INCLUDE_PREFIX "cexpf.cuh",
|
|
||||||
INCLUDE_PREFIX "config.h",
|
INCLUDE_PREFIX "config.h",
|
||||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||||
INCLUDE_PREFIX "fp16_math.cuh",
|
INCLUDE_PREFIX "fp16_math.cuh",
|
||||||
@ -178,7 +188,6 @@ constexpr const char* g_headers[] = {
|
|||||||
jit_source_atomic_ops,
|
jit_source_atomic_ops,
|
||||||
jit_source_binary_ops,
|
jit_source_binary_ops,
|
||||||
jit_source_cast_op,
|
jit_source_cast_op,
|
||||||
jit_source_cexpf,
|
|
||||||
jit_source_config,
|
jit_source_config,
|
||||||
jit_source_cucomplex_math,
|
jit_source_cucomplex_math,
|
||||||
jit_source_fp16_math,
|
jit_source_fp16_math,
|
||||||
@ -217,16 +226,23 @@ JitModule::JitModule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compile program.
|
// Compile program.
|
||||||
|
std::vector<const char*> args;
|
||||||
bool use_sass = compiler_supports_device_sass(device);
|
bool use_sass = compiler_supports_device_sass(device);
|
||||||
std::string compute = fmt::format(
|
std::string compute = fmt::format(
|
||||||
"--gpu-architecture={}_{}{}",
|
"--gpu-architecture={}_{}{}",
|
||||||
use_sass ? "sm" : "compute",
|
use_sass ? "sm" : "compute",
|
||||||
device.compute_capability_major(),
|
device.compute_capability_major(),
|
||||||
device.compute_capability_minor());
|
device.compute_capability_minor());
|
||||||
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
args.push_back(compute.c_str());
|
||||||
const char* args[] = {compute.c_str(), include.c_str()};
|
std::string cccl_include;
|
||||||
|
if (get_cccl_include(&cccl_include)) {
|
||||||
|
args.push_back(cccl_include.c_str());
|
||||||
|
}
|
||||||
|
std::string cuda_include =
|
||||||
|
fmt::format("--include-path={}/include", cuda_home());
|
||||||
|
args.push_back(cuda_include.c_str());
|
||||||
nvrtcResult compile_result =
|
nvrtcResult compile_result =
|
||||||
nvrtcCompileProgram(prog, std::size(args), args);
|
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||||
if (compile_result != NVRTC_SUCCESS) {
|
if (compile_result != NVRTC_SUCCESS) {
|
||||||
size_t log_size;
|
size_t log_size;
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||||
|
@ -1,20 +1,18 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <filesystem>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#define NS_PRIVATE_IMPLEMENTATION
|
#define NS_PRIVATE_IMPLEMENTATION
|
||||||
#define CA_PRIVATE_IMPLEMENTATION
|
#define CA_PRIVATE_IMPLEMENTATION
|
||||||
#define MTL_PRIVATE_IMPLEMENTATION
|
#define MTL_PRIVATE_IMPLEMENTATION
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -80,12 +78,7 @@ MTL::Library* try_load_bundle(
|
|||||||
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
|
||||||
MTL::Device* device,
|
MTL::Device* device,
|
||||||
const std::string& relative_path) {
|
const std::string& relative_path) {
|
||||||
std::string binary_dir = get_binary_directory();
|
auto path = current_binary_dir() / relative_path;
|
||||||
if (binary_dir.size() == 0) {
|
|
||||||
return {nullptr, nullptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto path = fs::path(binary_dir) / relative_path;
|
|
||||||
if (!path.has_extension()) {
|
if (!path.has_extension()) {
|
||||||
path.replace_extension(".metallib");
|
path.replace_extension(".metallib");
|
||||||
}
|
}
|
||||||
@ -197,7 +190,7 @@ MTL::Library* load_library(
|
|||||||
|
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
msg << "Failed to load the metallib " << lib_name << ".metallib. "
|
||||||
<< "We attempted to load it from <" << get_binary_directory() << "/"
|
<< "We attempted to load it from <" << current_binary_dir() << "/"
|
||||||
<< lib_name << ".metallib" << ">";
|
<< lib_name << ".metallib" << ">";
|
||||||
#ifdef SWIFTPM_BUNDLE
|
#ifdef SWIFTPM_BUNDLE
|
||||||
msg << " and from the Swift PM bundle.";
|
msg << " and from the Swift PM bundle.";
|
||||||
|
@ -3,8 +3,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Metal/Metal.hpp>
|
#include <Metal/Metal.hpp>
|
||||||
#include <dlfcn.h>
|
|
||||||
#include <filesystem>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <shared_mutex>
|
#include <shared_mutex>
|
||||||
@ -15,22 +13,8 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
// Note, this function must be left inline in a header so that it is not
|
|
||||||
// dynamically linked.
|
|
||||||
inline std::string get_binary_directory() {
|
|
||||||
Dl_info info;
|
|
||||||
std::string directory;
|
|
||||||
int success = dladdr((void*)get_binary_directory, &info);
|
|
||||||
if (success) {
|
|
||||||
directory = fs::path(info.dli_fname).remove_filename().c_str();
|
|
||||||
}
|
|
||||||
return directory;
|
|
||||||
}
|
|
||||||
|
|
||||||
using MTLFCList =
|
using MTLFCList =
|
||||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user