Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4987e7615a Improve the cutlass gemm 2025-08-25 18:18:19 -07:00
Angelos Katharopoulos
e1303f6160 Reset cutlass gemm to working state again 2025-08-21 01:29:43 -07:00
Angelos Katharopoulos
cf5eef095d tmp 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
395d582719 Add a cutlass gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
05583bcd10 More pipelining for the sm_80 gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
6fce01593a Improve gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
97afe40b7b Remove duplicate register tile 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
f70c62d69c Simple gemm example 2025-08-20 23:51:25 -07:00
259 changed files with 4095 additions and 11807 deletions

View File

@@ -18,17 +18,16 @@ jobs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.10
brew install python@3.9
brew install doxygen
python3.10 -m venv env
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -90,8 +89,7 @@ jobs:
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
@@ -120,7 +118,7 @@ jobs:
parameters:
xcode_version:
type: string
default: "26.0.0"
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
@@ -128,19 +126,18 @@ jobs:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
resource_class: m2pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
uv venv --python 3.10
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
@@ -199,7 +196,7 @@ jobs:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv pip install -e .
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
@@ -225,20 +222,15 @@ jobs:
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
@@ -246,23 +238,12 @@ jobs:
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
@@ -273,10 +254,10 @@ jobs:
parameters:
python_version:
type: string
default: "3.10"
default: "3.9"
xcode_version:
type: string
default: "26.0.0"
default: "16.2.0"
build_env:
type: string
default: ""
@@ -285,7 +266,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
@@ -293,15 +274,11 @@ jobs:
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -311,29 +288,29 @@ jobs:
- run:
name: Install Python package
command: |
conda activate env
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
@@ -342,7 +319,7 @@ jobs:
- run:
name: Upload package
command: |
conda activate env
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
@@ -351,7 +328,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.10"
default: "3.9"
build_env:
type: string
default: ""
@@ -387,7 +364,7 @@ jobs:
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -415,7 +392,7 @@ jobs:
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
resource_class: large
steps:
- checkout
- run:
@@ -462,7 +439,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
@@ -484,10 +461,71 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -503,7 +541,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
@@ -529,7 +567,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
@@ -546,13 +584,59 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
@@ -564,14 +648,75 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:

View File

@@ -1,24 +0,0 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

View File

@@ -1,68 +0,0 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

View File

@@ -1,38 +0,0 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -1,78 +0,0 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -1,22 +0,0 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

View File

@@ -1,124 +0,0 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
build-jit:
description: 'Whether to build with JIT'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip cmake setuptools
uv pip install nanobind==2.4.0 \
numpy torch tensorflow unittest-xml-reporting
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.build-jit == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 13 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 13.0
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

View File

@@ -1,83 +0,0 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

View File

@@ -1,31 +0,0 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

View File

@@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -1,27 +0,0 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

View File

@@ -1,28 +0,0 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -1,114 +0,0 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,68 +1,20 @@
name: Build and Test
on: pull_request
permissions:
contents: read
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
python-version: 3.8
- name: Install dependencies
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

View File

@@ -1,206 +0,0 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
outputs:
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiples: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiples: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}

View File

@@ -1,10 +1,4 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -19,17 +19,12 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -88,26 +87,22 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -116,8 +111,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -146,12 +140,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -179,7 +167,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate not found, using default backend.")
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html),
@@ -91,7 +91,7 @@ Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```text
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64")
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"):
for dtype in ("float32", "float16"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
)
)
plt.close(fig)

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source,
source=source
ensure_row_contiguous=False,
)

View File

@@ -16,7 +16,7 @@ silicon computer is
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.10
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
@@ -39,7 +39,7 @@ requirements:
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
- Python >= 3.9
CPU-only (Linux)
@@ -55,7 +55,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.10
- Python >= 3.9
Troubleshooting
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -27,7 +27,6 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -112,7 +112,6 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x)
timeit(mx.compile(gelu), x)
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss

View File

@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array([-1.0]))
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -107,20 +107,8 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell

View File

@@ -241,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}

View File

@@ -294,11 +294,6 @@ class array {
return array_desc_->siblings;
}
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}};
return {{1}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
return {{1}, {0}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General,
};
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else {
topt = TernaryOpType::General;
}
@@ -69,8 +59,6 @@ inline void set_ternary_op_output_data(
b.flags());
}
break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General:
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -95,11 +94,7 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
auto output_dir = std::filesystem::temp_directory_path();
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();

View File

@@ -996,6 +996,131 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,

View File

@@ -46,6 +46,7 @@ void eig_impl(
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -48,15 +49,9 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -88,47 +88,4 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr = nullptr;
const void* b_mask_ptr = nullptr;
const void* out_mask_ptr = nullptr;
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool = false;
bool b_mask_bool = false;
bool out_mask_bool = false;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
@@ -423,6 +423,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -91,6 +91,7 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
@@ -107,9 +108,6 @@ void matmul_general(
} else if (out.dtype() == float64) {
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
@@ -130,6 +128,10 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;

View File

@@ -1,11 +1,10 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -14,35 +13,6 @@ namespace mlx::core {
namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
@@ -437,229 +407,6 @@ void _qmm_dispatch(
}
}
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
@@ -766,198 +513,115 @@ void _bs_qmm_dispatch(
}
}
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& lhs_indices = inputs[inputs.size() - 2];
auto& rhs_indices = inputs[inputs.size() - 1];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
&encoder](const array& arr) {
&temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
}
template <typename T, typename U>
@@ -1041,7 +705,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::Quantize::eval_cpu(
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -1100,47 +764,7 @@ void fast::Quantize::eval_cpu(
}
} else {
throw std::runtime_error(
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
}
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
}

View File

@@ -1,6 +1,5 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -10,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in simd/base_simd.h
// There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
@@ -201,15 +200,6 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
@@ -217,20 +207,14 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
auto out = Simd<T, N>(asd::max(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
// TODO add isnan
return asd::max(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
auto out = Simd<T, N>(asd::min(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
// TODO add isnan
return asd::min(a.value, b.value);
}
template <typename T, int N>
@@ -250,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
@@ -268,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined
if (any(exp < 0)) {
return 0;
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -171,11 +171,6 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -15,18 +15,6 @@ namespace mlx::core {
namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -39,7 +27,7 @@ struct StridedIterator {
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: stride_(stride), ptr_(ptr + offset * stride) {}
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -142,7 +130,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, nan_aware_less<T>);
std::stable_sort(st, ed);
src_it.step();
}
}
@@ -196,15 +184,6 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}
@@ -240,7 +219,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, nan_aware_less<T>);
std::nth_element(st, md, ed);
}
}
@@ -297,15 +276,6 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -83,6 +83,8 @@ void svd_impl(
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N;
src += N;
dst += N;

View File

@@ -77,8 +77,7 @@ struct Real {
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
return 1.0f / (1.0f + simd::exp(-x));
}
SINGLE()
};
@@ -108,73 +107,4 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -22,11 +22,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -51,19 +52,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -96,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Keep ptx around for inspection
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
@@ -177,6 +174,16 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Fetch and make available cutlass
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v4.1.0)
FetchContent_Populate(cutlass)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)

View File

@@ -30,20 +30,8 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
@@ -91,12 +79,13 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -130,7 +119,7 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.lock();
}
active_memory_ += buf->size;
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.

View File

@@ -6,33 +6,23 @@
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct Arange {
const T start;
const T step;
template <typename T, typename IdxT, int N_WRITES>
__global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
}
};
} // namespace cu
@@ -46,23 +36,19 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType);
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
encoder.add_kernel_node(
cu::arange<OutType, IdxT, N_WRITES>,
num_blocks,
block_dims,
0,
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out);
}
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
cache(/* capacity */ 128);
return cache;
}
@@ -382,19 +382,20 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}

View File

@@ -15,8 +15,8 @@ void copy_gpu_inplace(
int64_t offset_out,
CopyType ctype,
const Stream& s,
std::optional<array> dynamic_offset_in,
std::optional<array> dynamic_offset_out) {
const std::optional<array>& dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) {
if (out.size() == 0) {
return;
}
@@ -44,16 +44,6 @@ void copy_gpu_inplace(
strides_vec[0]);
} else {
if (dynamic_offset_in || dynamic_offset_out) {
if (!dynamic_offset_in) {
dynamic_offset_in = array(0, int64);
encoder.add_temporary(*dynamic_offset_in);
}
if (!dynamic_offset_out) {
dynamic_offset_out = array(0, int64);
encoder.add_temporary(*dynamic_offset_out);
}
encoder.set_input_array(*dynamic_offset_in);
encoder.set_input_array(*dynamic_offset_out);
copy_general_dynamic(
encoder,
ctype,
@@ -64,8 +54,8 @@ void copy_gpu_inplace(
shape_collapsed,
strides_vec[0],
strides_vec[1],
*dynamic_offset_in,
*dynamic_offset_out);
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
} else {
copy_general(
encoder,

View File

@@ -210,9 +210,6 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}

View File

@@ -14,6 +14,10 @@ namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -23,11 +27,11 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
}();
return use_graphs;
return cache_size;
}
} // namespace
@@ -64,8 +68,8 @@ Device::~Device() {
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs.
static thread_local int current = 0;
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
@@ -82,19 +86,11 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
graph.end_capture(enc.stream());
if (discard) {
return;
@@ -109,9 +105,6 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -193,43 +186,35 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
graph_cache_(cuda_graph_cache_size()) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
active_outputs_.push_back(id);
}
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
@@ -245,23 +230,6 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
@@ -288,38 +256,20 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
return;
}
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (use_cuda_graphs() && node_count_ > 0) {
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
}
graph_key_ += ".";
@@ -353,6 +303,7 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
@@ -364,7 +315,6 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}
void CommandEncoder::synchronize() {

View File

@@ -76,6 +76,9 @@ class CommandEncoder {
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
@@ -83,7 +86,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
int get_num_ops();
void maybe_commit();
void commit();
Device& device() {
@@ -98,9 +101,6 @@ class CommandEncoder {
void synchronize();
private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode {
cudaGraphNode_t node;
// K = kernel
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, this method is thread-safe.
// Make this device the current cuda device, required by some cuda calls.
void make_current();
CommandEncoder& get_command_encoder(Stream s);

View File

@@ -204,12 +204,6 @@ struct Power {
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
// Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) {
return 0;
}
}
while (exp) {
if (exp & 1) {
res *= base;

View File

@@ -6,6 +6,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
@@ -115,4 +116,15 @@ inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@@ -2,8 +2,6 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -259,8 +257,8 @@ struct Round {
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(abs(x)));
return (x < 0) ? y : 1 - y;
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
}
};
@@ -336,17 +334,4 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilities that work under both
// This file must not include any host-only code, utilies that work under both
// host and device can be put here.
//
// See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
}
};
// CUDA 11 does not have host side arithmetic operators for half types.
// CUDA 11 does not have host side arithmatic operators for half types.
template <typename T>
struct Limits<
T,

View File

@@ -1,56 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace mlx::core::distributed

View File

@@ -5,24 +5,18 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}
@@ -40,8 +34,7 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
@@ -52,14 +45,7 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
encoder.maybe_commit();
}
void finalize(Stream s) {

View File

@@ -3,12 +3,10 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
@@ -19,180 +17,104 @@ namespace cu {
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
// Manage cached cudaEvent_t objects.
class CudaEventPool {
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
void release(CudaEventHandle event) {
if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
}
private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
cudaEvent_t event_;
};
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
cuda_event_pool().release(std::move(event_));
}
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_);
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
}
void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_);
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
}
void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream);
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
}
bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
return cudaEventQuery(*event_) == cudaSuccess;
}
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
///////////////////////////////////////////////////////////////////////////////
// AtomicEvent implementations
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
AtomicEvent::AtomicEvent() {
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr);
@@ -201,17 +123,17 @@ AtomicEvent::AtomicEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
}
void AtomicEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait");
event_wait(atomic(), value);
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(to_atomic(buf_), value);
}
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
}
void AtomicEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
@@ -222,17 +144,17 @@ void AtomicEvent::wait(Stream s, uint64_t value) {
}
}
void AtomicEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal");
event_signal(atomic(), value);
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(to_atomic(buf_), value);
}
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
}
void AtomicEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
@@ -246,14 +168,14 @@ void AtomicEvent::signal(Stream s, uint64_t value) {
}
}
bool AtomicEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
return atomic()->load() >= value;
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return to_atomic(buf_)->load() >= value;
}
uint64_t AtomicEvent::value() const {
nvtx3::scoped_range r("cu::AtomicEvent::value");
return atomic()->load();
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return to_atomic(buf_)->load();
}
} // namespace cu
@@ -266,14 +188,14 @@ namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to AtomicEvent in following cases:
// to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::AtomicEvent> atomic;
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const {
return cuda || atomic;
return cuda || shared;
}
void ensure_created(Stream s, uint64_t signal_value) {
@@ -281,10 +203,10 @@ struct EventImpl {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
} else {
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
cuda = std::make_unique<cu::CudaEvent>();
}
}
};
@@ -303,7 +225,7 @@ void Event::wait() {
assert(value() == 1);
event->cuda->wait();
} else {
event->atomic->wait(value());
event->shared->wait(value());
}
}
@@ -314,7 +236,7 @@ void Event::wait(Stream s) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->atomic->wait(s, value());
event->shared->wait(s, value());
}
}
@@ -325,7 +247,7 @@ void Event::signal(Stream s) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->atomic->signal(s, value());
event->shared->signal(s, value());
}
}
@@ -336,9 +258,9 @@ bool Event::is_signaled() const {
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->is_signaled();
return event->cuda->recorded() && event->cuda->completed();
} else {
return event->atomic->is_signaled(value());
return event->shared->is_signaled(value());
}
}

View File

@@ -3,60 +3,49 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
class CudaEventHandle;
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
CudaEvent();
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
bool recorded() const {
return recorded_;
}
private:
CudaEventHandle event_;
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class AtomicEvent {
class SharedEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
AtomicEvent();
SharedEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
@@ -68,11 +57,7 @@ class AtomicEvent {
uint64_t value() const;
private:
Atomic* atomic() const {
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
};
} // namespace mlx::core::cu

View File

@@ -7,7 +7,7 @@ namespace mlx::core {
struct FenceImpl {
uint32_t count;
cu::AtomicEvent event;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {

View File

@@ -50,10 +50,8 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -87,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
@@ -128,47 +126,37 @@ CublasGemm::CublasGemm(
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
scale_type_ = dtype_to_cublas_type(dtype);
auto scale_type = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type_ = CUDA_R_32F;
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&a_op,
&op,
sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&b_op,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
@@ -203,7 +191,7 @@ CublasGemm::CublasGemm(
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
@@ -225,30 +213,14 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
cols,
rows,
cols,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
@@ -256,19 +228,11 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha) {
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
@@ -276,13 +240,7 @@ void CublasGemm::run(
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
@@ -355,16 +313,6 @@ void CublasGemm::execute(
}
}
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
@@ -381,12 +329,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
alpha_ptr,
b, // a and b are swapped
a_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
beta_ptr,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,

View File

@@ -55,8 +55,6 @@ class CublasGemm {
int32_t batch_count,
int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run(
cu::CommandEncoder& encoder,
array& out,
@@ -64,8 +62,7 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha = 1.0f);
const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
@@ -88,8 +85,7 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha);
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
@@ -115,7 +111,6 @@ class CublasGemm {
uint64_t M_;
uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,8 +13,7 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha) {
const Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -28,8 +27,7 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr,
alpha);
nullptr);
a_it.step();
b_it.step();
}

View File

@@ -154,8 +154,7 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha) {
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
@@ -227,8 +226,7 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr,
alpha);
nullptr);
}
void CublasGemm::run_batched(

View File

@@ -0,0 +1,396 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu {
namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
}
template <bool transpose, typename Tiler>
constexpr int get_feature_size(Tiler smem) {
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
return (feature_size >= 64) ? 64 : feature_size;
}
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <
int num_threads,
int itemsize,
bool transpose,
int copy_bits,
typename Copier,
typename Tiler>
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
constexpr int num_elements = copy_bits / itemsize;
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
constexpr int copies_per_feature = feature_size / num_elements;
using E = Int<num_elements>;
using C = Int<copies_per_feature>;
using R = Int<num_threads / copies_per_feature>;
using ThreadLayout = std::conditional_t<
transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
}
template <int rasterization_factor>
__device__ inline int2 raster_tile(int x, int y) {
return {
x / rasterization_factor,
(x % rasterization_factor) + y * rasterization_factor};
}
template <
typename T,
typename SLayoutA,
typename SLayoutB,
typename SLayoutC,
typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
const int blocks_m = ceil_div(M, BM);
const int blocks_n = ceil_div(N, BN);
// Exit early if the tile is OOB
if (tile.x >= blocks_m || tile.y >= blocks_n) {
return;
}
// Make the full tensors
Tensor full_A =
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
Tensor full_B =
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
Tensor full_C =
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
// Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[];
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
T* shared_B_ptr =
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
// Get the copies that correspond to this thread
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
// Start fetches
int k_tile_count = size<2>(local_A);
int k_tile_next = 0;
CUTE_UNROLL
for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
}
// Get the MMA that corresponds to this thread and allocate registers
auto thread_mma = mma.get_slice(threadIdx.x);
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
// Make shared to register copies
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
constexpr auto RPIPE = size<2>(mma_shared_A);
int smem_read = 0;
int smem_write = PIPE - 1;
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
// Start the register pipeline
if constexpr (RPIPE > 1) {
cp_async_wait<PIPE - 2>();
__syncthreads();
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
}
CUTE_NO_UNROLL
while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == RPIPE - 1) {
mma_A_src_p = mma_A_src(_, _, _, smem_read);
mma_B_src_p = mma_B_src(_, _, _, smem_read);
cp_async_wait<PIPE - 2>();
__syncthreads();
}
// Load the next register tile
auto k_block_next = (k_block + 1) % RPIPE;
copy(
s2r_copy_a,
mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) {
copy(
copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy(
copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
smem_write = smem_read;
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
}
gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
}
}
copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
//
// print(s2r_atom_a); print("\n");
// }
}
} // namespace
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// Tile definitions
auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Thread definitions
using TM = Int<2>;
using TN = Int<2>;
using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
auto async_copy_op =
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
auto tiled_mma = make_tiled_mma(
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node(
kernel,
grid,
block,
smem_size,
a.data<bf16>(),
b.data<bf16>(),
out.data<bf16>(),
SA,
SB,
SC,
tiled_copy_a,
tiled_copy_b,
tiled_copy_c,
tiled_mma,
M,
N,
K);
} else {
throw std::runtime_error("Only bfloat16 supported");
}
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -13,37 +13,6 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -55,8 +24,7 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
@@ -64,11 +32,12 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
@@ -138,7 +107,7 @@ void gemv(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;

View File

@@ -0,0 +1,69 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/dtype_utils.h"
#include <iostream>
namespace mlx::core::cu {
namespace {
template <typename Kernel>
static void configure_smem(Kernel kernel, int SM) {
static bool done = false;
if (done) {
return;
}
std::cout << "configuring" << std::endl;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
cudaFuncSetAttribute(
kernel,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
done = true;
}
} // namespace
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int PIPE = 3;
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
constexpr int WM = 2;
constexpr int WN = 4;
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
configure_smem(kernel, SM);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
WM * WN * WARP_SIZE,
SM,
a.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
N,
K);
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(axes_);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(axes_);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(

View File

@@ -67,11 +67,9 @@ const std::string& cccl_dir() {
return path.string();
}
// Finally check the environment variable.
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
path = env;
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
@@ -99,30 +97,6 @@ const std::filesystem::path& ptx_cache_dir() {
return cache;
}
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
@@ -133,7 +107,7 @@ bool read_cached_ptx(
return false;
}
auto ptx_path = get_ptx_path(cache_dir, module_name);
auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) {
@@ -146,7 +120,7 @@ bool read_cached_ptx(
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
@@ -168,26 +142,16 @@ void write_cached_ptx(
return;
}
auto ptx_path = get_ptx_path(cache_dir, module_name);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size());
}
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
// Write the generated code
std::ofstream source_file(ptx_path.replace_extension(".cu"));
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
@@ -331,8 +295,7 @@ void load_module(
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -349,7 +312,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_tuple(kernel, false, 0);
kernels[name] = std::make_pair(kernel, false);
}
}
@@ -393,7 +356,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
@@ -404,22 +367,14 @@ std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
// If it is the first time we run this kernel then configure it. Do it only
// once!
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(kernel);
configure_kernel(it->second.first);
}
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
it->second.second = true;
}
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
return it->second.first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -46,11 +46,6 @@ struct KernelArgs {
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
}
template <typename T>
void append(const std::vector<T>& vec) {
append(SmallVector<T>(vec.begin(), vec.end()));
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) {
@@ -99,13 +94,10 @@ class JitModule {
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,10 +35,12 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
int work_per_thread) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
// This file includes host-only utilities for writing CUDA kernels, the
// difference from backend/cuda/device/utils.cuh is that the latter file only
// include device-only code.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/device/utils.cuh is that the latter file only include
// device-only code.
#pragma once
@@ -120,28 +120,19 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024);
int work_per_thread = 1);
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
return get_launch_args(
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@@ -2,15 +2,11 @@
#pragma once
#include "mlx/utils.h"
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
#include <fmt/format.h>
namespace mlx::core {
template <
@@ -31,14 +27,6 @@ class LRUCache {
}
}
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const {
return map_.size();
}
@@ -88,14 +76,6 @@ class LRUCache {
return {it->second, false};
}
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
@@ -126,9 +106,6 @@ class LRUCache {
}
}
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_;
map_type map_;
size_t capacity_;

View File

@@ -3,7 +3,9 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/gemms/simple_gemm.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
@@ -14,6 +16,11 @@ namespace mlx::core {
namespace {
int get_test_gemm() {
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
return t;
}
std::tuple<bool, int64_t, array>
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
@@ -29,80 +36,6 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
}
}
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -123,6 +56,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
@@ -132,8 +68,70 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
gemm_and_bias(
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
cu::simple_gemm(a, b, out, M, N, K, encoder);
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -158,29 +156,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
@@ -222,7 +197,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt with AddMM settings
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),

View File

@@ -24,6 +24,8 @@ namespace mlx::core {
}
NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
@@ -40,6 +42,7 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)

View File

@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;

View File

@@ -1,19 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -1,83 +0,0 @@
#pragma once
struct __nv_fp8_e8m0 {
__device__ __nv_fp8_e8m0(float x) {
if (!std::isfinite(x)) {
__x = 0xFF;
return;
}
if (x < 0.0f) {
__x = 0x00;
return;
}
float le = std::log2f(x);
int n = static_cast<int>(std::nearbyintf(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
__x = static_cast<uint8_t>(n + 127);
}
__device__ operator float() {
if (__x == 0xFF) {
return std::numeric_limits<float>::quiet_NaN();
}
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
}
uint8_t __x{0};
};
struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
__x = 0x7;
return;
}
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
x = std::abs(x);
if (x > 5.0f) {
__x = 0x7;
} else if (x >= 3.5f) {
__x = 0x6;
} else if (x > 2.5f) {
__x = 0x5;
} else if (x >= 1.75f) {
__x = 0x4;
} else if (x > 1.25f) {
__x = 0x3;
} else if (x >= 0.75f) {
__x = 0x2;
} else if (x > 0.25f) {
__x = 0x1;
} else {
__x = 0x0;
}
__x |= sign_bit;
}
__device__ operator float() {
static const float LUT[16] = {
0.0f,
0.5f,
1.0f,
1.5f,
2.0f,
3.0f,
4.0f,
6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
return LUT[__x];
}
uint8_t __x{0};
};

View File

@@ -1,216 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
if constexpr (bits == 8) {
return float(*(__nv_fp8_e4m3*)(&x));
} else {
return float(*(__nv_fp4_e2m1*)(&x));
}
}
};
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
return;
}
float w_thread = w[index];
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto scale = float(((ScaleType*)(scales))[gindex]);
out += oindex;
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}
} // namespace cu
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not quantize input with type float64.");
}
});
}
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
constexpr int uint8_per_uint32 = 4;
int packs_per_int = 8 / bits;
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_dequantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_dequantize<T, 16, 4, false>;
}
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
}
});
}
} // namespace mlx::core

View File

@@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
} // namespace
void fast::Quantize::eval_gpu(
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("Quantize::eval_gpu");
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
@@ -57,30 +57,23 @@ void fast::Quantize::eval_gpu(
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
}
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
}
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}

View File

@@ -24,22 +24,4 @@ void affine_dequantize(
cu::CommandEncoder& enc,
const Stream& s);
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -181,47 +181,6 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu
inline auto output_grid_for_col_reduce(
@@ -247,7 +206,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -271,55 +230,12 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
kernel, grid, blocks, 0, indata, out.data<U>(), args);
});
});
});
}
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -342,13 +258,6 @@ void col_reduce(
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -7,6 +7,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -81,8 +83,7 @@ struct RowReduceArgs {
};
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -90,8 +91,8 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
AlignedVector<T, N> vals[M];
AlignedVector<U, M> accs;
T vals[M][N];
U accs[M];
for (int i = 0; i < M; i++) {
accs[i] = init;
}
@@ -100,31 +101,43 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size + block.thread_rank() * N;
in += start_row * size;
out += start_row;
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
vals[k] = load_vector<N>(in + k * size, 0);
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
in += block.size() * N;
}
if (final_offset < size) {
for (int k = 0; k < M; k++) {
for (int i = 0; i < N; i++) {
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
? in[k * size + i]
: cast_to<T>(init);
}
}
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + final_offset,
vals[k],
size,
cast_to<T>(init));
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
@@ -132,11 +145,13 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
}
__shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
store_vector(out, 0, accs);
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
@@ -146,10 +161,17 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
}
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
__global__ void row_reduce_looped(
const T* in,
T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
@@ -163,60 +185,36 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
const size_t full_blocks = args.row_size / (block.size() * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
{
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
cast_to<T>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
__shared__ U shared_accumulators[32];
@@ -236,6 +234,8 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
@@ -250,15 +250,14 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
@@ -268,7 +267,6 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -284,6 +282,8 @@ void row_reduce_looped(
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -295,27 +295,34 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
});
encoder.add_kernel_node(
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
});
});
}

View File

@@ -103,21 +103,15 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
const int* offset,
int offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t offset_stride,
int n_head,
int64_t n_batch,
uint3 pos,
uint3 dims) {
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
float L = scale * static_cast<float>(pos.y + offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -129,19 +123,20 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -172,8 +167,7 @@ __global__ void rope(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t offset_stride,
int n_head,
int64_t n_batch,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
@@ -188,13 +182,12 @@ __global__ void rope(
rope_impl<T, traditional, forward>(
in,
out,
offset,
*offset,
inv_freq,
scale,
strides,
out_strides,
offset_stride,
n_head,
n_batch,
pos,
dims);
}
@@ -209,8 +202,7 @@ __global__ void rope_freqs(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t offset_stride,
int n_head,
int64_t n_batch,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
@@ -225,13 +217,12 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>(
in,
out,
offset,
*offset,
inv_freq,
scale,
strides,
out_strides,
offset_stride,
n_head,
n_batch,
pos,
dims);
}
@@ -254,28 +245,23 @@ void RoPE::eval_gpu(
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
size_t mat_size = in.shape(-2) * in.shape(-1);
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < D) {
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -316,7 +302,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
@@ -333,7 +319,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -350,7 +336,7 @@ void RoPE::eval_gpu(
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -368,14 +354,10 @@ void RoPE::eval_gpu(
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -389,20 +371,15 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
in.size() / mat_size,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -415,8 +392,7 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
in.size() / mat_size,
dims);
}
});

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
@@ -45,7 +46,6 @@ __global__ void kernel_sdpav_1pass(
const T* K,
const T* V,
T* O,
const T* sinks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
@@ -65,7 +65,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * M_LOG2E;
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
@@ -108,12 +108,8 @@ __global__ void kernel_sdpav_1pass(
o[i] = 0.f;
}
U max_score = Limits<U>::finite_min();
U max_score = -INFINITY;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -171,7 +167,7 @@ __global__ void kernel_sdpav_1pass(
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
sum_exp_score = __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
@@ -197,7 +193,6 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
const T* sinks,
float* partials,
float* sums,
float* maxs,
@@ -273,12 +268,8 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f;
}
U max_score = Limits<U>::finite_min();
U max_score = -1e9;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -419,7 +410,7 @@ __global__ void kernel_sdpav_2pass_2(
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
@@ -472,14 +463,10 @@ void sdpa_vector_1pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks) {
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o);
cu::AttnParams params{
@@ -502,7 +489,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -517,7 +504,6 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params);
});
});
@@ -532,8 +518,7 @@ void sdpa_vector_2pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks) {
bool do_causal_ = false) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
@@ -574,7 +559,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -585,10 +570,6 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
@@ -604,7 +585,6 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
@@ -647,16 +627,15 @@ void sdpa_vector_fallback(
const array& v,
const float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks) {
bool do_causal_ = false) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal, sinks);
s, encoder, q, k, v, scale, o, do_causal_);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal, sinks);
s, encoder, q, k, v, scale, o, do_causal_);
}
}
@@ -712,7 +691,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(inputs.size());
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -724,16 +703,6 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -771,6 +740,10 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
@@ -779,26 +752,22 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ o.size() == o.shape(3),
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
o.size(),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here

View File

@@ -1,11 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/dtype_utils.h"
#include <numeric>
@@ -30,7 +27,8 @@ void concatenate_gpu(
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto concurrent = cu::get_command_encoder(s).concurrent_context();
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
@@ -40,71 +38,4 @@ void concatenate_gpu(
}
}
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s) {
Dtype dtype = indices.dtype();
int nidx = axes.size();
std::string module_name =
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
std::string kernel_name = fmt::format(
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
dtype_to_cuda_type(dtype),
nidx);
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::string source = R"(
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
template <typename T, int NIDX>
__global__ void compute_dynamic_offset(
const T* indices,
int64_t* offset,
const __grid_constant__ Strides strides,
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
int64_t acc = 0;
#pragma unroll
for (int i = 0; i < NIDX; ++i) {
acc += indices[i] * strides[axes[i]];
}
*offset = acc;
}
} // namespace mlx::core::cu
)";
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
cu::KernelArgs args;
args.append(indices);
args.append(offset);
args.append_ndim(strides);
args.append(axes);
auto kernel = mod.get_kernel(kernel_name);
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
return offset;
}
} // namespace mlx::core

View File

@@ -9,7 +9,7 @@
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cub/device/device_segmented_radix_sort.cuh>
#include <cub/device/device_segmented_sort.cuh>
#include <cassert>
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.add_temporary(discard);
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
nullptr,
size,
in.data<Type>(),
@@ -90,8 +90,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -106,7 +104,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(),
size,
in.data<Type>(),
@@ -117,12 +115,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
} else {
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
nullptr,
size,
in.data<Type>(),
@@ -131,8 +127,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -140,7 +134,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(),
size,
in.data<Type>(),
@@ -149,8 +143,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
}
} else {

View File

@@ -4,95 +4,189 @@
namespace mlx::core::cu {
template <typename T, int BM, int BN, int BK, int WM, int WN>
__device__ inline void gemm_ab_t(
RegisterTile<float, BM / WM, BN / WN>& C,
SharedTile<T, BM, BK>& As,
SharedTile<T, BN, BK>& Bs,
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
RegisterTile<T, BM / WM, 16> A[2];
RegisterTile<T, BN / WN, 16> B[2];
rloader_a.load(A[0], As.base_addr(), 0);
rloader_b.load(B[0], Bs.base_addr(), 0);
MLX_UNROLL
for (int k = 1; k < BK / 16; k++) {
rloader_a.load(A[k & 1], As.base_addr(), k);
rloader_b.load(B[k & 1], Bs.base_addr(), k);
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
}
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
}
/**
* An example gemm written with the utils.
*
* Computes A @ B.T when A and B are all aligned with the block sizes.
*/
template <typename T, int BM, int BN, int BK>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 2;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
// constexpr int NUM_WARPS = WM * WN;
// constexpr int WARP_STEP_M = BM / WM;
// constexpr int WARP_STEP_N = BN / WN;
//
// // Precompute some offsets for each thread
// const int warpid = threadIdx.x / 32;
// const int laneid = threadIdx.x % 32;
// const int wm = warpid / WN;
// const int wn = warpid % WN;
// const int offset_m = wm * WARP_STEP_M;
// const int offset_n = wn * WARP_STEP_N;
//
// // Allocate shared memory
// extern __shared__ char shmem[];
// SharedTile<T, BM, BK>(&as)[PIPE] =
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
// SharedTile<T, BN, BK>(&bs)[PIPE] =
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
//
// // Move the global pointers to the tile
// a += blockIdx.y * BM * K;
// b += blockIdx.x * BN * K;
// y += blockIdx.y * BM * N + blockIdx.x * BN;
//
// // Make the loaders to/from SMEM
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
//
// // Start the SM pipeline
// MLX_UNROLL
// for (int i = 0; i < PIPE - 1; i++) {
// sloader_a.load_async(as[i].base_addr());
// sloader_b.load_async(bs[i].base_addr());
// cp_async_commit();
// sloader_a.next();
// sloader_b.next();
// }
//
// // Allocate and zero the MMA accumulator
// RegisterTile<float, BM / WM, BN / WN> C;
// C.fill(0);
//
// // Matmul loop
// int num_blocks = K / BK;
// int sread = 0;
// int swrite = PIPE - 1;
// for (int i = 0; i < num_blocks; i++) {
// cp_async_wait<PIPE - 1>();
//
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
// C, as[sread], bs[sread], rloader_a, rloader_b);
//
// sloader_a.load_async(as[swrite].base_addr());
// sloader_b.load_async(bs[swrite].base_addr());
// cp_async_commit();
// sloader_a.next(i + PIPE < num_blocks);
// sloader_b.next(i + PIPE < num_blocks);
//
// swrite = sread;
// sread = (sread + 1) % PIPE;
// }
//
// C.store_global(y, N, offset_m, offset_n);
// }
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
__global__ __launch_bounds__(
WM* WN* WARP_SIZE,
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int NUM_WARPS = WM * WN;
constexpr int WARP_STEP_M = BM / WM;
constexpr int WARP_STEP_N = BN / WN;
// Precompute some offsets for each thread
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int wm = warpid / WN;
const int wn = warpid % WN;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[2] =
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
SharedTile<T, BM, BK>(&as)[PIPE] =
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[PIPE] =
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
// Move the global pointers to the tile
a += blockIdx.y * BM * K;
b += blockIdx.x * BN * K;
y += blockIdx.y * BM * N + blockIdx.x * BN;
// Zero the accumulators
C.fill(0);
// Make the loaders to/from SMEM
using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
const int scol =
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
a += srow * K + scol;
b += srow * K + scol;
uint32_t sm_offsets[PIPE][2];
MLX_UNROLL
for (int s = 0; s < PIPE; s++) {
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
}
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
// Start the SM pipeline
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
cp_async_commit();
int tic = 0;
for (int k_block = BK; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL
for (int s = 0; s < PIPE - 1; s++) {
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
a += sloader::STEP_ROWS * K;
b += sloader::STEP_ROWS * K;
}
tic ^= 1;
cp_async_commit();
}
// Empty the pipeline
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
// Allocate and zero the MMA accumulator
RegisterTile<float, BM / WM, BN / WN> C;
C.fill(0);
mma_t(C, A, B);
// Matmul loop
int num_blocks = K / BK;
int sread = 0;
int swrite = PIPE - 1;
for (int i = 0; i < num_blocks; i++) {
cp_async_wait<PIPE - 1>();
gemm_ab_t<T, BM, BN, BK, WM, WN>(
C, as[sread], bs[sread], rloader_a, rloader_b);
if (false) {
MLX_UNROLL
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
a += sloader::STEP_ROWS * K;
b += sloader::STEP_ROWS * K;
}
}
cp_async_commit();
swrite = sread;
sread = (sread + 1) % PIPE;
}
C.store_global(y, N, offset_m, offset_n);

View File

@@ -223,59 +223,10 @@ struct RegisterTile {
}
};
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
struct SharedTile {
using value_type = T;
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
@@ -317,23 +268,26 @@ struct SharedTile {
}
}
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
__device__ static inline uint32_t offset(int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr +
sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols);
const uint32_t addr = sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle);
} else {
return ptr + sizeof(T) * (row * COLS + col);
return sizeof(T) * (row * COLS + col);
}
}
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
return ptr + offset(row, col);
}
// Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col);
@@ -364,6 +318,76 @@ struct SharedTile {
}
};
template <int NUM_WARPS, typename Tile>
struct SharedTileLoader {
using T = typename Tile::value_type;
static constexpr int NUM_THREADS = NUM_WARPS * 32;
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const T* x_;
int N_;
uint32_t offset_;
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x_ += row * N + col * ELEMENTS_PER_LOAD;
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
}
__device__ inline void load_async(uint32_t base_address) {
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async<16>(
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
x_ + i * STEP_ROWS * N_);
}
}
__device__ inline void next() {
x_ += Tile::COLS;
}
};
template <typename Tile>
struct RegisterTileLoader {
using T = typename Tile::value_type;
uint32_t offset_[Tile::COLS / 16];
__device__ RegisterTileLoader(int offset_row, int laneid) {
const int row = offset_row + laneid & 15;
const int col = (laneid >> 4) << 3;
MLX_UNROLL
for (int i = 0; i < Tile::COLS / 16; i++) {
offset_[i] = Tile::offset(row, col + i * 16);
}
}
template <typename T, int ROWS, int COLS>
__device__ inline void
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
x.data[i * TILES_X + j].load(
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
}
}
}
};
/**
* Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately.

View File

@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
} else if constexpr (N == 8) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int2*>(x)));
} else if constexpr (N == 4) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x)));
}
#endif

View File

@@ -156,25 +156,7 @@ void ternary_op_gpu_inplace(
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::VectorVectorVector ||
topt == TernaryOpType::ScalarScalarScalar) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} else {
if (topt == TernaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
@@ -243,6 +225,23 @@ void ternary_op_gpu_inplace(
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
});
}

Some files were not shown because too many files have changed in this diff Show More