Compare commits

...

15 Commits

Author SHA1 Message Date
dependabot[bot]
c2764d1073 Bump actions/download-artifact from 6 to 7 (#2912)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:10:16 -08:00
dependabot[bot]
093a62d2ed Bump actions/upload-artifact from 5 to 6 (#2911)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:09:55 -08:00
Awni Hannun
1b591ec736 No VJP for mask or sinks in attention (#2909)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-13 19:48:39 -08:00
Awni Hannun
47d2505ea9 Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 06:54:30 -08:00
Cheng
bedefed784 Fix ccache getting disabled (#2905)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 13:00:51 +09:00
Melissa Kilby
ccaaa7d6df fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-12 02:11:18 -08:00
Awni Hannun
f3e5ca5414 [CUDA] Add host nodes to subgraph types for graph update (#2901)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 19:13:44 -08:00
Awni Hannun
81dfe5f137 Fix grad in place updates (#2899)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-11 14:44:58 -08:00
Anastasiia Filippova
012fb220a1 fp quantize (#2892)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 06:11:25 -08:00
Nathan Goldbaum
e1fee0074b Update nanobind pin to most recent version (#2896) 2025-12-11 06:07:36 -08:00
CCYeh
3c8ce9b00e Fix input buffer donation in compile (#2897) 2025-12-11 06:07:03 -08:00
David Koski
937ce79660 do not use simd neon intrinsics on x86 (#2893)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-10 12:23:28 -08:00
Nathan Goldbaum
208f5441a7 bump minimum required Python version (#2891)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-09 16:54:38 -08:00
Awni Hannun
b862d842e1 Allow events in sub graph to be updatable (#2886)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-09 12:34:37 -08:00
Satyam singh
f7a400951a Fix docs: replace mx.random.randn with mx.random.normal (#2890) 2025-12-09 11:46:30 -08:00
30 changed files with 651 additions and 146 deletions

View File

@@ -11,7 +11,7 @@ runs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs

View File

@@ -10,23 +10,29 @@ inputs:
description: 'Version of python to set up' description: 'Version of python to set up'
required: false required: false
default: '3.10' default: '3.10'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies - name: Install common dependencies
shell: bash shell: bash
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6 - uses: actions/setup-python@v6
with: with:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
@@ -36,7 +42,7 @@ runs:
run: | run: |
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0 pip install setuptools cmake nanobind==2.10.2
echo PATH=$PATH >> $GITHUB_ENV echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind # Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV

View File

@@ -23,14 +23,14 @@ jobs:
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64" arch: "x86_64"
- name: Upload mlx artifacts - name: Upload mlx artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: linux-wheels-${{ matrix.python_version }} name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl path: wheelhouse/mlx-*.whl
retention-days: 7 retention-days: 7
- name: Upload mlx-cpu artifacts - name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10' if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: mlx-cpu name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl path: wheelhouse/mlx_cpu-*.whl
@@ -89,7 +89,7 @@ jobs:
with: with:
toolkit: 'cuda-12.9' toolkit: 'cuda-12.9'
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: mlx-cuda name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl path: wheelhouse/mlx_cuda-*.whl

View File

@@ -57,19 +57,20 @@ jobs:
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release - uses: ./.github/actions/build-linux-release
with: with:
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }} arch: ${{ matrix.arch }}
- name: Upload MLX artifacts - name: Upload MLX artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
overwrite: true overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }} name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts - name: Upload CPU artifacts
if: matrix.python_version == '3.10' if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
overwrite: true overwrite: true
name: mlx-cpu-${{ matrix.arch }} name: mlx-cpu-${{ matrix.arch }}
@@ -95,7 +96,7 @@ jobs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs
shell: bash -l {0} shell: bash -l {0}
@@ -113,14 +114,14 @@ jobs:
macos-target: 15.0 macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts - name: Upload MLX artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
overwrite: true overwrite: true
name: mac-wheels-${{ matrix.python-version }} name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl path: dist/mlx-*.whl
- name: Upload Metal artifacts - name: Upload Metal artifacts
if: matrix.python-version == '3.10' if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
overwrite: true overwrite: true
name: mlx-metal name: mlx-metal
@@ -141,12 +142,13 @@ jobs:
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
toolkit: ${{ matrix.toolkit }} toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with: with:
arch: ${{ matrix.arch }} arch: ${{ matrix.arch }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
overwrite: true overwrite: true
name: mlx-cuda name: mlx-cuda
@@ -162,12 +164,12 @@ jobs:
name: pypi name: pypi
url: https://pypi.org/p/mlx url: https://pypi.org/p/mlx
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v7
with: with:
pattern: linux-wheels-* pattern: linux-wheels-*
merge-multiple: true merge-multiple: true
path: dist path: dist
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v7
with: with:
pattern: mac-wheels-* pattern: mac-wheels-*
merge-multiple: true merge-multiple: true
@@ -189,7 +191,7 @@ jobs:
name: pypi name: pypi
url: https://pypi.org/p/mlx-cuda url: https://pypi.org/p/mlx-cuda
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v7
with: with:
name: mlx-cuda name: mlx-cuda
path: dist path: dist
@@ -210,7 +212,7 @@ jobs:
name: pypi name: pypi
url: https://pypi.org/p/mlx-cpu url: https://pypi.org/p/mlx-cpu
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v7
with: with:
pattern: mlx-cpu-* pattern: mlx-cpu-*
merge-multiple: true merge-multiple: true
@@ -232,7 +234,7 @@ jobs:
name: pypi name: pypi
url: https://pypi.org/p/mlx-metal url: https://pypi.org/p/mlx-metal
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v7
with: with:
name: mlx-metal name: mlx-metal
path: dist path: dist

View File

@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(
Python 3.8 Python 3.10
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
REQUIRED) REQUIRED)
execute_process( execute_process(

View File

@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
.. code-block:: shell .. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10) >>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1 >>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]`` The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``. selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.25",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.10.2",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.10.2

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { !is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;

View File

@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys, num_keys,
kshape = keys.shape(), kshape = keys.shape(),
kstrides = keys.strides()]() mutable { kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4; size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2; auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0; bool even = out_skip % 2 == 0;
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) { if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count); auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first; ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) { copy_remaining(cptr, count.second, rb.second);
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
} }
if (!even) { if (!even) {
count.second = 0; count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first; copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
} }
} }
}); });

View File

@@ -3,5 +3,9 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h" #include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif #endif
#endif

View File

@@ -338,18 +338,23 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
} }
cudaGraphNodeType type; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) { switch (type) {
case cudaGraphNodeTypeGraph: {
// Try to be updatable for a structure like graph -> graph -> kernel // Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child; cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child); auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable; is_updatable &= sub_is_updatable;
key += subkey; key += subkey;
} else if (type == cudaGraphNodeTypeMemset) { break;
}
case cudaGraphNodeTypeHost:
key += "H";
break;
case cudaGraphNodeTypeMemset:
key += "M"; key += "M";
} else if (type != cudaGraphNodeTypeKernel) { break;
is_updatable = false; case cudaGraphNodeTypeKernel: {
} else {
cudaLaunchAttributeValue cluster_dim; cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim)); node, cudaLaunchAttributeClusterDimension, &cluster_dim));
@@ -360,6 +365,16 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
key += "K"; key += "K";
key += std::to_string(cluster_dim.clusterDim.x); key += std::to_string(cluster_dim.clusterDim.x);
} }
break;
}
case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
is_updatable = false;
} }
} }
key += ")"; key += ")";

View File

@@ -2,7 +2,11 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -13,17 +17,6 @@
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits> template <int bits>
struct Dequantize { struct Dequantize {
__device__ float operator()(uint8_t x) { __device__ float operator()(uint8_t x) {
@@ -37,29 +30,40 @@ struct Dequantize {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
auto block_size = cg::this_thread_block().dim_threads(); auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index(); auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index(); auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
auto grid_dim_x = size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; size_t base_idx = thread_idx * group_size;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) { if (base_idx >= size) {
return; return;
} }
float w_thread = w[index]; auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
cg::greater<float> max_op; Tx2 amax_2x = Tx2{0.0f, 0.0f};
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f; scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale // Convert to mx scale or nv scale
using ScaleType = using ScaleType =
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x; uint8_t q_scale = s.__x;
scale = float(s); scale = float(s);
// Write out the scales scales[thread_idx] = q_scale;
size_t gindex = index / group_size; constexpr int elem_per_byte = bits == 8 ? 1 : 2;
if (index % group_size == 0) { AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale); #pragma unroll
if (bits == 4) { for (int i = 0; i < group_size / 4; i++) {
uint8_t sval = warp.shfl_down(output, 1); Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
output |= sval << bits; if constexpr (bits == 8) {
uint32_t quantized_val =
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
} else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
} }
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
} }
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
} }
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -142,15 +149,16 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) { if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>; auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
if (bits == 8) { if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>; kernel = cu::fp_quantize<T, 32, 8, true, false>;
} else if (group_size == 16) { } else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>; kernel = cu::fp_quantize<T, 16, 4, false, false>;
} }
bool large = w.size() > UINT_MAX; bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large); get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -0,0 +1,32 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
} }
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu } // namespace cu
template <typename F> template <typename F>

View File

@@ -3,31 +3,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/steel/utils.cuh" #include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/** /**
* The basic building block for Ampere mmas. A 16x16 tile distributed across * The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp. * the warp.

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -347,7 +347,7 @@ template <
MMAFrag_mask_t::load_safe( MMAFrag_mask_t::load_safe(
mfrag, mfrag,
mask, mask,
int(mask_params->M_strides[2]), int64_t(mask_params->M_strides[2]),
Int<1>{}, Int<1>{},
params->qL, params->qL,
params->kL, params->kL,

View File

@@ -346,7 +346,7 @@ template <
MSubTile mfrag; MSubTile mfrag;
mfrag.load_safe( mfrag.load_safe(
mask, mask,
int(mask_params->M_strides[2]), int64_t(mask_params->M_strides[2]),
Int<1>{}, Int<1>{},
params->qL, params->qL,
params->kL, params->kL,

View File

@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
LimY lim_y, LimY lim_y,
OffX off_x = Int<0>{}, OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) { OffY off_y = Int<0>{}) {
src += off_x * str_x + off_y * str_y;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) { for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) { for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] = dst[i * kElemCols + j] = static_cast<T>(src[0]);
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
} else { } else {
dst[i * kElemCols + j] = T(0); dst[i * kElemCols + j] = T(0);
} }
src += str_y;
} }
src -= kElemCols * str_y;
src += str_x;
} }
} }

View File

@@ -880,6 +880,11 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<array> returned_vjps; std::vector<array> returned_vjps;
for (int arg : argnums) { for (int arg : argnums) {
if (arg >= 3) {
throw std::invalid_argument(
"[scale_dot_product_attention] Does not support VJP with respect "
" to mask or attention sinks.");
}
returned_vjps.push_back(std::move(vjps[arg])); returned_vjps.push_back(std::move(vjps[arg]));
} }
return returned_vjps; return returned_vjps;

View File

@@ -1,7 +1,7 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=80", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.10.2",
"cmake>=3.25", "cmake>=3.25",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func", /* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func), /* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0, /* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots}; /* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr; static PyTypeObject* gc_func_tp = nullptr;

View File

@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
NB_TYPE_CASTER( NB_TYPE_CASTER(
List, List,
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name + const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size; size_t size;

View File

@@ -124,37 +124,53 @@ auto py_value_and_grad(
// Collect the arrays // Collect the arrays
std::vector<mx::array> arrays; std::vector<mx::array> arrays;
std::vector<nb::object> array_objects;
auto flatten_with_objects = [&arrays, &array_objects](
auto tree, bool strict) {
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
arrays.push_back(nb::cast<mx::array>(obj));
array_objects.push_back(nb::borrow<nb::object>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
};
std::vector<int> counts(1, 0); std::vector<int> counts(1, 0);
std::vector<int> gradient_indices; std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) { for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i); bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); auto pre_size = arrays.size();
flatten_with_objects(args[i], /* strict = */ needs_grad);
if (needs_grad) { if (needs_grad) {
auto old_size = gradient_indices.size(); auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsi.size()); auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota( std::iota(
gradient_indices.begin() + old_size, gradient_indices.begin() + old_size,
gradient_indices.end(), gradient_indices.end(),
arrays.size()); pre_size);
j++; j++;
counts.push_back(argsi.size()); counts.push_back(delta_size);
} }
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
} }
for (auto item : kwargs) { for (auto item : kwargs) {
bool needs_grad = bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end()); (argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad); auto pre_size = arrays.size();
flatten_with_objects(item.second, /* strict = */ needs_grad);
if (needs_grad) { if (needs_grad) {
auto old_size = gradient_indices.size(); auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsk.size()); auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota( std::iota(
gradient_indices.begin() + old_size, gradient_indices.begin() + old_size,
gradient_indices.end(), gradient_indices.end(),
arrays.size()); pre_size);
counts.push_back(argsk.size()); counts.push_back(delta_size);
} }
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
} }
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
@@ -163,7 +179,7 @@ auto py_value_and_grad(
nb::object py_value_out; nb::object py_value_out;
auto value_and_grads = mx::value_and_grad( auto value_and_grads = mx::value_and_grad(
[&fun, [&fun,
&arrays, &array_objects,
&args, &args,
&kwargs, &kwargs,
&py_value_out, &py_value_out,
@@ -183,8 +199,9 @@ auto py_value_and_grad(
tree_visit_update(tree, [&](nb::handle node) { tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node); auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) { if (replace_arr.id() == a[index].id()) {
return nb::cast(arrays[index++]); return array_objects[index++];
} else { } else {
index++;
return nb::cast(replace_arr); return nb::cast(replace_arr);
} }
}); });

View File

@@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
return arrs[0] return arrs[0]
arrs = [mx.array(1.0)] arrs = [mx.array(1.0)]
init_id = id(arrs[0]) arr = arrs[0]
mx.grad(fun)(arrs) mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0])) self.assertEqual(id(arr), id(arrs[0]))
def fun(arrs):
arrs[1] = sum(arrs)
return arrs[1]
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
a_0, a_1, a_2 = arrs
mx.grad(fun)(arrs)
self.assertEqual(id(a_0), id(arrs[0]))
self.assertNotEqual(id(a_1), id(arrs[1]))
self.assertEqual(id(a_2), id(arrs[2]))
def test_grad_with_inplace_update(self): def test_grad_with_inplace_update(self):
def loss_fn(model): def loss_fn(model):

View File

@@ -4,12 +4,12 @@ import gc
import inspect import inspect
import io import io
import math import math
import unittest
from functools import partial, wraps from functools import partial, wraps
from io import StringIO from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase): class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x) loss, grads = step(emb, w, x)
mx.eval(loss, grads) mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase):
return Vector([t[0] + 10, t[1] * 10]) return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2)) x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform) vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple) vmap_transform_tuple = mx.vmap(transform_tuple)

View File

@@ -255,7 +255,7 @@ if __name__ == "__main__":
extras = { extras = {
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.10.2",
"numpy", "numpy",
"pre-commit", "pre-commit",
"setuptools>=80", "setuptools>=80",