mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
67 lines
1.8 KiB
C++
67 lines
1.8 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/gpu/eval.h"
|
|
#include "mlx/backend/cuda/allocator.h"
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/backend/gpu/available.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
#include <nvtx3/nvtx3.hpp>
|
|
|
|
namespace mlx::core::gpu {
|
|
|
|
bool is_available() {
|
|
return true;
|
|
}
|
|
|
|
void new_stream(Stream s) {
|
|
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
|
cudaFree(nullptr);
|
|
// Ensure the static stream objects get created.
|
|
cu::get_command_encoder(s);
|
|
// The main thread is safe to free buffers.
|
|
cu::allocator().register_this_thread();
|
|
}
|
|
|
|
void eval(array& arr) {
|
|
nvtx3::scoped_range r("gpu::eval");
|
|
auto outputs = arr.outputs();
|
|
{
|
|
// If the array is a tracer hold a reference
|
|
// to its inputs so they don't get donated
|
|
std::vector<array> inputs;
|
|
if (arr.is_tracer()) {
|
|
inputs = arr.inputs();
|
|
}
|
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
|
}
|
|
|
|
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
|
// Keep used buffers alive until kernel finishes running.
|
|
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
for (auto& in : arr.inputs()) {
|
|
buffers.insert(in.data_shared_ptr());
|
|
}
|
|
for (auto& s : arr.siblings()) {
|
|
buffers.insert(s.data_shared_ptr());
|
|
}
|
|
// Remove the output if it was donated to by an input.
|
|
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
buffers.erase(it);
|
|
}
|
|
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
|
encoder.maybe_commit();
|
|
}
|
|
|
|
void finalize(Stream s) {
|
|
nvtx3::scoped_range r("gpu::finalize");
|
|
cu::get_command_encoder(s).commit();
|
|
}
|
|
|
|
void synchronize(Stream s) {
|
|
nvtx3::scoped_range r("gpu::synchronize");
|
|
cu::get_command_encoder(s).synchronize();
|
|
}
|
|
|
|
} // namespace mlx::core::gpu
|