Compare commits

...

15 Commits

Author SHA1 Message Date
Anastasiia Filippova
012fb220a1 fp quantize (#2892)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 06:11:25 -08:00
Nathan Goldbaum
e1fee0074b Update nanobind pin to most recent version (#2896) 2025-12-11 06:07:36 -08:00
CCYeh
3c8ce9b00e Fix input buffer donation in compile (#2897) 2025-12-11 06:07:03 -08:00
David Koski
937ce79660 do not use simd neon intrinsics on x86 (#2893)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-10 12:23:28 -08:00
Nathan Goldbaum
208f5441a7 bump minimum required Python version (#2891)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-09 16:54:38 -08:00
Awni Hannun
b862d842e1 Allow events in sub graph to be updatable (#2886)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-09 12:34:37 -08:00
Satyam singh
f7a400951a Fix docs: replace mx.random.randn with mx.random.normal (#2890) 2025-12-09 11:46:30 -08:00
Awni Hannun
27232db1ba [CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-08 06:18:01 -08:00
Awni Hannun
a4b3bc969b Try not to fail when there should be memory available (#2869)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-07 06:11:00 -08:00
Awni Hannun
667c0f3bb9 [Metal] No copy array init (#2875)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 13:36:45 -08:00
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
41 changed files with 965 additions and 191 deletions

View File

@@ -1,6 +1,15 @@
name: 'Build CUDA wheel' name: 'Build CUDA wheel'
description: 'Build CUDA wheel' description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs: runs:
using: "composite" using: "composite"
steps: steps:
@@ -12,4 +21,4 @@ runs:
pip install auditwheel build patchelf setuptools pip install auditwheel build patchelf setuptools
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -11,7 +11,7 @@ runs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs

View File

@@ -15,6 +15,7 @@ runs:
using: "composite" using: "composite"
steps: steps:
- name: Use ccache - name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2 uses: hendrikmuhs/ccache-action@v1.2
with: with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
@@ -35,7 +36,7 @@ runs:
run: | run: |
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0 pip install setuptools cmake nanobind==2.10.2
echo PATH=$PATH >> $GITHUB_ENV echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind # Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV

View File

@@ -95,7 +95,7 @@ jobs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs
shell: bash -l {0} shell: bash -l {0}
@@ -128,7 +128,11 @@ jobs:
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }} DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
@@ -136,9 +140,11 @@ jobs:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
toolkit: 'cuda-12.9' toolkit: ${{ matrix.toolkit }}
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:

View File

@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(
Python 3.8 Python 3.10
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
REQUIRED) REQUIRED)
execute_process( execute_process(

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda] pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
- Nvidia architecture >= SM 7.0 (Volta) - Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
.. code-block:: shell .. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10) >>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1 >>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]`` The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``. selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.25",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.10.2",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.10.2

View File

@@ -1,7 +1,6 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
}; };
}; };
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -49,4 +49,25 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin()); init(data.begin());
} }
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,6 +57,16 @@ class array {
Shape shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */ /* Build an array from a buffer */
explicit array( explicit array(
allocator::Buffer data, allocator::Buffer data,

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { !is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;

View File

@@ -3,5 +3,9 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h" #include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif #endif
#endif

View File

@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool // Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8; constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page // The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size. // size and small_block_size.
constexpr int small_pool_size = 4 * page_size; constexpr int small_pool_size = 4 * page_size;
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
int device_count = 0; int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000 auto loc = cuda_mem_loc(i);
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
} }
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total; size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total * 0.9; memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int device_count = 0; int device_count = 0;
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
cudaStream_t s; cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking)); CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s); free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
} }
CHECK_CUDA_ERROR(cudaSetDevice(curr)); CHECK_CUDA_ERROR(cudaSetDevice(curr));
} }
@@ -154,23 +166,35 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
cudaError_t err;
void* data = nullptr; void* data = nullptr;
if (device == -1) { if (device == -1) {
err = cudaMallocManaged(&data, size); CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else { } else {
err = cudaMallocAsync(&data, size, stream); CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
if (!data) { if (!data) {
return Buffer{nullptr}; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
buf = new CudaBuffer{data, size, device}; buf = new CudaBuffer{data, size, device};
} }
lock.lock(); lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
} }
active_memory_ += buf->size; active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);

View File

@@ -71,11 +71,14 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_; size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_; std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };

View File

@@ -95,11 +95,14 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1; auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0; auto rest = out.size() / dim0;
if (dim0 >= 4) { if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4; work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
} }
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1); auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel = kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
} }
@@ -127,7 +133,9 @@ void copy_general_input(
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>; auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>; kernel = cu::copy_g<InType, OutType, IdxT, 4>;
} }
encoder.add_kernel_node( encoder.add_kernel_node(

View File

@@ -318,46 +318,64 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"}); insert_graph_dependencies(GraphNode{node, "K"});
} }
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) { std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// CUDA graphs do not get updated correctly if a kernel node getting updated // Constructs a key representing the nodes of a sub-graph.
// has a different cluster shape than the node it's being updated with. // Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0; size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) { if (num_nodes == 0) {
return true; return {key + ")", true};
} }
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes); std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) { for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) { switch (type) {
case cudaGraphNodeTypeGraph: {
// Try to be updatable for a structure like graph -> graph -> kernel // Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child; cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
return is_graph_updatable(child, cluster_dim_x); auto [subkey, sub_is_updatable] = subgraph_to_key(child);
} else if (type != cudaGraphNodeTypeKernel) { is_updatable &= sub_is_updatable;
return false; key += subkey;
} else { break;
}
case cudaGraphNodeTypeMemset:
key += "M";
break;
case cudaGraphNodeTypeKernel: {
cudaLaunchAttributeValue cluster_dim; cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim)); node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1 // Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false; is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
} }
// Only one child node allowed when subgraph uses clusters break;
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
} }
cluster_dim_x = cluster_dim.clusterDim.x; case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
is_updatable = false;
} }
} }
return true; key += ")";
return {key, is_updatable};
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -370,11 +388,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return; return;
} }
cudaGraphNode_t node; cudaGraphNode_t node;
int cluster_dim_x = 0; auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x); is_graph_updatable_ &= is_updatable;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies( insert_graph_dependencies(GraphNode{node, sub_graph_key});
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
} }
bool CommandEncoder::needs_commit() { bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
// E = empty // E = empty
// G* = subgraph (with metadata) // () = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators // Symbols ':', '-' are reserved as separators
std::string node_type; std::string node_type;
std::string id; std::string id;

View File

@@ -2,7 +2,11 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -13,17 +17,6 @@
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits> template <int bits>
struct Dequantize { struct Dequantize {
__device__ float operator()(uint8_t x) { __device__ float operator()(uint8_t x) {
@@ -37,29 +30,40 @@ struct Dequantize {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
auto block_size = cg::this_thread_block().dim_threads(); auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index(); auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index(); auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
auto grid_dim_x = size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; size_t base_idx = thread_idx * group_size;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) { if (base_idx >= size) {
return; return;
} }
float w_thread = w[index]; auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
cg::greater<float> max_op; Tx2 amax_2x = Tx2{0.0f, 0.0f};
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f; scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale // Convert to mx scale or nv scale
using ScaleType = using ScaleType =
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x; uint8_t q_scale = s.__x;
scale = float(s); scale = float(s);
// Write out the scales scales[thread_idx] = q_scale;
size_t gindex = index / group_size; constexpr int elem_per_byte = bits == 8 ? 1 : 2;
if (index % group_size == 0) { AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale); #pragma unroll
if (bits == 4) { for (int i = 0; i < group_size / 4; i++) {
uint8_t sval = warp.shfl_down(output, 1); Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
output |= sval << bits; if constexpr (bits == 8) {
uint32_t quantized_val =
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
} else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
} }
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
} }
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
} }
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -142,15 +149,16 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) { if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>; auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
if (bits == 8) { if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>; kernel = cu::fp_quantize<T, 32, 8, true, false>;
} else if (group_size == 16) { } else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>; kernel = cu::fp_quantize<T, 16, 4, false, false>;
} }
bool large = w.size() > UINT_MAX; bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large); get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -0,0 +1,32 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
} }
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu } // namespace cu
template <typename F> template <typename F>

View File

@@ -89,9 +89,13 @@ template <
int NDIM, int NDIM,
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4,
__global__ void int BLOCKS = 1>
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { __global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank(); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile // Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row; short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size; size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) { if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) { if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
} }
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
thread_x, thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN, out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args, const cu::ColReduceArgs& args,
int bn) { int bn,
int outer = 1) {
int gx, gy = 1; int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride; size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks; size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) { while (n_blocks / gy > INT32_MAX) {
gy *= 2; gy *= 2;
} }
@@ -277,7 +298,8 @@ void col_reduce_looped(
0, 0,
indata, indata,
gpu_ptr<U>(out), gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args)); static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
}); });
}); });
}); });
@@ -320,6 +342,117 @@ void col_reduce_small(
}); });
} }
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for // It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements. // a subrow of the fast moving axis. For instance 32 elements.
// //
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and // Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend). // leave transpositions as they are (contrary to our Metal backend).
// //
@@ -349,6 +494,14 @@ void col_reduce(
return; return;
} }
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -3,31 +3,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/steel/utils.cuh" #include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/** /**
* The basic building block for Ampere mmas. A 16x16 tile distributed across * The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp. * the warp.

View File

@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
} }
void CudaGraph::end_capture(cudaStream_t stream) { void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
} }

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -7,8 +7,6 @@
namespace mlx::core { namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) { void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }

View File

@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) { if (!buf) {
return Buffer{nullptr}; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
lk.lock(); lk.lock();
num_resources_++; num_resources_++;
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length(); return static_cast<MTL::Buffer*>(buffer.ptr())->length();
} }
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator // By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };

View File

@@ -25,6 +25,7 @@ class CommonAllocator : public Allocator {
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const { size_t get_active_memory() const {
return active_memory_; return active_memory_;
}; };

View File

@@ -1,7 +1,7 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=80", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.10.2",
"cmake>=3.25", "cmake>=3.25",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \ --plat manylinux_2_35_${1} \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \ --exclude libcuda* \

View File

@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func", /* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func), /* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0, /* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots}; /* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr; static PyTypeObject* gc_func_tp = nullptr;

View File

@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
NB_TYPE_CASTER( NB_TYPE_CASTER(
List, List,
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name + const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size; size_t size;

View File

@@ -4,12 +4,12 @@ import gc
import inspect import inspect
import io import io
import math import math
import unittest
from functools import partial, wraps from functools import partial, wraps
from io import StringIO from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase): class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x) loss, grads = step(emb, w, x)
mx.eval(loss, grads) mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(np_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True)) self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_long_column(self):
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
b = mx.array(a)
c1 = a.sum(0)
c2 = b.sum(0)
self.assertTrue(np.all(c1 == c2))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True) mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -7,13 +7,21 @@ import re
import subprocess import subprocess
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version(): def get_version():
with open("mlx/version.h", "r") as fid: with open("mlx/version.h", "r") as fid:
for l in fid: for l in fid:
@@ -31,7 +39,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release: if not pypi_release and not dev_release:
git_hash = ( git_hash = (
run( subprocess.run(
"git rev-parse --short HEAD".split(), "git rev-parse --short HEAD".split(),
capture_output=True, capture_output=True,
check=True, check=True,
@@ -247,7 +255,7 @@ if __name__ == "__main__":
extras = { extras = {
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.10.2",
"numpy", "numpy",
"pre-commit", "pre-commit",
"setuptools>=80", "setuptools>=80",
@@ -284,7 +292,11 @@ if __name__ == "__main__":
install_requires.append( install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"' f'mlx-metal=={version}; platform_system == "Darwin"'
) )
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
@@ -299,13 +311,25 @@ if __name__ == "__main__":
if build_macos: if build_macos:
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [ install_requires += [
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <climits> #include <climits>
#include "doctest/doctest.h" #include "doctest/doctest.h"
@@ -608,3 +607,24 @@ TEST_CASE("test make empty array") {
CHECK_EQ(a.size(), 0); CHECK_EQ(a.size(), 0);
CHECK_EQ(a.dtype(), bool_); CHECK_EQ(a.dtype(), bool_);
} }
TEST_CASE("test make array from user buffer") {
int size = 4096;
std::vector<int> buffer(size, 0);
int count = 0;
auto deleter = [&count](void*) { count++; };
{
auto a = array(buffer.data(), Shape{size}, int32, deleter);
if (metal::is_available()) {
CHECK_EQ(buffer.data(), a.data<int>());
}
auto b = a + array(1);
eval(b);
auto expected = ones({4096});
CHECK(array_equal(b, expected).item<bool>());
}
// deleter should always get called
CHECK_EQ(count, 1);
}