mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
[CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear * fix copy
This commit is contained in:
parent
b8022c578a
commit
cad5c0241c
@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaFree(buf);
|
cudaFree(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,25 +63,30 @@ void copy_general(
|
|||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
|||||||
worker_.commit(stream_.last_cuda_stream());
|
worker_.commit(stream_.last_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::synchronize() {
|
||||||
|
stream().synchronize();
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
std::future<void> f = p->get_future();
|
||||||
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
|
worker_.end_batch();
|
||||||
|
worker_.commit();
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device device) {
|
Device& device(mlx::core::Device device) {
|
||||||
static std::unordered_map<int, Device> devices;
|
static std::unordered_map<int, Device> devices;
|
||||||
auto it = devices.find(device.index);
|
auto it = devices.find(device.index);
|
||||||
|
@ -123,6 +123,9 @@ class CommandEncoder {
|
|||||||
return has_gpu_work_;
|
return has_gpu_work_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait until kernels and completion handlers are finished
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Device& device_;
|
Device& device_;
|
||||||
DeviceStream& stream_;
|
DeviceStream& stream_;
|
||||||
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
|||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
nvtx3::scoped_range r("gpu::synchronize");
|
nvtx3::scoped_range r("gpu::synchronize");
|
||||||
cu::get_stream(s).synchronize();
|
cu::get_command_encoder(s).synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::gpu
|
} // namespace mlx::core::gpu
|
||||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
|||||||
}
|
}
|
||||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||||
}
|
}
|
||||||
for (auto& task : tasks) {
|
// Make sure tasks are cleared before the next wait
|
||||||
|
for (int i = 0; i < tasks.size(); ++i) {
|
||||||
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
worker_event_.wait(batch + 1);
|
||||||
|
@ -6,7 +6,6 @@ cuda_skip = {
|
|||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
"TestLoad.test_load_f8_e4m3",
|
||||||
"TestMemory.test_memory_info",
|
|
||||||
"TestLayers.test_group_norm",
|
"TestLayers.test_group_norm",
|
||||||
"TestLayers.test_pooling",
|
"TestLayers.test_pooling",
|
||||||
"TestLayers.test_quantized_embedding",
|
"TestLayers.test_quantized_embedding",
|
||||||
|
Loading…
Reference in New Issue
Block a user