mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
6 Commits
e2a2bae148
...
4498a46248
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4498a46248 | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
992eac905a | ||
![]() |
c8d4d97447 | ||
![]() |
28902ece4e |
@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
|
@ -63,25 +63,30 @@ void copy_general(
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
data_size,
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <future>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
stream().synchronize();
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
worker_.end_batch();
|
||||
worker_.commit();
|
||||
f.wait();
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
|
@ -123,6 +123,9 @@ class CommandEncoder {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_stream(s).synchronize();
|
||||
cu::get_command_encoder(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
||||
|
@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) {
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const char* cuda_home() {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
const std::string& cuda_home() {
|
||||
static std::string home = []() -> std::string {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
#if defined(__linux__)
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
}();
|
||||
return home;
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
if (!std::filesystem::is_directory(path)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(path, error)) {
|
||||
return false;
|
||||
const std::filesystem::path& ptx_cache_dir() {
|
||||
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||
std::filesystem::path cache;
|
||||
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||
cache = c;
|
||||
} else {
|
||||
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
}
|
||||
}
|
||||
*result = path;
|
||||
return true;
|
||||
if (!std::filesystem::exists(cache)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(cache, error)) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
}
|
||||
return cache;
|
||||
}();
|
||||
return cache;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
@ -75,6 +85,10 @@ bool read_cached_ptx(
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
@ -105,6 +119,10 @@ void write_cached_ptx(
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
@ -184,11 +202,9 @@ JitModule::JitModule(
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::filesystem::path cache_dir;
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
@ -246,7 +262,7 @@ JitModule::JitModule(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
||||
}
|
||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||
}
|
||||
for (auto& task : tasks) {
|
||||
// Make sure tasks are cleared before the next wait
|
||||
for (int i = 0; i < tasks.size(); ++i) {
|
||||
auto task = std::move(tasks[i]);
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
|
@ -4,12 +4,15 @@
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
std::mutex metal_operation_mutex;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -78,6 +82,7 @@ void eval(array& arr) {
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
@ -88,6 +93,7 @@ void finalize(Stream s) {
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
});
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
|
||||
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
if (!f.use_fast) {
|
||||
|
7
mlx/backend/metal/thread_safey.h
Normal file
7
mlx/backend/metal/thread_safey.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
extern std::mutex metal_operation_mutex;
|
||||
}
|
@ -6,7 +6,6 @@ cuda_skip = {
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
|
@ -9,7 +9,9 @@ FetchContent_MakeAvailable(doctest)
|
||||
|
||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp)
|
||||
elseif(MLX_BUILD_CUDA)
|
||||
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
||||
endif()
|
||||
|
||||
|
@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") {
|
||||
array b = array(buf_b, shape, float32, deleter);
|
||||
|
||||
eval(a + b);
|
||||
synchronize(); // ensure all operations complete before test ends
|
||||
}
|
||||
|
||||
TEST_CASE("test make empty array") {
|
||||
|
250
tests/metal_thread_safety_tests.cpp
Normal file
250
tests/metal_thread_safety_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
||||
#include "doctest/doctest.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
// Helper function to run operations across multiple threads with pre-created streams
|
||||
void run_in_threads(int num_threads, const std::function<void(int, Stream)>& func,
|
||||
const std::vector<Stream>& streams) {
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
threads.emplace_back(func, i, streams[i % streams.size()]);
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for tasks not requiring streams (e.g., using default stream)
|
||||
void run_in_threads_default(int num_threads, const std::function<void(int)>& func) {
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
threads.emplace_back(func, i);
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Thread-safe result collection
|
||||
struct TestResults {
|
||||
std::mutex mutex;
|
||||
std::vector<bool> shape_checks;
|
||||
std::vector<bool> availability_checks;
|
||||
std::vector<bool> value_checks;
|
||||
std::vector<float> expected_values;
|
||||
std::vector<float> actual_values;
|
||||
|
||||
void record_result(bool shape_ok, bool available_ok, bool value_ok,
|
||||
float expected, float actual) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
shape_checks.push_back(shape_ok);
|
||||
availability_checks.push_back(available_ok);
|
||||
value_checks.push_back(value_ok);
|
||||
expected_values.push_back(expected);
|
||||
actual_values.push_back(actual);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("test metal concurrent eval operations") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||
const int ops_per_thread = 10;
|
||||
const int array_size = 32;
|
||||
std::atomic<int> completed_ops{0};
|
||||
TestResults results;
|
||||
|
||||
// Pre-create streams to avoid concurrent stream creation
|
||||
std::vector<Stream> streams;
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
streams.push_back(new_stream(D_GPU));
|
||||
}
|
||||
synchronize(); // Ensure stream creation is complete
|
||||
|
||||
auto task = [&](int thread_id, Stream s) {
|
||||
try {
|
||||
for (int i = 0; i < ops_per_thread; ++i) {
|
||||
float val1 = static_cast<float>(thread_id * ops_per_thread + i + 1);
|
||||
float val2 = val1 * 2.0f;
|
||||
|
||||
auto x = full({array_size, array_size}, val1, s);
|
||||
auto y = full({array_size, array_size}, val2, s);
|
||||
auto z = add(x, y);
|
||||
eval(z);
|
||||
|
||||
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = z.is_available();
|
||||
|
||||
// Get a value from the array
|
||||
int mid = array_size/2;
|
||||
auto sample = slice(z, {mid, mid}, {mid+1, mid+1});
|
||||
float actual = sample.item<float>();
|
||||
float expected = val1 + val2;
|
||||
|
||||
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||
|
||||
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||
|
||||
if (shape_ok && available_ok && values_match) {
|
||||
completed_ops++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
// Run the threads with pre-created streams
|
||||
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||
|
||||
// Check all results outside of threads
|
||||
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||
CAPTURE(i); // Help identify which operation failed
|
||||
CHECK(results.shape_checks[i]);
|
||||
CHECK(results.availability_checks[i]);
|
||||
CHECK(results.value_checks[i]);
|
||||
if (!results.value_checks[i]) {
|
||||
CAPTURE(results.expected_values[i]);
|
||||
CAPTURE(results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all operations completed successfully
|
||||
CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread);
|
||||
}
|
||||
|
||||
TEST_CASE("test metal high contention on default stream eval") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||
const int ops_per_thread = 5;
|
||||
const int array_size = 16;
|
||||
Stream default_gpu_stream = default_stream(D_GPU);
|
||||
std::atomic<int> successful_ops{0};
|
||||
std::vector<std::string> thread_errors;
|
||||
std::mutex errors_mutex;
|
||||
TestResults results;
|
||||
|
||||
auto task = [&](int thread_id) {
|
||||
try {
|
||||
for (int i = 0; i < ops_per_thread; ++i) {
|
||||
float val = static_cast<float>(thread_id * 100 + i + 1);
|
||||
auto x = full({array_size, array_size}, val, default_gpu_stream);
|
||||
auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream);
|
||||
auto z = multiply(x, y);
|
||||
eval(z);
|
||||
|
||||
// Sample a value
|
||||
auto sample = slice(z, {0, 0}, {1, 1});
|
||||
float actual = sample.item<float>();
|
||||
float expected = val * val * 0.5f;
|
||||
|
||||
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = z.is_available();
|
||||
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||
|
||||
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||
|
||||
if (shape_ok && available_ok && values_match) {
|
||||
successful_ops++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::lock_guard<std::mutex> lock(errors_mutex);
|
||||
thread_errors.push_back(std::string("Thread ") +
|
||||
std::to_string(thread_id) +
|
||||
" exception: " + e.what());
|
||||
}
|
||||
};
|
||||
|
||||
// Use the default helper for this test since it uses the default stream
|
||||
CHECK_NOTHROW(run_in_threads_default(num_threads, task));
|
||||
|
||||
// Check for thread errors
|
||||
CHECK(thread_errors.empty());
|
||||
if (!thread_errors.empty()) {
|
||||
for (const auto& err : thread_errors) {
|
||||
CAPTURE(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Check all results
|
||||
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||
CAPTURE(i);
|
||||
CHECK(results.shape_checks[i]);
|
||||
CHECK(results.availability_checks[i]);
|
||||
CHECK(results.value_checks[i]);
|
||||
if (!results.value_checks[i]) {
|
||||
CAPTURE(results.expected_values[i]);
|
||||
CAPTURE(results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify operation count
|
||||
CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread);
|
||||
}
|
||||
|
||||
TEST_CASE("test metal concurrent graph eval from different threads") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity
|
||||
const int array_size = 64;
|
||||
TestResults all_results;
|
||||
|
||||
// Pre-create streams
|
||||
std::vector<Stream> streams;
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
streams.push_back(new_stream(D_GPU));
|
||||
}
|
||||
synchronize();
|
||||
|
||||
auto task = [&](int thread_id, Stream s) {
|
||||
try {
|
||||
float val1_base = static_cast<float>(thread_id + 1) * 10.0f;
|
||||
auto x = full({array_size, array_size}, val1_base, s);
|
||||
auto y = full({array_size, array_size}, val1_base + 1.0f, s);
|
||||
auto z = add(x, y);
|
||||
auto w = multiply(z, x);
|
||||
eval(w);
|
||||
|
||||
float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base;
|
||||
auto sample = slice(w, {0,0}, {1,1});
|
||||
float actual_val = sample.item<float>();
|
||||
|
||||
bool shape_ok = (w.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = w.is_available();
|
||||
bool value_ok = (std::abs(actual_val - expected_val) < 1e-4);
|
||||
|
||||
all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val);
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||
|
||||
CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread
|
||||
for (size_t i = 0; i < num_threads; ++i) {
|
||||
CAPTURE(i);
|
||||
CHECK(all_results.shape_checks[i]);
|
||||
CHECK(all_results.availability_checks[i]);
|
||||
CHECK(all_results.value_checks[i]);
|
||||
if (!all_results.value_checks[i]) {
|
||||
CAPTURE(all_results.expected_values[i]);
|
||||
CAPTURE(all_results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user