Compare commits

..

22 Commits

Author SHA1 Message Date
Awni Hannun
529842fed9 fix 2025-11-03 16:43:19 -08:00
Awni Hannun
cc6df9fc8a fix 2025-11-03 15:07:01 -08:00
Awni Hannun
742033fefe remove use of cuda pool, use cuda free async 2025-11-03 09:14:17 -08:00
Awni Hannun
c27a0647a3 load eval gpu for cuda 2025-11-01 13:18:57 -07:00
Awni Hannun
d378567cc6 refactor for regular cuda malloc 2025-10-31 14:12:15 -07:00
Awni Hannun
b84fc978d3 add pool threshold 2025-10-30 10:32:57 -07:00
Awni Hannun
764b4b7ce8 Use async cuda malloc managed with cuda 13 2025-10-30 10:32:57 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
Awni Hannun
d1e06117e8 bump python (#2694) 2025-10-27 11:34:31 -07:00
Awni Hannun
539d8322d1 add median op (#2705) 2025-10-27 11:33:42 -07:00
Awni Hannun
c4767d110f fix addmm cpu (#2699) 2025-10-27 11:33:32 -07:00
David Koski
895217f25b optionally load metallib from framework (#2702)
* optionally load metallib from framework

* pre-commit

* adjust logic
2025-10-27 07:52:03 -07:00
Manuel Villanueva
0cfeeb60ca Einsum error msg improvement (#2690)
* Improved error message for Einsum

* Modifications via pre-commit

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-27 06:31:47 -07:00
Ronan Collobert
8f8af61a37 fix warnings showing up with -Wall (#2692) 2025-10-24 11:43:35 -07:00
Manuel Villanueva
233384161e Improved mx.split() docs (#2689)
* Improved mx.split() documentation

* Fix typo in docstring for array split function

* add example

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-24 09:48:41 -07:00
Awni Hannun
5bcf3a6794 format 2025-10-22 16:08:47 -07:00
wickedcoder
7707196297 Merge commit from fork
* add length validation to the header

* fix accessing out of bound index with .at()
2025-10-22 15:31:25 -07:00
wickedcoder
7e3471c987 Merge commit from fork
* add tensor->weights_data validation

* add null pointer check for tensor
2025-10-22 15:31:03 -07:00
Awni Hannun
9f0ba3ddf1 patch bump (#2680) 2025-10-17 12:12:07 -07:00
154 changed files with 3547 additions and 1399 deletions

View File

@@ -26,9 +26,9 @@ jobs:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install python@3.10
brew install doxygen
python3.9 -m venv env
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -140,7 +140,7 @@ jobs:
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
@@ -273,7 +273,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
@@ -328,7 +328,7 @@ jobs:
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -351,7 +351,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
build_env:
type: string
default: ""
@@ -387,7 +387,7 @@ jobs:
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -484,7 +484,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
@@ -503,7 +503,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
@@ -546,13 +546,13 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
@@ -564,14 +564,14 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:

View File

@@ -0,0 +1,24 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

68
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,68 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs/build/html -L .
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

78
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,78 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -0,0 +1,22 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

124
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,124 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
build-jit:
description: 'Whether to build with JIT'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip cmake setuptools
uv pip install nanobind==2.4.0 \
numpy torch tensorflow unittest-xml-reporting
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.build-jit == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 13 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 13.0
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

83
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

31
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

93
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,93 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

View File

@@ -1,20 +1,46 @@
on:
pull_request:
branches:
- main
name: Build and Test
on: pull_request
permissions:
contents: read
jobs:
check_lint:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs

188
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,188 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
build_documentation:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiples: true
path: artifacts
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiples: true
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: build_cuda_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: build_linux_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: build_mac_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,4 +1,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()

View File

@@ -16,7 +16,7 @@ silicon computer is
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- Using a native Python >= 3.10
- macOS >= 13.5
.. note::
@@ -39,7 +39,7 @@ requirements:
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
CPU-only (Linux)
@@ -55,7 +55,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
Troubleshooting

View File

@@ -112,6 +112,7 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
}
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
}
void array::copy_shared_buffer(const array& other) {
@@ -241,8 +240,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
status(Status::unscheduled),
inputs(std::move(inputs)) {
init();
}

View File

@@ -349,15 +349,23 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
}
enum Status {
@@ -461,8 +469,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
// Offset from beginning of data pointer
int64_t offset{0};
// The size in elements of the data buffer the array accesses
size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
mallocfn(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
mallocfn(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
}
}
}

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
return false;
}
}

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt) {
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
mallocfn(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -7,19 +7,22 @@
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
}

View File

@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,

View File

@@ -46,7 +46,6 @@ void eig_impl(
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
const void* a_mask_ptr = nullptr;
const void* b_mask_ptr = nullptr;
const void* out_mask_ptr = nullptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
bool a_mask_bool = false;
bool b_mask_bool = false;
bool out_mask_bool = false;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
@@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -91,7 +91,6 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -445,7 +448,6 @@ void mxfp4_qmm(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -487,7 +489,6 @@ void mxfp4_qmm_t(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -1104,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -39,7 +39,7 @@ struct StridedIterator {
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
: stride_(stride), ptr_(ptr + offset * stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}

View File

@@ -83,8 +83,6 @@ void svd_impl(
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;

View File

@@ -108,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
@@ -51,12 +52,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -170,11 +178,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
@@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf;
}
@@ -88,14 +90,22 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
}
Buffer CudaAllocator::malloc(size_t size) {
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
@@ -122,8 +132,17 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
int device = -1;
if (stream != nullptr) {
cudaStreamGetDevice(stream, &device);
}
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
@@ -141,6 +160,14 @@ Buffer CudaAllocator::malloc(size_t size) {
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
@@ -169,7 +196,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
}
delete buf;
}
}
@@ -220,6 +251,16 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {
@@ -232,7 +273,19 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
CHECK_CUDA_ERROR(
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(cbuf.data));
cbuf.data = new_data;
}
return cbuf.data;
}
} // namespace allocator

View File

@@ -4,7 +4,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <utility>
@@ -17,6 +19,7 @@ using allocator::Buffer;
struct CudaBuffer {
void* data;
size_t size;
int device; // -1 for managed
};
class SmallSizePool {
@@ -45,6 +48,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dims,
0,
out.data<OutType>(),
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
out.data_size());
});
}
@@ -365,7 +365,11 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param(shape),
const_param(a_strides),
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
out_a.data_size());
});
}

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
}
}
auto& encoder = cu::get_command_encoder(s);
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
for (auto& x : outputs) {
args.append(x);
}
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -5,6 +5,22 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
}
});
});

View File

@@ -92,8 +92,8 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;

View File

@@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false;
}
encoder.add_temporary(workspace);
return true;
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>

View File

@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
}
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"

View File

@@ -2,6 +2,8 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -334,4 +336,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto set_input_output = [&](const array& in,
array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:

View File

@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
@@ -278,9 +278,9 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
nullptr,
alpha);
}
@@ -321,10 +321,10 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(c),
alpha,
beta);
}
@@ -370,11 +370,11 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
workspace_ptr = gpu_ptr<void>(workspace);
}
auto capture = encoder.capture_context();

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr,
alpha);
a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3},
uint64);
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4},
uint64);
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;

View File

@@ -149,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
mat = gpu_ptr<DataType>(b);
vec = gpu_ptr<DataType>(a);
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
mat = gpu_ptr<DataType>(a);
vec = gpu_ptr<DataType>(b);
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -177,7 +177,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols);
} else {
@@ -189,7 +189,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols,
const_param(batch_shape),

View File

@@ -31,7 +31,7 @@ void append_indices_arg(
int idx_ndim) {
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
indices[i] = gpu_ptr<void>(inputs[i + 1]);
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() > 0);
const auto& src = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(idx_dtype),
nidx);
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_ndim,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& src = inputs[0];
const auto& idx = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(out.dtype()),
dtype_to_string(idx.dtype()));
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.flags().row_contiguous,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
}
template <typename T>

View File

@@ -9,6 +9,7 @@
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h>

View File

@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(b),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride,
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

60
mlx/backend/cuda/load.cpp Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <utility>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/primitives.h"
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
namespace mlx::core {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<void>(out),
out_ptr,
nbytes,
cudaMemcpyDefault,
encoder.stream()));
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
}
} // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
cu::malloc_async(in.nbytes() / n, encoder.stream()),
in.data_size() / n,
std::move(strides),
flags);
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
gemm_and_bias(
encoder,
M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);

View File

@@ -28,7 +28,6 @@ NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)

View File

@@ -262,10 +262,10 @@ void affine_quantize(
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
w.size());
});
});
@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -318,10 +318,10 @@ void affine_dequantize(
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
gpu_ptr<uint8_t>(wq),
gpu_ptr<T>(scales),
gpu_ptr<T>(biases),
gpu_ptr<T>(w),
w.size());
});
});

View File

@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,83 @@
#pragma once
struct __nv_fp8_e8m0 {
__device__ __nv_fp8_e8m0(float x) {
if (!std::isfinite(x)) {
__x = 0xFF;
return;
}
if (x < 0.0f) {
__x = 0x00;
return;
}
float le = std::log2f(x);
int n = static_cast<int>(std::nearbyintf(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
__x = static_cast<uint8_t>(n + 127);
}
__device__ operator float() {
if (__x == 0xFF) {
return std::numeric_limits<float>::quiet_NaN();
}
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
}
uint8_t __x{0};
};
struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
__x = 0x7;
return;
}
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
x = std::abs(x);
if (x > 5.0f) {
__x = 0x7;
} else if (x >= 3.5f) {
__x = 0x6;
} else if (x > 2.5f) {
__x = 0x5;
} else if (x >= 1.75f) {
__x = 0x4;
} else if (x > 1.25f) {
__x = 0x3;
} else if (x >= 0.75f) {
__x = 0x2;
} else if (x > 0.25f) {
__x = 0x1;
} else {
__x = 0x0;
}
__x |= sign_bit;
}
__device__ operator float() {
static const float LUT[16] = {
0.0f,
0.5f,
1.0f,
1.5f,
2.0f,
3.0f,
4.0f,
6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
return LUT[__x];
}
uint8_t __x{0};
};

View File

@@ -0,0 +1,216 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
if constexpr (bits == 8) {
return float(*(__nv_fp8_e4m3*)(&x));
} else {
return float(*(__nv_fp4_e2m1*)(&x));
}
}
};
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
return;
}
float w_thread = w[index];
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto scale = float(((ScaleType*)(scales))[gindex]);
out += oindex;
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}
} // namespace cu
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
gpu_ptr<T>(w),
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not quantize input with type float64.");
}
});
}
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
constexpr int uint8_per_uint32 = 4;
int packs_per_int = 8 / bits;
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_dequantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_dequantize<T, 16, 4, false>;
}
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
gpu_ptr<uint8_t>(wq),
gpu_ptr<uint8_t>(scales),
gpu_ptr<T>(w),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
}
});
}
} // namespace mlx::core

View File

@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
}
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
}
}
}

View File

@@ -24,4 +24,22 @@ void affine_dequantize(
cu::CommandEncoder& enc,
const Stream& s);
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
uint32_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
@@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key);
@@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
gpu_ptr<uint32_t>(keys),
gpu_ptr<uint8_t>(out),
grid_dims,
odd,
bytes_per_key,

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
@@ -100,14 +100,15 @@ void all_reduce(
Dtype dt = in.dtype();
// Cub doesn't like const pointers for load (sigh).
void* indata = const_cast<void*>(in.data<void>());
void* indata = const_cast<void*>(gpu_ptr<void>(in));
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) {
@@ -122,14 +123,14 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
intermediate.data<U>(),
gpu_ptr<U>(intermediate),
block_step,
insize);
});
});
// Set the input for the next step and recalculate the blocks
indata = intermediate.data<void>();
indata = gpu_ptr<void>(intermediate);
dt = intermediate.dtype();
insize = intermediate.size();
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
@@ -149,7 +150,7 @@ void all_reduce(
threads,
0,
static_cast<T*>(indata),
out.data<U>(),
gpu_ptr<U>(out),
block_step,
insize);
});

View File

@@ -250,7 +250,7 @@ void col_reduce_looped(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -261,7 +261,7 @@ void col_reduce_looped(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
@@ -276,7 +276,7 @@ void col_reduce_looped(
blocks,
0,
indata,
out.data<U>(),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
});
});
@@ -293,7 +293,7 @@ void col_reduce_small(
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -312,8 +312,8 @@ void col_reduce_small(
grid,
block,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args),
out.size());
});

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
encoder.set_output_array(out);
@@ -42,7 +42,7 @@ void init_reduce(
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
kernel, grid, block, 0, gpu_ptr<U>(out), out.size());
});
});
}

View File

@@ -5,6 +5,7 @@
#include <numeric>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
@@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
const std::vector<int>& axes,
cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return;
}
@@ -133,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
allocator::malloc(out.nbytes()),
cu::malloc_async(out.nbytes(), encoder.stream()),
data_size,
final_strides,
fl,

View File

@@ -238,7 +238,7 @@ void row_reduce_simple(
const ReductionPlan& plan) {
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
// TODO: If out.size() < 1024 which will be a common case then write this in
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
@@ -268,10 +268,10 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
T* indata = const_cast<T*>(gpu_ptr<T>(in));
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
});
});
}
@@ -286,7 +286,7 @@ void row_reduce_looped(
cu::RowReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
allocate_same_layout(out, in, axes, encoder);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -315,7 +315,7 @@ void row_reduce_looped(
});
encoder.add_kernel_node(
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
});
});
}

View File

@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
nvtx3::scoped_range r("RMSNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_output_array(out);
@@ -223,9 +223,9 @@ void RMSNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride);
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}
@@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

View File

@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
if (with_freqs) {
@@ -340,9 +340,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
mat_size,
@@ -357,10 +357,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
mat_size,
dims,
@@ -381,10 +381,10 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
std::log2(base_),
strides,
@@ -408,9 +408,9 @@ void RoPE::eval_gpu(
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
strides,

View File

@@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
gpu_ptr<DataType>(o),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
params);
});
});
@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
@@ -601,13 +602,13 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
gpu_ptr<DataType>(q),
gpu_ptr<DataType>(k),
gpu_ptr<DataType>(v),
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
params);
}
@@ -628,10 +629,10 @@ void sdpa_vector_2pass_fallback(
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
gpu_ptr<float>(intermediate),
gpu_ptr<float>(sums),
gpu_ptr<float>(maxs),
gpu_ptr<DataType>(o),
params);
}
});
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
allocator::malloc(o.nbytes()),
cu::malloc_async(o.nbytes(), encoder.stream()),
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);

View File

@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
@@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.data_size() / axis_size,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size);
} else {
constexpr int BM = WARP_SIZE;
@@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim,
0,
in.data<T>(),
out.data<U>(),
gpu_ptr<T>(in),
gpu_ptr<U>(out),
axis_size,
stride,
stride_blocks);

View File

@@ -23,14 +23,15 @@ void concatenate_gpu(
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto concurrent = cu::get_command_encoder(s).concurrent_context();
auto concurrent = encoder.concurrent_context();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
auto& encoder = cu::get_command_encoder(s);
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);

View File

@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Softmax::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
@@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

View File

@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
out = array(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
in.data_size(),
in.strides(),
in.flags());
@@ -70,22 +73,28 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
array indices(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
array discard(
cu::malloc_async(in.nbytes(), encoder.stream()),
in.shape(),
in.dtype());
encoder.add_temporary(discard);
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr,
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
@@ -103,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
thrust::device_pointer_cast(gpu_ptr<uint32_t>(indices)),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(discard),
gpu_ptr<uint32_t>(indices),
gpu_ptr<uint32_t>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -125,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,
@@ -135,16 +147,19 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(),
gpu_ptr<void>(temp),
size,
in.data<Type>(),
out.data<Type>(),
gpu_ptr<Type>(in),
gpu_ptr<Type>(out),
in.data_size(),
in.data_size() / nsort,
offsets,

View File

@@ -168,10 +168,10 @@ void ternary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
out.data_size());
});
} else {
@@ -211,10 +211,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -231,10 +231,10 @@ void ternary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
gpu_ptr<bool>(a),
gpu_ptr<DType>(b),
gpu_ptr<DType>(c),
gpu_ptr<DType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -256,7 +256,10 @@ void ternary_op_gpu(
auto& b = inputs[1];
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
ternary_op_gpu_inplace<Op>(inputs, out, s);
}

View File

@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, ToFP8>) {
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
}
if (std::is_same_v<Op, FromFP8>) {
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
}
return false;
}
@@ -152,8 +158,8 @@ void unary_op_gpu_inplace(
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
@@ -176,8 +182,8 @@ void unary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(in),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(strides),
@@ -201,7 +207,10 @@ void unary_op_gpu(
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -4,89 +4,12 @@
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/cuda_utils.h"
namespace mlx::core {
namespace cu {
class Device;
}
struct Dtype;
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
@@ -100,4 +23,22 @@ inline uint max_occupancy_block_dim(T kernel) {
return block_dim;
}
template <typename T>
inline T* gpu_ptr(array& arr) {
return reinterpret_cast<T*>(
static_cast<char*>(
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
arr.offset());
}
template <typename T>
inline const T* gpu_ptr(const array& arr) {
return gpu_ptr<T>(const_cast<array&>(arr));
}
struct Dtype;
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
} // namespace mlx::core

View File

@@ -7,18 +7,7 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
@@ -52,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {

Some files were not shown because too many files have changed in this diff Show More