perf tuning

This commit is contained in:
Awni Hannun 2025-06-18 16:42:39 -07:00
parent 76831ed83d
commit 72e21b7d51
4 changed files with 32 additions and 19 deletions

View File

@ -1,5 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/utils.h"
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu { namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) {
cuda_free(buf->data); cuda_free(buf->data);
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return; return;

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0; IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0; IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,11 +162,15 @@ class MatMul {
} }
} }
array workspace( void *workspace_ptr = nullptr;
allocator::malloc(heuristic_.workspaceSize), if (heuristic_.workspaceSize > 0) {
{static_cast<int>(heuristic_.workspaceSize)}, array workspace(
int8); allocator::malloc(heuristic_.workspaceSize),
encoder.add_temporary(workspace); {static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -183,8 +187,8 @@ class MatMul {
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace.data<void>(), workspace_ptr,
workspace.nbytes(), heuristic_.workspaceSize,
stream)); stream));
}); });
} }