mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
1d4eacb737
...
f8bd675655
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 |
@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
buf = new CudaBuffer{nullptr, size, device};
|
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
|
void* data = nullptr;
|
||||||
if (device == -1) {
|
if (device == -1) {
|
||||||
err = cudaMallocManaged(&buf->data, size);
|
err = cudaMallocManaged(&data, size);
|
||||||
} else {
|
} else {
|
||||||
err = cudaMallocAsync(&buf->data, size, stream);
|
err = cudaMallocAsync(&data, size, stream);
|
||||||
}
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
}
|
}
|
||||||
|
if (!data) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
}
|
||||||
|
buf = new CudaBuffer{data, size, device};
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ class CudaHandle {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~CudaHandle() {
|
~CudaHandle() {
|
||||||
|
// Skip if there was an error to avoid throwing in the destructors
|
||||||
|
if (cudaPeekAtLastError() != cudaSuccess) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = []() {
|
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
|
||||||
}();
|
|
||||||
return use_graphs;
|
return use_graphs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* save_cuda_graphs_dot_file() {
|
||||||
|
static const char* filename = []() -> const char* {
|
||||||
|
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||||
|
if (env && std::strlen(env) == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return env;
|
||||||
|
}();
|
||||||
|
return filename;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
|
|||||||
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save cuda graph to dot file
|
||||||
|
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||||
|
static int count = 0;
|
||||||
|
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ void Event::wait() {
|
|||||||
} else {
|
} else {
|
||||||
event->atomic->wait(value());
|
event->atomic->wait(value());
|
||||||
}
|
}
|
||||||
|
CHECK_CUDA_ERROR(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait(Stream s) {
|
void Event::wait(Stream s) {
|
||||||
|
|||||||
@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void malloc_with_same_layout(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& o,
|
||||||
|
const array& q) {
|
||||||
|
if (q.flags().row_contiguous) {
|
||||||
|
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// fill_order = argsort(q.strides())
|
||||||
|
Shape fill_order(q.ndim());
|
||||||
|
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||||
|
std::stable_sort(
|
||||||
|
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||||
|
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||||
|
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||||
|
return s1 < s2;
|
||||||
|
});
|
||||||
|
// Generate o_strides with fill_order
|
||||||
|
Strides o_strides(q.ndim());
|
||||||
|
int64_t stride = 1;
|
||||||
|
for (int i : fill_order) {
|
||||||
|
o_strides[i] = stride;
|
||||||
|
stride *= o.shape(i);
|
||||||
|
}
|
||||||
|
// o is a transposed contiguous array
|
||||||
|
o.set_data(
|
||||||
|
cu::malloc_async(o.nbytes(), encoder),
|
||||||
|
o.size(),
|
||||||
|
o_strides,
|
||||||
|
{true, false, false});
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int QKV_NDIM = 4;
|
constexpr int QKV_NDIM = 4;
|
||||||
|
|
||||||
struct SDPACacheKey {
|
struct SDPACacheKey {
|
||||||
@@ -338,9 +370,7 @@ void sdpa_cudnn(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
auto handle = encoder.device().cudnn_handle();
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
// TODO: Handle donation.
|
malloc_with_same_layout(encoder, o, q);
|
||||||
// TODO: Make O use same memory layout with Q.
|
|
||||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
@@ -392,10 +422,9 @@ void sdpa_backward_cudnn(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
auto handle = encoder.device().cudnn_handle();
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
// TODO: Handle donation.
|
malloc_with_same_layout(encoder, d_q, q);
|
||||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
malloc_with_same_layout(encoder, d_k, k);
|
||||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
malloc_with_same_layout(encoder, d_v, v);
|
||||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
|
|||||||
@@ -135,7 +135,11 @@ class Scheduler {
|
|||||||
|
|
||||||
~Scheduler() {
|
~Scheduler() {
|
||||||
for (auto s : streams_) {
|
for (auto s : streams_) {
|
||||||
synchronize(s);
|
try {
|
||||||
|
synchronize(s);
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
// ignore errors if synch fails
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto t : threads_) {
|
for (auto t : threads_) {
|
||||||
if (t != nullptr) {
|
if (t != nullptr) {
|
||||||
|
|||||||
Reference in New Issue
Block a user