Compare commits

..

4 Commits

Author SHA1 Message Date
Arkar Min Aung
5feed6cb77
Merge cb4dc59a9e into 5adf185f86 2025-06-21 10:45:06 +10:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584
Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d
Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
12 changed files with 213 additions and 32 deletions

View File

@ -16,6 +16,9 @@ parameters:
linux_release: linux_release:
type: boolean type: boolean
default: false default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@ -104,7 +107,7 @@ jobs:
command: | command: |
echo "stubs" echo "stubs"
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -162,7 +165,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -223,7 +226,6 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
@ -283,7 +285,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@ -342,7 +344,7 @@ jobs:
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel python -m build --wheel
@ -356,6 +358,48 @@ jobs:
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows: workflows:
build_and_test: build_and_test:
when: when:
@ -625,3 +669,14 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@ -65,6 +75,8 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@ -107,6 +119,8 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@ -185,6 +199,7 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@ -3,6 +3,7 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu { namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) {
cuda_free(buf->data); cuda_free(buf->data);
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return; return;

View File

@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch(); worker_.end_batch();
worker_.commit(); commit();
f.wait(); f.wait();
} }

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll #pragma unroll
for (int i = NDIM - 1; i >= 0; --i) { for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0; IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc); return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0; IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i]; int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i]; a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * b_strides[i]; b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * c_strides[i]; c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i]; elem /= shape[i];
} }
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,11 +162,15 @@ class MatMul {
} }
} }
array workspace( void* workspace_ptr = nullptr;
allocator::malloc(heuristic_.workspaceSize), if (heuristic_.workspaceSize > 0) {
{static_cast<int>(heuristic_.workspaceSize)}, array workspace(
int8); allocator::malloc(heuristic_.workspaceSize),
encoder.add_temporary(workspace); {static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -183,8 +187,8 @@ class MatMul {
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace.data<void>(), workspace_ptr,
workspace.nbytes(), heuristic_.workspaceSize,
stream)); stream));
}); });
} }
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_; array out = out_;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) { if (axis < 0) {
axis += in.ndim(); axis += in.ndim();
} }
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags()); in.flags());
} }
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".' f'Module does not have sub-module named "{k}".'
) )
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(modules)):
current_value = dst[i] current_value = dst[i]
new_value = modules[i] new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value): if self.is_module(current_value) and self.is_module(new_value):

View File

@ -0,0 +1,17 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc*
cd wheelhouse
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
# Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" .

View File

@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]}) m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):

View File

@ -174,20 +174,26 @@ if __name__ == "__main__":
) )
package_dir = {"": "python"} package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
setup( setup(
name="mlx", name="mlx-cuda" if build_cuda else "mlx",
version=get_version(), version=get_version(),
author="MLX Contributors", author="MLX Contributors",
author_email="mlx@group.apple.com", author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.", description="A framework for machine learning on Apple silicon.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="MIT",
url="https://github.com/ml-explore/mlx", url="https://github.com/ml-explore/mlx",
packages=packages, packages=packages,
package_dir=package_dir, package_dir=package_dir,
package_data=package_data, package_data=package_data,
include_package_data=True, include_package_data=True,
install_requires=install_requires,
extras_require={ extras_require={
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.4.0",