mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
6 Commits
d40ed46a1a
...
323cc645ab
Author | SHA1 | Date | |
---|---|---|---|
![]() |
323cc645ab | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
992eac905a | ||
![]() |
c8d4d97447 | ||
![]() |
28902ece4e |
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -14,9 +15,11 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
|
if (size < page_size) {
|
||||||
|
size = next_power_of_2(size);
|
||||||
|
} else {
|
||||||
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
|
}
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
|
@ -24,7 +24,6 @@ void copy_gpu_inplace(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT b_loc = 0;
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT c_loc = 0;
|
IdxT c_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
@ -162,11 +162,15 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array workspace(
|
void* workspace_ptr = nullptr;
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
if (heuristic_.workspaceSize > 0) {
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
array workspace(
|
||||||
int8);
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
encoder.add_temporary(workspace);
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = workspace.data<void>();
|
||||||
|
}
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@ -183,8 +187,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace.data<void>(),
|
workspace_ptr,
|
||||||
workspace.nbytes(),
|
heuristic_.workspaceSize,
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
@ -4,12 +4,15 @@
|
|||||||
#include "mlx/backend/gpu/available.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/backend/gpu/eval.h"
|
#include "mlx/backend/gpu/eval.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::gpu {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
std::mutex metal_operation_mutex;
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -78,6 +82,7 @@ void eval(array& arr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream s) {
|
void finalize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
@ -88,6 +93,7 @@ void finalize(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
|
|||||||
if (stream.device == Device::cpu) {
|
if (stream.device == Device::cpu) {
|
||||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
|
|||||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
|
|
||||||
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
if (!f.use_fast) {
|
if (!f.use_fast) {
|
||||||
|
7
mlx/backend/metal/thread_safey.h
Normal file
7
mlx/backend/metal/thread_safey.h
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
extern std::mutex metal_operation_mutex;
|
||||||
|
}
|
@ -413,7 +413,7 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(modules)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
|
@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
|
# Allow updating a strict subset
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||||
|
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
@ -9,7 +9,9 @@ FetchContent_MakeAvailable(doctest)
|
|||||||
|
|
||||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL)
|
||||||
|
set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp)
|
||||||
|
elseif(MLX_BUILD_CUDA)
|
||||||
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") {
|
|||||||
array b = array(buf_b, shape, float32, deleter);
|
array b = array(buf_b, shape, float32, deleter);
|
||||||
|
|
||||||
eval(a + b);
|
eval(a + b);
|
||||||
|
synchronize(); // ensure all operations complete before test ends
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test make empty array") {
|
TEST_CASE("test make empty array") {
|
||||||
|
250
tests/metal_thread_safety_tests.cpp
Normal file
250
tests/metal_thread_safety_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
|
#include <mutex>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
// Helper function to run operations across multiple threads with pre-created streams
|
||||||
|
void run_in_threads(int num_threads, const std::function<void(int, Stream)>& func,
|
||||||
|
const std::vector<Stream>& streams) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(num_threads);
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
threads.emplace_back(func, i, streams[i % streams.size()]);
|
||||||
|
}
|
||||||
|
for (auto& t : threads) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function for tasks not requiring streams (e.g., using default stream)
|
||||||
|
void run_in_threads_default(int num_threads, const std::function<void(int)>& func) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(num_threads);
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
threads.emplace_back(func, i);
|
||||||
|
}
|
||||||
|
for (auto& t : threads) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thread-safe result collection
|
||||||
|
struct TestResults {
|
||||||
|
std::mutex mutex;
|
||||||
|
std::vector<bool> shape_checks;
|
||||||
|
std::vector<bool> availability_checks;
|
||||||
|
std::vector<bool> value_checks;
|
||||||
|
std::vector<float> expected_values;
|
||||||
|
std::vector<float> actual_values;
|
||||||
|
|
||||||
|
void record_result(bool shape_ok, bool available_ok, bool value_ok,
|
||||||
|
float expected, float actual) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
shape_checks.push_back(shape_ok);
|
||||||
|
availability_checks.push_back(available_ok);
|
||||||
|
value_checks.push_back(value_ok);
|
||||||
|
expected_values.push_back(expected);
|
||||||
|
actual_values.push_back(actual);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test metal concurrent eval operations") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||||
|
const int ops_per_thread = 10;
|
||||||
|
const int array_size = 32;
|
||||||
|
std::atomic<int> completed_ops{0};
|
||||||
|
TestResults results;
|
||||||
|
|
||||||
|
// Pre-create streams to avoid concurrent stream creation
|
||||||
|
std::vector<Stream> streams;
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
streams.push_back(new_stream(D_GPU));
|
||||||
|
}
|
||||||
|
synchronize(); // Ensure stream creation is complete
|
||||||
|
|
||||||
|
auto task = [&](int thread_id, Stream s) {
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < ops_per_thread; ++i) {
|
||||||
|
float val1 = static_cast<float>(thread_id * ops_per_thread + i + 1);
|
||||||
|
float val2 = val1 * 2.0f;
|
||||||
|
|
||||||
|
auto x = full({array_size, array_size}, val1, s);
|
||||||
|
auto y = full({array_size, array_size}, val2, s);
|
||||||
|
auto z = add(x, y);
|
||||||
|
eval(z);
|
||||||
|
|
||||||
|
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = z.is_available();
|
||||||
|
|
||||||
|
// Get a value from the array
|
||||||
|
int mid = array_size/2;
|
||||||
|
auto sample = slice(z, {mid, mid}, {mid+1, mid+1});
|
||||||
|
float actual = sample.item<float>();
|
||||||
|
float expected = val1 + val2;
|
||||||
|
|
||||||
|
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||||
|
|
||||||
|
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||||
|
|
||||||
|
if (shape_ok && available_ok && values_match) {
|
||||||
|
completed_ops++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run the threads with pre-created streams
|
||||||
|
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||||
|
|
||||||
|
// Check all results outside of threads
|
||||||
|
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||||
|
CAPTURE(i); // Help identify which operation failed
|
||||||
|
CHECK(results.shape_checks[i]);
|
||||||
|
CHECK(results.availability_checks[i]);
|
||||||
|
CHECK(results.value_checks[i]);
|
||||||
|
if (!results.value_checks[i]) {
|
||||||
|
CAPTURE(results.expected_values[i]);
|
||||||
|
CAPTURE(results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all operations completed successfully
|
||||||
|
CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test metal high contention on default stream eval") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||||
|
const int ops_per_thread = 5;
|
||||||
|
const int array_size = 16;
|
||||||
|
Stream default_gpu_stream = default_stream(D_GPU);
|
||||||
|
std::atomic<int> successful_ops{0};
|
||||||
|
std::vector<std::string> thread_errors;
|
||||||
|
std::mutex errors_mutex;
|
||||||
|
TestResults results;
|
||||||
|
|
||||||
|
auto task = [&](int thread_id) {
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < ops_per_thread; ++i) {
|
||||||
|
float val = static_cast<float>(thread_id * 100 + i + 1);
|
||||||
|
auto x = full({array_size, array_size}, val, default_gpu_stream);
|
||||||
|
auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream);
|
||||||
|
auto z = multiply(x, y);
|
||||||
|
eval(z);
|
||||||
|
|
||||||
|
// Sample a value
|
||||||
|
auto sample = slice(z, {0, 0}, {1, 1});
|
||||||
|
float actual = sample.item<float>();
|
||||||
|
float expected = val * val * 0.5f;
|
||||||
|
|
||||||
|
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = z.is_available();
|
||||||
|
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||||
|
|
||||||
|
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||||
|
|
||||||
|
if (shape_ok && available_ok && values_match) {
|
||||||
|
successful_ops++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::lock_guard<std::mutex> lock(errors_mutex);
|
||||||
|
thread_errors.push_back(std::string("Thread ") +
|
||||||
|
std::to_string(thread_id) +
|
||||||
|
" exception: " + e.what());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the default helper for this test since it uses the default stream
|
||||||
|
CHECK_NOTHROW(run_in_threads_default(num_threads, task));
|
||||||
|
|
||||||
|
// Check for thread errors
|
||||||
|
CHECK(thread_errors.empty());
|
||||||
|
if (!thread_errors.empty()) {
|
||||||
|
for (const auto& err : thread_errors) {
|
||||||
|
CAPTURE(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check all results
|
||||||
|
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||||
|
CAPTURE(i);
|
||||||
|
CHECK(results.shape_checks[i]);
|
||||||
|
CHECK(results.availability_checks[i]);
|
||||||
|
CHECK(results.value_checks[i]);
|
||||||
|
if (!results.value_checks[i]) {
|
||||||
|
CAPTURE(results.expected_values[i]);
|
||||||
|
CAPTURE(results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify operation count
|
||||||
|
CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test metal concurrent graph eval from different threads") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity
|
||||||
|
const int array_size = 64;
|
||||||
|
TestResults all_results;
|
||||||
|
|
||||||
|
// Pre-create streams
|
||||||
|
std::vector<Stream> streams;
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
streams.push_back(new_stream(D_GPU));
|
||||||
|
}
|
||||||
|
synchronize();
|
||||||
|
|
||||||
|
auto task = [&](int thread_id, Stream s) {
|
||||||
|
try {
|
||||||
|
float val1_base = static_cast<float>(thread_id + 1) * 10.0f;
|
||||||
|
auto x = full({array_size, array_size}, val1_base, s);
|
||||||
|
auto y = full({array_size, array_size}, val1_base + 1.0f, s);
|
||||||
|
auto z = add(x, y);
|
||||||
|
auto w = multiply(z, x);
|
||||||
|
eval(w);
|
||||||
|
|
||||||
|
float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base;
|
||||||
|
auto sample = slice(w, {0,0}, {1,1});
|
||||||
|
float actual_val = sample.item<float>();
|
||||||
|
|
||||||
|
bool shape_ok = (w.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = w.is_available();
|
||||||
|
bool value_ok = (std::abs(actual_val - expected_val) < 1e-4);
|
||||||
|
|
||||||
|
all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val);
|
||||||
|
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||||
|
|
||||||
|
CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread
|
||||||
|
for (size_t i = 0; i < num_threads; ++i) {
|
||||||
|
CAPTURE(i);
|
||||||
|
CHECK(all_results.shape_checks[i]);
|
||||||
|
CHECK(all_results.availability_checks[i]);
|
||||||
|
CHECK(all_results.value_checks[i]);
|
||||||
|
if (!all_results.value_checks[i]) {
|
||||||
|
CAPTURE(all_results.expected_values[i]);
|
||||||
|
CAPTURE(all_results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user