mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Compare commits
5 Commits
4498a46248
...
d40ed46a1a
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d40ed46a1a | ||
![]() |
76831ed83d | ||
![]() |
992eac905a | ||
![]() |
c8d4d97447 | ||
![]() |
28902ece4e |
@ -16,6 +16,9 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
cuda_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@ -104,7 +107,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -162,7 +165,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -223,7 +226,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@ -283,7 +285,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
@ -342,7 +344,7 @@ jobs:
|
|||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
@ -356,6 +358,48 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
extra_env:
|
||||||
|
type: string
|
||||||
|
default: "DEV_RELEASE=1"
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install ".[dev]" -v
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build --wheel
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -625,3 +669,14 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
cuda_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.cuda_release >>
|
||||||
|
jobs:
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||||
|
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx-cuda
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@ -65,6 +75,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@ -107,6 +119,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
|
|||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
worker_.end_batch();
|
||||||
worker_.commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,12 +4,15 @@
|
|||||||
#include "mlx/backend/gpu/available.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/backend/gpu/eval.h"
|
#include "mlx/backend/gpu/eval.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::gpu {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
|
std::mutex metal_operation_mutex;
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -78,6 +82,7 @@ void eval(array& arr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream s) {
|
void finalize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
@ -88,6 +93,7 @@ void finalize(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
|
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto cb = d.get_command_buffer(s.index);
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
|
|||||||
if (stream.device == Device::cpu) {
|
if (stream.device == Device::cpu) {
|
||||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
|
|||||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
d.end_encoding(stream.index);
|
d.end_encoding(stream.index);
|
||||||
auto command_buffer = d.get_command_buffer(stream.index);
|
auto command_buffer = d.get_command_buffer(stream.index);
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/thread_safey.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
|
|
||||||
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||||
auto& d = metal::device(stream.device);
|
auto& d = metal::device(stream.device);
|
||||||
auto idx = stream.index;
|
auto idx = stream.index;
|
||||||
if (!f.use_fast) {
|
if (!f.use_fast) {
|
||||||
|
7
mlx/backend/metal/thread_safey.h
Normal file
7
mlx/backend/metal/thread_safey.h
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace mlx::core::gpu {
|
||||||
|
extern std::mutex metal_operation_mutex;
|
||||||
|
}
|
17
python/scripts/repair_cuda.sh
Normal file
17
python/scripts/repair_cuda.sh
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
auditwheel repair dist/* \
|
||||||
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--exclude libcublas* \
|
||||||
|
--exclude libnvrtc*
|
||||||
|
|
||||||
|
cd wheelhouse
|
||||||
|
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
||||||
|
unzip -q "${repaired_wheel}"
|
||||||
|
core_so=$(find mlx -name "core*.so" -print -quit)
|
||||||
|
rpath=$(patchelf --print-rpath "${core_so}")
|
||||||
|
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
||||||
|
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
||||||
|
|
||||||
|
# Re-zip the repaired wheel
|
||||||
|
zip -r -q "${repaired_wheel}" .
|
8
setup.py
8
setup.py
@ -174,20 +174,26 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
|
install_requires = []
|
||||||
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
if build_cuda:
|
||||||
|
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx-cuda" if build_cuda else "mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
license="MIT",
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
@ -9,7 +9,9 @@ FetchContent_MakeAvailable(doctest)
|
|||||||
|
|
||||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL)
|
||||||
|
set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp)
|
||||||
|
elseif(MLX_BUILD_CUDA)
|
||||||
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") {
|
|||||||
array b = array(buf_b, shape, float32, deleter);
|
array b = array(buf_b, shape, float32, deleter);
|
||||||
|
|
||||||
eval(a + b);
|
eval(a + b);
|
||||||
|
synchronize(); // ensure all operations complete before test ends
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test make empty array") {
|
TEST_CASE("test make empty array") {
|
||||||
|
250
tests/metal_thread_safety_tests.cpp
Normal file
250
tests/metal_thread_safety_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
|
#include <mutex>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
// Helper function to run operations across multiple threads with pre-created streams
|
||||||
|
void run_in_threads(int num_threads, const std::function<void(int, Stream)>& func,
|
||||||
|
const std::vector<Stream>& streams) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(num_threads);
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
threads.emplace_back(func, i, streams[i % streams.size()]);
|
||||||
|
}
|
||||||
|
for (auto& t : threads) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function for tasks not requiring streams (e.g., using default stream)
|
||||||
|
void run_in_threads_default(int num_threads, const std::function<void(int)>& func) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(num_threads);
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
threads.emplace_back(func, i);
|
||||||
|
}
|
||||||
|
for (auto& t : threads) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thread-safe result collection
|
||||||
|
struct TestResults {
|
||||||
|
std::mutex mutex;
|
||||||
|
std::vector<bool> shape_checks;
|
||||||
|
std::vector<bool> availability_checks;
|
||||||
|
std::vector<bool> value_checks;
|
||||||
|
std::vector<float> expected_values;
|
||||||
|
std::vector<float> actual_values;
|
||||||
|
|
||||||
|
void record_result(bool shape_ok, bool available_ok, bool value_ok,
|
||||||
|
float expected, float actual) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
shape_checks.push_back(shape_ok);
|
||||||
|
availability_checks.push_back(available_ok);
|
||||||
|
value_checks.push_back(value_ok);
|
||||||
|
expected_values.push_back(expected);
|
||||||
|
actual_values.push_back(actual);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test metal concurrent eval operations") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||||
|
const int ops_per_thread = 10;
|
||||||
|
const int array_size = 32;
|
||||||
|
std::atomic<int> completed_ops{0};
|
||||||
|
TestResults results;
|
||||||
|
|
||||||
|
// Pre-create streams to avoid concurrent stream creation
|
||||||
|
std::vector<Stream> streams;
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
streams.push_back(new_stream(D_GPU));
|
||||||
|
}
|
||||||
|
synchronize(); // Ensure stream creation is complete
|
||||||
|
|
||||||
|
auto task = [&](int thread_id, Stream s) {
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < ops_per_thread; ++i) {
|
||||||
|
float val1 = static_cast<float>(thread_id * ops_per_thread + i + 1);
|
||||||
|
float val2 = val1 * 2.0f;
|
||||||
|
|
||||||
|
auto x = full({array_size, array_size}, val1, s);
|
||||||
|
auto y = full({array_size, array_size}, val2, s);
|
||||||
|
auto z = add(x, y);
|
||||||
|
eval(z);
|
||||||
|
|
||||||
|
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = z.is_available();
|
||||||
|
|
||||||
|
// Get a value from the array
|
||||||
|
int mid = array_size/2;
|
||||||
|
auto sample = slice(z, {mid, mid}, {mid+1, mid+1});
|
||||||
|
float actual = sample.item<float>();
|
||||||
|
float expected = val1 + val2;
|
||||||
|
|
||||||
|
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||||
|
|
||||||
|
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||||
|
|
||||||
|
if (shape_ok && available_ok && values_match) {
|
||||||
|
completed_ops++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run the threads with pre-created streams
|
||||||
|
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||||
|
|
||||||
|
// Check all results outside of threads
|
||||||
|
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||||
|
CAPTURE(i); // Help identify which operation failed
|
||||||
|
CHECK(results.shape_checks[i]);
|
||||||
|
CHECK(results.availability_checks[i]);
|
||||||
|
CHECK(results.value_checks[i]);
|
||||||
|
if (!results.value_checks[i]) {
|
||||||
|
CAPTURE(results.expected_values[i]);
|
||||||
|
CAPTURE(results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all operations completed successfully
|
||||||
|
CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test metal high contention on default stream eval") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||||
|
const int ops_per_thread = 5;
|
||||||
|
const int array_size = 16;
|
||||||
|
Stream default_gpu_stream = default_stream(D_GPU);
|
||||||
|
std::atomic<int> successful_ops{0};
|
||||||
|
std::vector<std::string> thread_errors;
|
||||||
|
std::mutex errors_mutex;
|
||||||
|
TestResults results;
|
||||||
|
|
||||||
|
auto task = [&](int thread_id) {
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < ops_per_thread; ++i) {
|
||||||
|
float val = static_cast<float>(thread_id * 100 + i + 1);
|
||||||
|
auto x = full({array_size, array_size}, val, default_gpu_stream);
|
||||||
|
auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream);
|
||||||
|
auto z = multiply(x, y);
|
||||||
|
eval(z);
|
||||||
|
|
||||||
|
// Sample a value
|
||||||
|
auto sample = slice(z, {0, 0}, {1, 1});
|
||||||
|
float actual = sample.item<float>();
|
||||||
|
float expected = val * val * 0.5f;
|
||||||
|
|
||||||
|
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = z.is_available();
|
||||||
|
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||||
|
|
||||||
|
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||||
|
|
||||||
|
if (shape_ok && available_ok && values_match) {
|
||||||
|
successful_ops++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::lock_guard<std::mutex> lock(errors_mutex);
|
||||||
|
thread_errors.push_back(std::string("Thread ") +
|
||||||
|
std::to_string(thread_id) +
|
||||||
|
" exception: " + e.what());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the default helper for this test since it uses the default stream
|
||||||
|
CHECK_NOTHROW(run_in_threads_default(num_threads, task));
|
||||||
|
|
||||||
|
// Check for thread errors
|
||||||
|
CHECK(thread_errors.empty());
|
||||||
|
if (!thread_errors.empty()) {
|
||||||
|
for (const auto& err : thread_errors) {
|
||||||
|
CAPTURE(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check all results
|
||||||
|
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||||
|
CAPTURE(i);
|
||||||
|
CHECK(results.shape_checks[i]);
|
||||||
|
CHECK(results.availability_checks[i]);
|
||||||
|
CHECK(results.value_checks[i]);
|
||||||
|
if (!results.value_checks[i]) {
|
||||||
|
CAPTURE(results.expected_values[i]);
|
||||||
|
CAPTURE(results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify operation count
|
||||||
|
CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test metal concurrent graph eval from different threads") {
|
||||||
|
Device D_GPU = Device::gpu;
|
||||||
|
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity
|
||||||
|
const int array_size = 64;
|
||||||
|
TestResults all_results;
|
||||||
|
|
||||||
|
// Pre-create streams
|
||||||
|
std::vector<Stream> streams;
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
streams.push_back(new_stream(D_GPU));
|
||||||
|
}
|
||||||
|
synchronize();
|
||||||
|
|
||||||
|
auto task = [&](int thread_id, Stream s) {
|
||||||
|
try {
|
||||||
|
float val1_base = static_cast<float>(thread_id + 1) * 10.0f;
|
||||||
|
auto x = full({array_size, array_size}, val1_base, s);
|
||||||
|
auto y = full({array_size, array_size}, val1_base + 1.0f, s);
|
||||||
|
auto z = add(x, y);
|
||||||
|
auto w = multiply(z, x);
|
||||||
|
eval(w);
|
||||||
|
|
||||||
|
float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base;
|
||||||
|
auto sample = slice(w, {0,0}, {1,1});
|
||||||
|
float actual_val = sample.item<float>();
|
||||||
|
|
||||||
|
bool shape_ok = (w.shape() == Shape{array_size, array_size});
|
||||||
|
bool available_ok = w.is_available();
|
||||||
|
bool value_ok = (std::abs(actual_val - expected_val) < 1e-4);
|
||||||
|
|
||||||
|
all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val);
|
||||||
|
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||||
|
|
||||||
|
CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread
|
||||||
|
for (size_t i = 0; i < num_threads; ++i) {
|
||||||
|
CAPTURE(i);
|
||||||
|
CHECK(all_results.shape_checks[i]);
|
||||||
|
CHECK(all_results.availability_checks[i]);
|
||||||
|
CHECK(all_results.value_checks[i]);
|
||||||
|
if (!all_results.value_checks[i]) {
|
||||||
|
CAPTURE(all_results.expected_values[i]);
|
||||||
|
CAPTURE(all_results.actual_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user