mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
1d4eacb737
...
f8bd675655
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 |
@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size, device};
|
||||
cudaError_t err;
|
||||
void* data = nullptr;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&buf->data, size);
|
||||
err = cudaMallocManaged(&data, size);
|
||||
} else {
|
||||
err = cudaMallocAsync(&buf->data, size, stream);
|
||||
err = cudaMallocAsync(&data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
if (!data) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
buf = new CudaBuffer{data, size, device};
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ class CudaHandle {
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
// Skip if there was an error to avoid throwing in the destructors
|
||||
if (cudaPeekAtLastError() != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
reset();
|
||||
}
|
||||
|
||||
|
||||
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
}();
|
||||
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
return use_graphs;
|
||||
}
|
||||
|
||||
const char* save_cuda_graphs_dot_file() {
|
||||
static const char* filename = []() -> const char* {
|
||||
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||
if (env && std::strlen(env) == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return env;
|
||||
}();
|
||||
return filename;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
|
||||
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
}
|
||||
|
||||
// Save cuda graph to dot file
|
||||
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||
static int count = 0;
|
||||
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||
}
|
||||
|
||||
// Reset state
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
|
||||
@@ -305,6 +305,7 @@ void Event::wait() {
|
||||
} else {
|
||||
event->atomic->wait(value());
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
|
||||
@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
|
||||
return x;
|
||||
}
|
||||
|
||||
void malloc_with_same_layout(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& o,
|
||||
const array& q) {
|
||||
if (q.flags().row_contiguous) {
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
return;
|
||||
}
|
||||
// fill_order = argsort(q.strides())
|
||||
Shape fill_order(q.ndim());
|
||||
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||
std::stable_sort(
|
||||
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||
return s1 < s2;
|
||||
});
|
||||
// Generate o_strides with fill_order
|
||||
Strides o_strides(q.ndim());
|
||||
int64_t stride = 1;
|
||||
for (int i : fill_order) {
|
||||
o_strides[i] = stride;
|
||||
stride *= o.shape(i);
|
||||
}
|
||||
// o is a transposed contiguous array
|
||||
o.set_data(
|
||||
cu::malloc_async(o.nbytes(), encoder),
|
||||
o.size(),
|
||||
o_strides,
|
||||
{true, false, false});
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
@@ -338,9 +370,7 @@ void sdpa_cudnn(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, o, q);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
@@ -392,10 +422,9 @@ void sdpa_backward_cudnn(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, d_q, q);
|
||||
malloc_with_same_layout(encoder, d_k, k);
|
||||
malloc_with_same_layout(encoder, d_v, v);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
|
||||
@@ -135,7 +135,11 @@ class Scheduler {
|
||||
|
||||
~Scheduler() {
|
||||
for (auto s : streams_) {
|
||||
try {
|
||||
synchronize(s);
|
||||
} catch (const std::runtime_error&) {
|
||||
// ignore errors if synch fails
|
||||
}
|
||||
}
|
||||
for (auto t : threads_) {
|
||||
if (t != nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user