mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Compare commits
3 Commits
695fd9281f
...
970a5d0a5e
Author | SHA1 | Date | |
---|---|---|---|
![]() |
970a5d0a5e | ||
![]() |
b3d7b85376 | ||
![]() |
7c99acb799 |
@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the CUDA toolkit.
|
// Return the location of the CUDA toolkit.
|
||||||
const char* cuda_home() {
|
const std::string& cuda_home() {
|
||||||
const char* home = std::getenv("CUDA_HOME");
|
static std::string home = []() -> std::string {
|
||||||
if (home) {
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
return home;
|
if (home) {
|
||||||
}
|
return home;
|
||||||
home = std::getenv("CUDA_PATH");
|
}
|
||||||
if (home) {
|
home = std::getenv("CUDA_PATH");
|
||||||
return home;
|
if (home) {
|
||||||
}
|
return home;
|
||||||
|
}
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
home = "/usr/local/cuda";
|
home = "/usr/local/cuda";
|
||||||
if (std::filesystem::exists(home)) {
|
if (std::filesystem::exists(home)) {
|
||||||
return home;
|
return home;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
|
}();
|
||||||
|
return home;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
if (!std::filesystem::is_directory(path)) {
|
std::filesystem::path cache;
|
||||||
std::error_code error;
|
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||||
if (!std::filesystem::create_directories(path, error)) {
|
cache = c;
|
||||||
return false;
|
} else {
|
||||||
|
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
}
|
}
|
||||||
}
|
if (!std::filesystem::exists(cache)) {
|
||||||
*result = path;
|
std::error_code error;
|
||||||
return true;
|
if (!std::filesystem::create_directories(cache, error)) {
|
||||||
|
return std::filesystem::path();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cache;
|
||||||
|
}();
|
||||||
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
@ -75,6 +85,10 @@ bool read_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
@ -105,6 +119,10 @@ void write_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
ptx_file.write(&ptx.front(), ptx.size());
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
@ -184,11 +202,9 @@ JitModule::JitModule(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
// Check cache.
|
// Check cache.
|
||||||
std::filesystem::path cache_dir;
|
|
||||||
std::vector<char> ptx;
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
|
||||||
// Create program.
|
// Create program.
|
||||||
auto [source_code, kernel_names] = builder();
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@ -246,7 +262,7 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
const device T* in,
|
const device T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
constant int& axis_size,
|
constant int& axis_size,
|
||||||
uint gid [[threadgroup_position_in_grid]],
|
uint2 gid [[threadgroup_position_in_grid]],
|
||||||
uint _lid [[thread_position_in_threadgroup]],
|
uint2 tid [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]],
|
||||||
|
uint2 _lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
int lid = _lid;
|
int lid = _lid.x;
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS;
|
||||||
|
|
||||||
threadgroup AccT local_max[SIMD_SIZE];
|
threadgroup AccT local_max[SIMD_SIZE];
|
||||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||||
|
|
||||||
AccT ld[N_READS];
|
AccT ld[N_READS];
|
||||||
|
|
||||||
in += gid * size_t(axis_size) + lid * N_READS;
|
const int axis_offset = tid.y * elem_per_group;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset;
|
||||||
|
if (axis_offset + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
ld[i] = AccT(in[i]);
|
ld[i] = AccT(in[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
ld[i] =
|
ld[i] = ((axis_offset + lid * N_READS + i) < axis_size)
|
||||||
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
? AccT(in[i])
|
||||||
|
: Limits<AccT>::min;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (simd_group_id == 0) {
|
if (simd_group_id == 0) {
|
||||||
@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
maxval = local_max[0];
|
maxval = local_max[0];
|
||||||
|
|
||||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||||
|
out += gid.x * grid_dim.y + tid.y;
|
||||||
AccT normalizer = 0;
|
AccT normalizer = 0;
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
normalizer += fast::exp(ld[i] - maxval);
|
normalizer += fast::exp(ld[i] - maxval);
|
||||||
@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
|
|||||||
if (simd_group_id == 0) {
|
if (simd_group_id == 0) {
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
const int n_reads = 4;
|
const int n_reads = 4;
|
||||||
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
||||||
|
|
||||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
bool split = n_rows < 4 && axis_size > 4 * looped_limit;
|
||||||
|
bool looped = !split && axis_size > looped_limit;
|
||||||
|
std::string kernel_name = looped ? "looped_" : "block_";
|
||||||
kernel_name += "logsumexp_";
|
kernel_name += "logsumexp_";
|
||||||
kernel_name += type_to_name(out);
|
kernel_name += type_to_name(out);
|
||||||
|
|
||||||
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
if (split) {
|
||||||
|
auto tmp_size = ceildiv(axis_size, looped_limit);
|
||||||
|
auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)};
|
||||||
|
array tmp(tmp_shape, in.dtype(), nullptr, {});
|
||||||
|
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||||
|
size_t threadgroup_size = 1024;
|
||||||
|
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
size_t n_threads = n_rows * threadgroup_size;
|
||||||
|
auto grid_dims = MTL::Size(n_threads, tmp_size, 1);
|
||||||
|
auto group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(tmp, 1);
|
||||||
|
compute_encoder.set_bytes(axis_size, 2);
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
d.add_temporary(tmp, s.index);
|
||||||
|
in = tmp;
|
||||||
|
axis_size = tmp_size;
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
MTL::Size grid_dims, group_dims;
|
MTL::Size grid_dims, group_dims;
|
||||||
if (axis_size <= looped_limit) {
|
if (!looped) {
|
||||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||||
size_t threadgroup_size = simd_size * simds_needed;
|
size_t threadgroup_size = simd_size * simds_needed;
|
||||||
|
@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||||
|
|
||||||
|
# Even larger
|
||||||
|
x = mx.random.uniform(shape=(4 * 4096 + 3,))
|
||||||
|
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||||
|
|
||||||
def test_mean(self):
|
def test_mean(self):
|
||||||
x = mx.array(
|
x = mx.array(
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user