Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
291 changed files with 4413 additions and 13295 deletions

View File

@@ -18,14 +18,13 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "26.0.0" xcode: "16.2.0"
resource_class: m4pro.medium resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install name: Install
command: | command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9 brew install python@3.9
brew install doxygen brew install doxygen
python3.9 -m venv env python3.9 -m venv env
@@ -90,8 +89,7 @@ jobs:
command: | command: |
uv venv uv venv
uv pip install cmake uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ uv pip install -e ".[dev]" -v
uv pip install -e ".[dev]" -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -120,7 +118,7 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "26.0.0" default: "16.2.0"
macosx_deployment_target: macosx_deployment_target:
type: string type: string
default: "" default: ""
@@ -128,13 +126,12 @@ jobs:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \ HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv brew install openmpi uv
- run: - run:
@@ -199,7 +196,7 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v uv pip install -e .
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \ uv run --no-project python -m xmlrunner discover \
@@ -225,20 +222,15 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf - curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
uv venv uv venv
uv pip install cmake CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v uv pip install -e ".[dev]" -v
- run: - run:
name: Run Python tests name: Run Python tests
@@ -246,23 +238,12 @@ jobs:
source .venv/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run: - run:
name: CCache report name: CCache report
command: | command: |
ccache --show-stats ccache --show-stats
ccache --zero-stats ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup ccache --cleanup
- save_cache: - save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }} key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
@@ -276,7 +257,7 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "26.0.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
@@ -285,7 +266,7 @@ jobs:
default: "" default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: m4pro.medium resource_class: m2pro.medium
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
@@ -293,15 +274,11 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
xcodebuild -downloadComponent MetalToolchain brew install python@<< parameters.python_version >>
mkdir -p ~/miniconda3 brew install openmpi
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh python<< parameters.python_version >> -m venv env
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 source env/bin/activate
rm ~/miniconda3/miniconda.sh pip install --upgrade pip
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.4.0
pip install --upgrade setuptools pip install --upgrade setuptools
@@ -311,19 +288,19 @@ jobs:
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
conda activate env source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
conda activate env source env/bin/activate
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
conda activate env source env/bin/activate
python setup.py clean --all python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when: - when:
@@ -333,7 +310,7 @@ jobs:
- run: - run:
name: Build common package name: Build common package
command: | command: |
conda activate env source env/bin/activate
python setup.py clean --all python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when: - when:
@@ -342,7 +319,7 @@ jobs:
- run: - run:
name: Upload package name: Upload package
command: | command: |
conda activate env source env/bin/activate
twine upload dist/* twine upload dist/*
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
@@ -415,7 +392,7 @@ jobs:
default: "" default: ""
machine: machine:
image: ubuntu-2204:current image: ubuntu-2204:current
resource_class: xlarge resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
@@ -462,7 +439,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "15.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test: - cuda_build_and_test:
matrix: matrix:
@@ -487,7 +464,68 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -529,7 +567,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "15.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
- cuda_build_and_test: - cuda_build_and_test:
@@ -548,7 +586,53 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
@@ -567,7 +651,68 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:

View File

@@ -19,17 +19,12 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -26,7 +26,6 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -88,21 +87,22 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL AND NOT METAL_LIB)
find_library(METAL_LIB Metal) message(STATUS "Metal not found. Unable to build GPU")
find_library(FOUNDATION_LIB Foundation) set(MLX_BUILD_METAL OFF)
find_library(QUARTZ_LIB QuartzCore) set(MLX_METAL_DEBUG OFF)
if(METAL_LIB) elseif(MLX_BUILD_METAL)
message(STATUS "Metal found ${METAL_LIB}") message(STATUS "Building METAL sources")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,8 +111,7 @@ if(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
@@ -141,12 +140,6 @@ if(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32) if(WIN32)
if(MSVC) if(MSVC)
# GGUF does not build with MSVC. # GGUF does not build with MSVC.
@@ -174,7 +167,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else() else()
message(STATUS "Accelerate not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()

View File

@@ -11,31 +11,31 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
- **Composable function transformations**: MLX supports composable function - **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization, transformations for automatic differentiation, automatic vectorization,
and computation graph optimization. and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only - **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed. materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed - **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive. slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices - **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU). (currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks - **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory. is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported Operations on MLX arrays can be performed on any of the supported
device types without transferring data. device types without transferring data.
MLX is designed by machine learning researchers for machine learning MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient researchers. The framework is intended to be user-friendly, but still efficient
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
```text ```
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,7 +142,9 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype) c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -161,7 +163,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64") dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -185,7 +187,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"): for dtype in ("float32", "float16"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,5 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

View File

@@ -18,7 +18,6 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -127,8 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source
ensure_row_contiguous=False,
) )
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
@@ -139,6 +138,7 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes=[a.shape], output_shapes=[a.shape],
output_dtypes=[a.dtype], output_dtypes=[a.dtype],
ensure_row_contiguous=False,
) )
return outputs[0] return outputs[0]

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management python/memory_management
python/nn python/nn
python/optimizers python/optimizers

View File

@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y apt-get update -y
apt-get -y install cuda-toolkit-12-9 apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -1,9 +0,0 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,4 +13,3 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -27,7 +27,6 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,7 +50,6 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state, destination={}) state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", state) mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors")) state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(gelu, x) timeit(nn.gelu, x)
timeit(mx.compile(gelu), x) timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y): def fun(x, y):
z = x + y z = x + y
state.append(z) state.append(z)
return mx.exp(z) return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0)) fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)] # Prints [array(3, dtype=float32)]

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss

View File

@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = tree_flatten(model.parameters(), destination={}) params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python .. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn") imported_abs = mx.import_function("fun.mlxfn")
# Ok # Ok
out, = imported_abs(mx.array([-1.0])) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -107,20 +107,8 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
.. code-block:: shell Note, unlike NumPy, updates to the same location are nondeterministic:
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell .. code-block:: shell

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}}; return {{1}, {0}, {0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; return {{1}, {0}, {0}, {0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -11,8 +11,6 @@ namespace mlx::core {
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector, VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General, General,
}; };
@@ -27,14 +25,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous && (a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) { c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector; topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -69,8 +59,6 @@ inline void set_ternary_op_output_data(
b.flags()); b.flags());
} }
break; break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General: case TernaryOpType::General:
// Try to donate an input which is row_contiguous // Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
} }
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T> template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) { inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index)); vec.erase(std::next(vec.begin(), index));

View File

@@ -15,7 +15,6 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core { namespace mlx::core {
@@ -95,11 +94,7 @@ void* compile(
kernel_file_name = kernel_name; kernel_file_name = kernel_name;
} }
auto output_dir = auto output_dir = std::filesystem::temp_directory_path();
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so"; std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string(); auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -162,12 +157,10 @@ inline void build_kernel(
#endif #endif
// Start the kernel // Start the kernel
os << "void " << kernel_name os << "void " << kernel_name << "(void** args) {" << std::endl;
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list // Skip constants from the input list
if (is_constant(i)) { if (is_constant(i)) {
@@ -182,8 +175,8 @@ inline void build_kernel(
<< "];" << std::endl; << "];" << std::endl;
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides[" os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< strides_index++ << "];" << std::endl; << "];" << std::endl;
} }
} }
@@ -193,8 +186,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl; << "*)args[" << cnt++ << "];" << std::endl;
} }
// Add output size // Add output strides and shape to extract the indices.
if (contiguous) { if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
} }
@@ -293,8 +288,17 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] = auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments. // Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) { if (is_constant_(i)) {
continue; continue;
@@ -302,6 +306,9 @@ void Compiled::eval_cpu(
const auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
} }
// Get the kernel name from the lib // Get the kernel name from the lib
@@ -336,20 +343,16 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
if (contiguous) { if (!contiguous) {
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr); auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch([fun, encoder.dispatch([fun,
args = std::move(args), args = std::move(args),
strides = std::move(strides), strides = std::move(strides),
shape = std::move(shape)]() mutable { shape = std::move(shape)]() mutable { fun(args.data()); });
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -88,47 +88,4 @@ void matmul<double>(
} }
} }
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd) INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri) INSTANTIATE_LAPACK_REAL(trtri)

View File

@@ -108,9 +108,6 @@ void matmul_general(
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul_dispatch<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
@@ -131,6 +128,10 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return; return;

View File

@@ -1,5 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -11,35 +13,6 @@ namespace mlx::core {
namespace { namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) { inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
} }
@@ -434,231 +407,6 @@ void _qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T> template <typename T>
void _bs_qmm_dispatch_typed( void _bs_qmm_dispatch_typed(
array& out, array& out,
@@ -765,198 +513,115 @@ void _bs_qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace } // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) { auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy_cpu(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous(x_pre); auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre); auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), group_size_ = group_size_,
scales = array::unsafe_weak_copy(scales), bits_ = bits_,
biases = array::unsafe_weak_copy(biases), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
bits_ = bits_, });
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& lhs_indices = inputs[inputs.size() - 2]; auto& biases_pre = inputs[3];
auto& rhs_indices = inputs[inputs.size() - 1]; auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& encoder = cpu::get_command_encoder(stream()); std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(), auto ensure_row_contiguous_last_dims = [s = stream(),
&encoder](const array& arr) { &temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1]; auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) { if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr; return arr;
} else { } else {
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, arr_cpy, CopyType::General, s); copy_cpu(arr, temps.back(), CopyType::General, s);
encoder.add_temporary(arr_cpy); return temps.back();
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous_last_dims(x_pre); auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre); auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices); encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices); encoder.set_input_array(rhs_indices);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) { encoder.dispatch([out = array::unsafe_weak_copy(out),
auto biases = ensure_row_contiguous_last_dims(inputs[3]); x = array::unsafe_weak_copy(x),
encoder.set_input_array(biases); w = array::unsafe_weak_copy(w),
encoder.dispatch([out = array::unsafe_weak_copy(out), scales = array::unsafe_weak_copy(scales),
x = array::unsafe_weak_copy(x), biases = array::unsafe_weak_copy(biases),
w = array::unsafe_weak_copy(w), lhs_indices = array::unsafe_weak_copy(lhs_indices),
scales = array::unsafe_weak_copy(scales), rhs_indices = array::unsafe_weak_copy(rhs_indices),
biases = array::unsafe_weak_copy(biases), group_size_ = group_size_,
lhs_indices = array::unsafe_weak_copy(lhs_indices), bits_ = bits_,
rhs_indices = array::unsafe_weak_copy(rhs_indices), transpose_ = transpose_]() mutable {
group_size_ = group_size_, _bs_qmm_dispatch(
bits_ = bits_, out,
transpose_ = transpose_]() mutable { x,
_bs_qmm_dispatch( w,
out, scales,
x, biases,
w, lhs_indices,
scales, rhs_indices,
biases, group_size_,
lhs_indices, bits_,
rhs_indices, transpose_);
group_size_, });
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
} }
template <typename T, typename U> template <typename T, typename U>
@@ -1040,7 +705,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
} }
void fast::Quantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) { auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -1099,7 +764,7 @@ void fast::Quantize::eval_cpu(
} }
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[fast::Quantize::eval_cpu] Only supports floating point inputs"); "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
} }
}); });
} }

View File

@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@@ -9,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in simd/base_simd.h // There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates // __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library // to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15 // higher than it should be even on macOS < 15
@@ -234,7 +234,6 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N> template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) { Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) { if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value)); return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) { } else if constexpr (sizeof(T1) == 2) {
@@ -252,13 +251,9 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value); return asd::pow(base.value, exp.value);
} else { } else {
Simd<T, N> res = 1; Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined while (any(exp)) {
if (any(exp < 0)) { res = select(exp & 1, res * base, res);
return 0; base = select(exp, base * base, base);
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1; exp = exp >> 1;
} }
return res; return res;

View File

@@ -15,18 +15,6 @@ namespace mlx::core {
namespace { namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@@ -142,7 +130,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, nan_aware_less<T>); std::stable_sort(st, ed);
src_it.step(); src_it.step();
} }
} }
@@ -196,15 +184,6 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -240,7 +219,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, nan_aware_less<T>); std::nth_element(st, md, ed);
} }
} }
@@ -297,15 +276,6 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -81,7 +81,9 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack). // Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M; const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N"; auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned. // Will contain the number of singular values after the call has returned.
int ns = 0; int ns = 0;
@@ -89,20 +91,30 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info; int info;
// Compute workspace size. // Compute workspace size.
gesdd<T>( gesvdx<T>(
/* jobz = */ jobz, /* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ nullptr, /* a = */ nullptr,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr, /* s = */ nullptr,
/* u = */ nullptr, /* u = */ nullptr,
/* ldu = */ &ldu, /* ldu = */ &ldu,
@@ -124,13 +136,20 @@ void svd_impl(
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
gesdd<T>( gesvdx<T>(
/* jobz = */ jobz, /* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in_ptr + M * N * i, /* a = */ in_ptr + M * N * i,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i, /* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U. // According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -148,6 +167,13 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info; ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);

View File

@@ -77,8 +77,7 @@ struct Real {
struct Sigmoid { struct Sigmoid {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); return 1.0f / (1.0f + simd::exp(-x));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
} }
SINGLE() SINGLE()
}; };

View File

@@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -16,13 +17,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
@@ -49,20 +45,18 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else() else()
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif() endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -154,7 +148,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare( FetchContent_Declare(
cudnn cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0 GIT_TAG v1.12.1
GIT_SHALLOW TRUE GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
@@ -170,10 +164,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda

View File

@@ -30,20 +30,8 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_; next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
int device_count = 0; cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
@@ -91,7 +79,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95; memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
} }

View File

@@ -6,33 +6,23 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups; template <typename T>
struct Arange {
const T start;
const T step;
template <typename T, typename IdxT, int N_WRITES> __device__ T operator()(uint32_t i) const {
__global__ void arange(T* out, IdxT size, T start, T step) { return start + i * step;
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
} }
} };
} // namespace cu } // namespace cu
@@ -46,23 +36,19 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType); CTYPE step =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
using IdxT = std::conditional_t<large(), int64_t, int32_t>; thrust::transform(
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES); cu::thrust_policy(encoder.stream()),
encoder.add_kernel_node( thrust::counting_iterator<uint32_t>(0),
cu::arange<OutType, IdxT, N_WRITES>, thrust::counting_iterator<uint32_t>(out.data_size()),
num_blocks, thrust::device_pointer_cast(out.data<OutType>()),
block_dims, cu::Arange<OutType>{
0, static_cast<OutType>(start_), static_cast<OutType>(step)});
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
}); });
} }

View File

@@ -99,89 +99,39 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
} }
} }
template < template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd( __global__ void binary_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides, const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data());
if (index_rest >= size_rest) { out[index] = Op{}(a[a_idx], b[b_idx]);
return;
} }
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g( __global__ void binary_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides, const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides, const __grid_constant__ Strides b_strides,
int ndim) { int ndim) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data(), ndim);
if (index_rest >= size_rest) { out[index] = Op{}(a[a_idx], b[b_idx]);
return;
} }
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
@@ -259,61 +209,39 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd< auto [num_blocks, block_dims] =
Op, get_launch_args(out, large());
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g_nd<
{num_blocks_x, num_blocks_y}, Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
rest, out.size(),
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>; auto [num_blocks, block_dims] = get_launch_args(out, large());
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g<Op, InType, OutType, IdxT>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
rest, out.size(),
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
const_param(b_strides), const_param(b_strides),
@@ -376,4 +304,54 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \ binary_op_gpu<cu::func>(inputs, out, name(), s); \
} }
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,21 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Power)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Remainder)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Subtract)
} // namespace mlx::core

View File

@@ -127,99 +127,45 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
} }
} }
template < template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_two_g_nd( __global__ void binary_two_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
Out* out_b, Out* out_b,
IdxT size_rest, IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides, const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data());
if (index_rest >= size_rest) { auto out = Op{}(a[a_idx], b[b_idx]);
return; out_a[index] = out[0];
out_b[index] = out[1];
} }
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_two_g( __global__ void binary_two_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
Out* out_b, Out* out_b,
IdxT size_rest, IdxT size,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides, const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides, const __grid_constant__ Strides b_strides,
int ndim) { int ndim) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [a_idx, b_idx] = elem_to_loc(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), a_strides.data(), b_strides.data(), ndim);
if (index_rest >= size_rest) { auto out = Op{}(a[a_idx], b[b_idx]);
return; out_a[index] = out[0];
out_b[index] = out[1];
} }
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
@@ -279,64 +225,42 @@ void binary_two_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out_a.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd< auto [num_blocks, block_dims] =
Op, get_launch_args(out_a, large());
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_two_g_nd<
{num_blocks_x, num_blocks_y}, Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
rest, out_a.size(),
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>; auto [num_blocks, block_dims] =
if (work_per_thread == 4) { get_launch_args(out_a, large());
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_two_g<Op, InType, OutType, IdxT>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
rest, out_a.size(),
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
const_param(b_strides), const_param(b_strides),

View File

@@ -267,8 +267,7 @@ void Compiled::eval_gpu(
} }
} }
return std::make_tuple( return std::make_pair(std::move(builder.os), std::move(kernel_names));
false, std::move(builder.os), std::move(kernel_names));
}); });
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
@@ -332,9 +331,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
} }
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread, max_block_dims); get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }

View File

@@ -1,12 +1,18 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
@@ -15,6 +21,9 @@ namespace mlx::core {
namespace { namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability. // Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR #define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \ #define CONV_BACKWARD_INPUT \
@@ -22,9 +31,6 @@ namespace {
#define CONV_BACKWARD_WEIGHT \ #define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
struct ConvCacheKey { struct ConvCacheKey {
int device_id; int device_id;
cudnnDataType_t cudnn_dtype; cudnnDataType_t cudnn_dtype;
@@ -44,13 +50,203 @@ struct ConvCacheKey {
auto& conv_cache() { auto& conv_cache() {
static LRUBytesKeyCache< static LRUBytesKeyCache<
ConvCacheKey, ConvCacheKey,
std::pair< std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
cudnnBackendDescriptorType_t, cache(/* capacity */ 128);
std::optional<cudnn_frontend::ExecutionPlan>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache; return cache;
} }
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
array& x,
array& w,
array& y) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
x.data<void>(),
w.data<void>(),
y.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag,
array& x,
array& w,
array& y) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true;
}
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return false;
}
auto get_conv_op_settings( auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type, cudnnBackendDescriptorType_t backend_type,
array& x, array& x,
@@ -95,7 +291,7 @@ auto get_conv_op_settings(
} }
} }
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph( std::optional<cudnn_frontend::OperationGraph> build_op_graph(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type, cudnnBackendDescriptorType_t backend_type,
Dtype dtype, Dtype dtype,
@@ -121,9 +317,9 @@ std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
.build(); .build();
auto op = cudnn_frontend::OperationBuilder(backend_type) auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_cudnn_tensor_nchw('x', x)) .setxDesc(build_tensor('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w)) .setwDesc(build_tensor('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y)) .setyDesc(build_tensor('y', y))
.setcDesc(conv_desc) .setcDesc(conv_desc)
.build(); .build();
@@ -140,42 +336,6 @@ std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
} }
} }
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
const array& x,
int groups,
int group_dim,
int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
}
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
// Do necessary transposes and copies to prepare the inputs and outputs for // Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one // building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies. // eval_gpu, with cost of possible redundant copies.
@@ -185,14 +345,13 @@ std::tuple<array, array, array> prepare_args(
array in, array in,
array wt, array wt,
array out, array out,
int groups,
Stream s) { Stream s) {
// Transpose the args depending on the backend type. // Transpose the args depending on the backend type.
// TODO: Handle groups. // TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) { if (backend_type == CONV_BACKWARD_INPUT) {
wt = group_transpose(wt, groups, 0, 0, -1, s); wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) { } else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = group_transpose(in, groups, -1, 0, -1, s); in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1); wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim // Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped. // C_in and C_out swapped.
@@ -285,12 +444,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
ConvCacheKey cache_key{ ConvCacheKey cache_key{
encoder.device().cuda_device(), encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype), dtype_to_cudnn_type(dtype),
vector_key(in.shape()), fixed_vector(in.shape()),
vector_key(wt.shape()), fixed_vector(wt.shape()),
vector_key(kernel_strides_), fixed_vector(kernel_strides_),
vector_key(padding_lo_), fixed_vector(padding_lo_),
vector_key(padding_hi_), fixed_vector(padding_hi_),
vector_key(kernel_dilation_), fixed_vector(kernel_dilation_),
groups_, groups_,
flip_, flip_,
get_alignment(in), get_alignment(in),
@@ -298,29 +457,11 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(out)}; get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second; auto& [backend_type, plan] = it->second;
if (plan) { std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
// Run cached plan. register_args(encoder, backend_type, in, wt, out, out_);
std::tie(in, wt, out) = auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
prepare_args(encoder, backend_type, in, wt, out, groups_, s); if (!execute_plan(encoder, plan, x, w, y)) {
register_args(encoder, backend_type, in, wt, out, out_); throw std::runtime_error("[conv] Cached plan failed to execute.");
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
} else {
// Run fallback kernel.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
} }
return; return;
} }
@@ -349,7 +490,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
std::optional<cudnn_frontend::OperationGraph> op_graph; std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) { for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] = auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s); prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings( auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend, try_backend,
@@ -361,7 +502,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_, padding_hi_,
kernel_dilation_, kernel_dilation_,
input_dilation_); input_dilation_);
op_graph = build_conv_op_graph( op_graph = build_op_graph(
encoder, encoder,
try_backend, try_backend,
dtype, dtype,
@@ -380,38 +521,26 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
break; break;
} }
} }
if (!op_graph) {
if (op_graph) { throw std::runtime_error("[conv] Can not build op graph.");
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
} }
// Use fallback kernel for settings not supported by cuDNN. // Get ready to execute the graph.
gemm_conv( register_args(encoder, backend_type, in, wt, out, out_);
encoder,
in, // Try to run plans based on heuristics.
wt, auto configs = get_engine_configs(backend_type, dtype, *op_graph);
out, auto tag = op_graph->getTag();
kernel_strides_, auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
padding_lo_, if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
kernel_dilation_, return;
input_dilation_, }
groups_, // Then try fallback plans.
flip_, configs = get_engine_configs(backend_type, dtype, *op_graph);
s); if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt)); return;
}
throw std::runtime_error("[conv] Unable to find a working engine.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,126 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
template <int NDIM>
struct ConvParams {
int N; // Batch size
int C; // In channels
int O; // Out channels
int strides[NDIM];
int padding[NDIM];
int kernel_dilation[NDIM];
int input_dilation[NDIM];
int groups;
bool flip;
int in_spatial_dims[NDIM];
int wt_spatial_dims[NDIM];
int out_spatial_dims[NDIM];
int64_t in_strides[NDIM + 2];
ConvParams(
const array& in,
const array& wt,
const array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip)
: N(in.shape(0)),
C(in.shape(-1)),
O(wt.shape(0)),
groups(groups),
flip(flip) {
std::copy_n(strides.begin(), NDIM, this->strides);
std::copy_n(padding.begin(), NDIM, this->padding);
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
}
};
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s);
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s);
inline void gemm_conv(
cu::CommandEncoder& encoder,
array in,
array wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
if (groups == 1) {
gemm_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
flip,
s);
} else {
gemm_grouped_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip,
s);
}
}
} // namespace mlx::core

View File

@@ -1,217 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_unfold_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array unfold_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = mat_K / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_unfold_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
int mat_N = params.O; // O
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded =
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
wt_reshaped.copy_shared_buffer(
wt,
{1, mat_K},
{false, false, /* col_contiguous */ true},
wt.data_size());
// Single batch.
Shape batch_shape{1};
Strides a_batch_strides{0};
Strides b_batch_strides{0};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
1, // groups
flip);
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -1,231 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_grouped_unfold_transpose_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + tid.y * (filter_size / params.C);
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
int wt_stride = 1;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
out += index_wt * wt_stride;
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
wt_stride *= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array grouped_unfold_transpose_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = (mat_K * params.groups) / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_grouped_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int C_per_group = params.C / params.groups;
int O_per_group = params.O / params.groups;
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
int mat_N = O_per_group; // O_per_group
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
array wt_view(
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
// Batch with size of groups.
Shape batch_shape{params.groups};
Strides a_batch_strides{mat_K};
Strides b_batch_strides{mat_N * mat_K};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K * params.groups, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.set_out(
out.dtype(),
false, // out_transposed
mat_M, // out_rows
mat_N, // out_cols
mat_N * params.groups, // out_ld
params.groups, // batch_count
mat_N); // batch_stride
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip);
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -15,8 +15,8 @@ void copy_gpu_inplace(
int64_t offset_out, int64_t offset_out,
CopyType ctype, CopyType ctype,
const Stream& s, const Stream& s,
std::optional<array> dynamic_offset_in, const std::optional<array>& dynamic_offset_in,
std::optional<array> dynamic_offset_out) { const std::optional<array>& dynamic_offset_out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -44,16 +44,6 @@ void copy_gpu_inplace(
strides_vec[0]); strides_vec[0]);
} else { } else {
if (dynamic_offset_in || dynamic_offset_out) { if (dynamic_offset_in || dynamic_offset_out) {
if (!dynamic_offset_in) {
dynamic_offset_in = array(0, int64);
encoder.add_temporary(*dynamic_offset_in);
}
if (!dynamic_offset_out) {
dynamic_offset_out = array(0, int64);
encoder.add_temporary(*dynamic_offset_out);
}
encoder.set_input_array(*dynamic_offset_in);
encoder.set_input_array(*dynamic_offset_out);
copy_general_dynamic( copy_general_dynamic(
encoder, encoder,
ctype, ctype,
@@ -64,8 +54,8 @@ void copy_gpu_inplace(
shape_collapsed, shape_collapsed,
strides_vec[0], strides_vec[0],
strides_vec[1], strides_vec[1],
*dynamic_offset_in, dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
*dynamic_offset_out); dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
} else { } else {
copy_general( copy_general(
encoder, encoder,

View File

@@ -10,80 +10,37 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS> template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_nd( __global__ void copy_gg_nd(
const In* in, const In* in,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in, const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) { const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), strides_in.data(), strides_out.data());
if (index_rest >= size_rest) { out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
return;
} }
auto shape_x = shape[NDIM - 1];
auto in_stride_x = strides_in[NDIM - 1];
auto out_stride_x = strides_out[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data());
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
} }
template <typename In, typename Out, typename IdxT, int N_READS> template <typename In, typename Out, typename IdxT>
__global__ void copy_gg( __global__ void copy_gg(
const In* in, const In* in,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out, const __grid_constant__ Strides strides_out,
int ndim) { int ndim) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = auto [idx_in, idx_out] = elem_to_loc(
grid.block_index().y * block.dim_threads().y + block.thread_index().y; index, shape.data(), strides_in.data(), strides_out.data(), ndim);
if (index_rest >= size_rest) { out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
return;
} }
auto shape_x = shape[ndim - 1];
auto in_stride_x = strides_in[ndim - 1];
auto out_stride_x = strides_out[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data(),
ndim);
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
} }
} // namespace cu } // namespace cu
@@ -112,52 +69,33 @@ void copy_general(
size_t data_size = 1; size_t data_size = 1;
for (auto& s : shape) for (auto& s : shape)
data_size *= s; data_size *= s;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = data_size / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) { dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto kernel = auto [num_blocks, block_dims] =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>; get_launch_args(data_size, shape, out.strides(), large());
if (work_per_thread == 4) {
kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
rest, data_size,
const_param<ndim_constant()>(shape), const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in), const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out)); const_param<ndim_constant()>(strides_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>; auto [num_blocks, block_dims] =
if (work_per_thread == 4) { get_launch_args(data_size, shape, out.strides(), large());
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg<InType, OutType, IdxT>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
rest, data_size,
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
const_param(strides_out), const_param(strides_out),

View File

@@ -10,67 +10,33 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS> template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_g_nd( __global__ void copy_g_nd(
const In* in, const In* in,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
grid.block_index().y * block.dim_threads().y + block.thread_index().y; out[index] = CastOp<In, Out>{}(in[idx_in]);
if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[NDIM - 1];
auto stride_x = strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename In, typename Out, typename IdxT, int N_READS> template <typename In, typename Out, typename IdxT>
__global__ void copy_g( __global__ void copy_g(
const In* in, const In* in,
Out* out, Out* out,
IdxT size_rest, IdxT size,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides, const __grid_constant__ Strides strides_in,
int ndim) { int ndim) {
auto block = cg::this_thread_block(); IdxT index = cg::this_grid().thread_rank();
auto grid = cg::this_grid(); if (index < size) {
IdxT index_rest = IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
grid.block_index().y * block.dim_threads().y + block.thread_index().y; out[index] = CastOp<In, Out>{}(in[idx_in]);
if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
} // namespace cu } // namespace cu
@@ -95,49 +61,30 @@ void copy_general_input(
const InType* in_ptr = in.data<InType>() + offset_in; const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out; OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto [num_blocks, block_dims] = get_launch_args(out, large());
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
rest, out.size(),
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in)); const_param<dims_constant()>(strides_in));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>; auto [num_blocks, block_dims] = get_launch_args(out, large());
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_g<InType, OutType, IdxT>,
{num_blocks_x, num_blocks_y}, num_blocks,
block_dims, block_dims,
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
rest, out.size(),
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
ndim); ndim);

View File

@@ -1,275 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
encoder.add_temporary(workspace);
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -1,164 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
//
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
} // namespace mlx::core

View File

@@ -1,379 +0,0 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@@ -14,6 +14,10 @@ namespace mlx::core::cu {
namespace { namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) { void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -23,11 +27,11 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
} }
} }
bool use_cuda_graphs() { int cuda_graph_cache_size() {
static bool use_graphs = []() { static int cache_size = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true); return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
}(); }();
return use_graphs; return cache_size;
} }
} // namespace } // namespace
@@ -64,8 +68,8 @@ Device::~Device() {
void Device::make_current() { void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce // We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. // actual calls of CUDA APIs. This function assumes single-thread in host.
static thread_local int current = 0; static int current = 0;
if (current != device_) { if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_)); CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_; current = device_;
@@ -82,20 +86,14 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current(); enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) { CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
enc.node_count_++; std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
return; &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
}
graph.end_capture(enc.stream());
if (discard) { if (discard) {
return; return;
} }
@@ -109,9 +107,6 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() { CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false; enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization // Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -190,46 +185,37 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
} }
CommandEncoder::CommandEncoder(Device& d) CommandEncoder::CommandEncoder(Device& d)
: device_(d), : device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
stream_(d), CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
graph_(d), }
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
} }
void CommandEncoder::set_input_array(const array& arr) { void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
} }
void CommandEncoder::set_output_array(const array& arr) { void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
active_outputs_.push_back(id); active_outputs_.push_back(id);
} }
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
@@ -245,23 +231,6 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x; kernel_params.gridDimX = grid_dim.x;
@@ -288,38 +257,20 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
return;
}
cudaGraphNode_t node; cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, 'G'});
} }
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit"); nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }
if (use_cuda_graphs() && node_count_ > 0) { if (node_count_ > 0) {
if (!from_nodes_.empty()) { if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies( CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
} }
graph_key_ += "."; graph_key_ += ".";
@@ -353,18 +304,19 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state // Reset state
node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0; empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
to_nodes_.clear(); to_nodes_.clear();
graph_key_.clear(); graph_key_.clear();
node_map_.clear(); node_map_.clear();
graph_ = CudaGraph(device_); CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.commit(stream_); worker_.commit(stream_);
node_count_ = 0;
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {

View File

@@ -21,7 +21,7 @@ class CommandEncoder {
struct CaptureContext { struct CaptureContext {
CaptureContext(CommandEncoder& enc); CaptureContext(CommandEncoder& enc);
~CaptureContext(); ~CaptureContext();
CudaGraph graph; cudaGraph_t graph;
CommandEncoder& enc; CommandEncoder& enc;
bool discard{false}; bool discard{false};
}; };
@@ -76,6 +76,9 @@ class CommandEncoder {
uint32_t smem_bytes, uint32_t smem_bytes,
void** params); void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child); void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) { void add_temporary(const array& arr) {
@@ -83,7 +86,7 @@ class CommandEncoder {
} }
void add_completed_handler(std::function<void()> task); void add_completed_handler(std::function<void()> task);
int get_num_ops(); void maybe_commit();
void commit(); void commit();
Device& device() { Device& device() {
@@ -98,9 +101,6 @@ class CommandEncoder {
void synchronize(); void synchronize();
private: private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode { struct GraphNode {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
@@ -115,7 +115,7 @@ class CommandEncoder {
Device& device_; Device& device_;
CudaStream stream_; CudaStream stream_;
CudaGraph graph_; cudaGraph_t graph_;
Worker worker_; Worker worker_;
char node_count_{0}; char node_count_{0};
char graph_node_count_{0}; char graph_node_count_{0};
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete; Device(const Device&) = delete;
Device& operator=(const Device&) = delete; Device& operator=(const Device&) = delete;
// Make this device the current cuda device, this method is thread-safe. // Make this device the current cuda device, required by some cuda calls.
void make_current(); void make_current();
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);

View File

@@ -204,12 +204,6 @@ struct Power {
__device__ T operator()(T base, T exp) { __device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
T res = 1; T res = 1;
// Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) {
return 0;
}
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@@ -6,6 +6,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -115,4 +116,15 @@ inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x); return CastOp<SrcT, DstT>{}(x);
} }
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -257,8 +257,8 @@ struct Round {
struct Sigmoid { struct Sigmoid {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
T y = 1 / (1 + exp(abs(x))); T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? y : 1 - y; return (x < 0) ? 1 - y : y;
} }
}; };

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilities that work under both // This file must not include any host-only code, utilies that work under both
// host and device can be put here. // host and device can be put here.
// //
// See more about the requirements at: // See more about the requirements at:
@@ -146,23 +146,6 @@ inline __device__ void store_vector(
} }
} }
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size,
int64_t stride) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -202,7 +185,7 @@ struct Limits<
} }
}; };
// CUDA 11 does not have host side arithmetic operators for half types. // CUDA 11 does not have host side arithmatic operators for half types.
template <typename T> template <typename T>
struct Limits< struct Limits<
T, T,

View File

@@ -1,56 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace mlx::core::distributed

View File

@@ -5,24 +5,18 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu { namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() { bool is_available() {
return true; return true;
} }
void new_stream(Stream s) { void new_stream(Stream s) {
// Force initalization of CUDA, so CUDA runtime get destroyed at last. // Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr); cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
} }
@@ -40,8 +34,7 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
} }
auto& stream = arr.primitive().stream(); auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
// Except for the donated one. // Except for the donated one.
@@ -52,14 +45,7 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
encoder.add_temporary(s); encoder.add_temporary(s);
} }
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
} }
void finalize(Stream s) { void finalize(Stream s) {

View File

@@ -3,12 +3,10 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
@@ -19,180 +17,104 @@ namespace cu {
// CudaEvent implementations // CudaEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
namespace { // Cuda event managed with RAII.
class CudaEventHandle {
// Manage cached cudaEvent_t objects.
class CudaEventPool {
public: public:
CudaEventHandle create(Device& d, int flags) { CudaEventHandle() {
if (!on_creation_thread()) { CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
return CudaEventHandle(d, flags); &event_, cudaEventDisableTiming | cudaEventBlockingSync));
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
} }
void release(CudaEventHandle event) { ~CudaEventHandle() {
if (!on_creation_thread()) { CHECK_CUDA_ERROR(cudaEventDestroy(event_));
// Event will be destroyed directly instead of getting moved to cache. }
return;
} CudaEventHandle(const CudaEventHandle&) = delete;
cache_for(event.device, event.flags).push_back(std::move(event)); CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
} }
private: private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) { cudaEvent_t event_;
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
}; };
CudaEventPool& cuda_event_pool() { CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() { void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait"); nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current(); if (!recorded_) {
cudaEventSynchronize(event_); throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
} }
void CudaEvent::wait(cudaStream_t stream) { void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current(); if (!recorded_) {
cudaStreamWaitEvent(stream, event_); throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
} }
void CudaEvent::record(cudaStream_t stream) { void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current(); cudaEventRecord(*event_, stream);
cudaEventRecord(event_, stream); recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
} }
bool CudaEvent::completed() const { bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device. return cudaEventQuery(*event_) == cudaSuccess;
return cudaEventQuery(event_) == cudaSuccess;
} }
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// AtomicEvent implementations // SharedEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current; uint64_t current;
while ((current = ac->load()) < value) { while ((current = ac->load()) < value) {
ac->wait(current); ac->wait(current);
} }
} }
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
ac->store(value); ac->store(value);
ac->notify_all(); ac->notify_all();
} }
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) { __global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value); event_wait(ac, value);
} }
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) { __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
AtomicEvent::AtomicEvent() { SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>( buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr); allocator().free(*ptr);
@@ -201,17 +123,17 @@ AtomicEvent::AtomicEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0; *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
} }
void AtomicEvent::wait(uint64_t value) { void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait"); nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(atomic(), value); event_wait(to_atomic(buf_), value);
} }
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) { void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value); event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void AtomicEvent::wait(Stream s, uint64_t value) { void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)"); nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
@@ -222,17 +144,17 @@ void AtomicEvent::wait(Stream s, uint64_t value) {
} }
} }
void AtomicEvent::signal(uint64_t value) { void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal"); nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(atomic(), value); event_signal(to_atomic(buf_), value);
} }
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value); event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void AtomicEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)"); nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified. // the atomic in CPU sometimes does not get GPU notified.
@@ -246,14 +168,14 @@ void AtomicEvent::signal(Stream s, uint64_t value) {
} }
} }
bool AtomicEvent::is_signaled(uint64_t value) const { bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled"); nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return atomic()->load() >= value; return to_atomic(buf_)->load() >= value;
} }
uint64_t AtomicEvent::value() const { uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::AtomicEvent::value"); nvtx3::scoped_range r("cu::SharedEvent::value");
return atomic()->load(); return to_atomic(buf_)->load();
} }
} // namespace cu } // namespace cu
@@ -266,14 +188,14 @@ namespace {
struct EventImpl { struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have // CudaEvent is preferred when possible because it is fast, however we have
// to fallback to AtomicEvent in following cases: // to fallback to SharedEvent in following cases:
// 1. the event is used to wait/signal a cpu stream; // 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified. // 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CopyableCudaEvent> cuda; std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::AtomicEvent> atomic; std::unique_ptr<cu::SharedEvent> shared;
bool is_created() const { bool is_created() const {
return cuda || atomic; return cuda || shared;
} }
void ensure_created(Stream s, uint64_t signal_value) { void ensure_created(Stream s, uint64_t signal_value) {
@@ -281,10 +203,10 @@ struct EventImpl {
return; return;
} }
if (s.device == mlx::core::Device::cpu || signal_value > 1) { if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow AtomicEvent"); nvtx3::mark("Using slow SharedEvent");
atomic = std::make_unique<cu::AtomicEvent>(); shared = std::make_unique<cu::SharedEvent>();
} else { } else {
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device)); cuda = std::make_unique<cu::CudaEvent>();
} }
} }
}; };
@@ -303,7 +225,7 @@ void Event::wait() {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(); event->cuda->wait();
} else { } else {
event->atomic->wait(value()); event->shared->wait(value());
} }
} }
@@ -314,7 +236,7 @@ void Event::wait(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(s); event->cuda->wait(s);
} else { } else {
event->atomic->wait(s, value()); event->shared->wait(s, value());
} }
} }
@@ -325,7 +247,7 @@ void Event::signal(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->record(s); event->cuda->record(s);
} else { } else {
event->atomic->signal(s, value()); event->shared->signal(s, value());
} }
} }
@@ -336,9 +258,9 @@ bool Event::is_signaled() const {
} }
if (event->cuda) { if (event->cuda) {
assert(value() == 1); assert(value() == 1);
return event->cuda->is_signaled(); return event->cuda->recorded() && event->cuda->completed();
} else { } else {
return event->atomic->is_signaled(value()); return event->shared->is_signaled(value());
} }
} }

View File

@@ -3,60 +3,49 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda/atomic> #include <cuda/atomic>
#include <memory>
namespace mlx::core::cu { namespace mlx::core::cu {
class Device; class CudaEventHandle;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait // Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream. // on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent { class CudaEvent {
public: public:
CudaEvent(Device& d, int flags); CudaEvent();
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
void wait(); void wait();
void wait(cudaStream_t stream); void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream); void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method // Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called. // returns true if record() has not been called.
bool completed() const; bool completed() const;
// Internal: make sure event pool is initialized. bool recorded() const {
static void init_pool(); return recorded_;
}
private: private:
CudaEventHandle event_; bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
}; };
// Event that can synchronize between CPU and GPU. It is much slower than // Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible. // CudaEvent so the latter should always be preferred when possible.
class AtomicEvent { class SharedEvent {
public: public:
using Atomic = cuda::atomic<uint64_t>; using Atomic = cuda::atomic<uint64_t>;
AtomicEvent(); SharedEvent();
void wait(uint64_t value); void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value); void wait(cudaStream_t stream, uint64_t value);
@@ -68,11 +57,7 @@ class AtomicEvent {
uint64_t value() const; uint64_t value() const;
private: private:
Atomic* atomic() const { std::shared_ptr<mlx::core::allocator::Buffer> buf_;
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -7,7 +7,7 @@ namespace mlx::core {
struct FenceImpl { struct FenceImpl {
uint32_t count; uint32_t count;
cu::AtomicEvent event; cu::SharedEvent event;
}; };
Fence::Fence(Stream s) { Fence::Fence(Stream s) {

View File

@@ -4,17 +4,16 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core { namespace mlx::core::cu {
void CublasGemm::run_batched( void Matmul::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const Shape& batch_shape, const mlx::core::Shape& batch_shape,
const Strides& a_batch_strides, const mlx::core::Strides& a_batch_strides,
const Strides& b_batch_strides, const mlx::core::Strides& b_batch_strides) {
float alpha) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -23,28 +22,27 @@ void CublasGemm::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
execute( run_impl(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc, b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr, nullptr);
alpha);
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }
} }
void CublasGemm::run_batched( void Matmul::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const Shape& batch_shape, const mlx::core::Shape& batch_shape,
const Strides& a_batch_strides, const mlx::core::Strides& a_batch_strides,
const Strides& b_batch_strides, const mlx::core::Strides& b_batch_strides,
const Strides& c_batch_strides, const mlx::core::Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
encoder.set_input_array(a); encoder.set_input_array(a);
@@ -58,7 +56,7 @@ void CublasGemm::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
execute( run_impl(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -72,4 +70,4 @@ void CublasGemm::run_batched(
} }
} }
} // namespace mlx::core } // namespace mlx::core::cu

View File

@@ -0,0 +1,208 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,12 +7,10 @@
#include <fmt/format.h> #include <fmt/format.h>
namespace mlx::core { namespace mlx::core::cu {
namespace {
struct CublasPreference { struct CublasPreference {
CublasPreference(cu::Device& device) { CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+: // for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -35,7 +33,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
}; };
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) { cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device); static CublasPreference pref(device);
return pref.pref_; return pref.pref_;
} }
@@ -50,13 +48,11 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F; : CUBLAS_COMPUTE_32F;
case float64: case float64:
return CUBLAS_COMPUTE_64F;
case complex64: case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 return CUBLAS_COMPUTE_64F;
: CUBLAS_COMPUTE_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
} }
} }
@@ -74,7 +70,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F; return CUDA_C_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
} }
} }
@@ -87,10 +83,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count, int32_t batch_count,
int64_t batch_stride) { int64_t batch_stride) {
cublasLtMatrixLayout_t desc; cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) { if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, desc,
@@ -106,10 +102,8 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc; return desc;
} }
} // namespace Matmul::Matmul(
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -128,51 +122,41 @@ CublasGemm::CublasGemm(
N_(b_cols) { N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
scale_type_ = dtype_to_cublas_type(dtype); auto scale_type = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) { if (dtype == bfloat16 || dtype == float16) {
scale_type_ = CUDA_R_32F; scale_type = CUDA_R_32F;
} }
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_)); &matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, &pointer_mode,
sizeof(int32_t))); sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA, CUBLASLT_MATMUL_DESC_TRANSA,
&a_op, &op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB, CUBLASLT_MATMUL_DESC_TRANSB,
&b_op, &op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout( a_desc_ = create_matrix_layout(
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout( b_desc_ = create_matrix_layout(
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
} }
CublasGemm::CublasGemm( Matmul::Matmul(
cu::Device& device, Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -187,7 +171,7 @@ CublasGemm::CublasGemm(
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride) int64_t c_batch_stride)
: CublasGemm( : Matmul(
device, device,
dtype, dtype,
a_transposed, a_transposed,
@@ -203,10 +187,10 @@ CublasGemm::CublasGemm(
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
CublasGemm::~CublasGemm() { Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -214,122 +198,7 @@ CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void CublasGemm::set_out( void Matmul::run_impl(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
cols,
rows,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -355,16 +224,6 @@ void CublasGemm::execute(
} }
} }
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned // Ensure workspace is 256-byte aligned
@@ -381,12 +240,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_, handle_,
matmul_desc_, matmul_desc_,
alpha_ptr, &alpha,
b, // a and b are swapped
a_desc_,
a, a,
a_desc_,
b,
b_desc_, b_desc_,
beta_ptr, &beta,
c ? c : out, c ? c : out,
c ? c_desc_ : out_desc_, c ? c_desc_ : out_desc_,
out, out,
@@ -397,4 +256,29 @@ void CublasGemm::execute(
encoder.stream())); encoder.stream()));
} }
} // namespace mlx::core void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <optional>
namespace mlx::core { namespace mlx::core::cu {
class Matmul {
class CublasGemm {
public: public:
CublasGemm( Matmul(
cu::Device& device, Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -25,8 +25,8 @@ class CublasGemm {
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride); int64_t b_batch_stride);
CublasGemm( Matmul(
cu::Device& device, Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -42,69 +42,41 @@ class CublasGemm {
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride); int64_t c_batch_stride);
~CublasGemm(); ~Matmul();
// The output's descriptor is inferred from inputs by default, use this method
// for unusual output.
void set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const Shape& batch_shape, const std::optional<array>& c = std::nullopt,
const Strides& a_batch_strides, float alpha = 1,
const Strides& b_batch_strides, float beta = 0);
float alpha = 1.0f);
void run( void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const Shape& batch_shape, const mlx::core::Shape& batch_shape,
const Strides& a_batch_strides, const mlx::core::Strides& a_batch_strides,
const Strides& b_batch_strides, const mlx::core::Strides& b_batch_strides,
const Strides& c_batch_strides, const mlx::core::Strides& c_batch_strides,
float alpha, float alpha,
float beta); float beta);
private: private:
void run_batched( void run_impl(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
void execute(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -115,7 +87,6 @@ class CublasGemm {
uint64_t M_; uint64_t M_;
uint64_t N_; uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
@@ -126,4 +97,4 @@ class CublasGemm {
cublasLtMatmulHeuristicResult_t heuristic_; cublasLtMatmulHeuristicResult_t heuristic_;
}; };
} // namespace mlx::core } // namespace mlx::core::cu

View File

@@ -1,329 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr,
alpha);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -13,37 +13,6 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8; static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread> template <typename T, int rows_per_block, int n_per_thread>
__device__ void __device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -55,8 +24,7 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y; int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) { if (row < rows) {
using Acc = typename GemvAccType<T>::type; float sum = 0.0f;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols; for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) { col += (WARP_SIZE * n_per_thread)) {
auto local_mat = auto local_mat =
@@ -64,11 +32,12 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0); auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll #pragma unroll
for (int j = 0; j < n_per_thread; ++j) { for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]); sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
} }
} }
sum = cg::reduce(warp, sum, cg::plus<Acc>{}); sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum); out[row] = static_cast<T>(sum);
} }
@@ -138,7 +107,7 @@ void gemv(
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) { dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block}; dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat; const DataType* mat;

View File

@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_tuple(false, jit_source_gather, std::move(kernel_names)); return std::make_pair(jit_source_gather, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim()); args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_); args.append_ndim(slice_sizes_);
args.append(slice_size); args.append(slice_size);
args.append(axes_); args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); return std::make_pair(jit_source_scatter, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape()); args.append_ndim(out.shape());
args.append_ndim(out.strides()); args.append_ndim(out.strides());
args.append<int32_t>(out.ndim()); args.append<int32_t>(out.ndim());
args.append(axes_); args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
@@ -268,8 +268,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_tuple( return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
false, jit_source_gather_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;
@@ -372,8 +371,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_tuple( return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
false, jit_source_scatter_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;

View File

@@ -67,11 +67,9 @@ const std::string& cccl_dir() {
return path.string(); return path.string();
} }
// Finally check the environment variable. // Finally check the environment variable.
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) { path = std::getenv("MLX_CCCL_DIR");
path = env; if (!path.empty() && std::filesystem::exists(path)) {
if (!path.empty() && std::filesystem::exists(path)) { return path.string();
return path.string();
}
} }
return std::string(); return std::string();
}(); }();
@@ -99,41 +97,17 @@ const std::filesystem::path& ptx_cache_dir() {
return cache; return cache;
} }
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
std::string& ptx, std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) { std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return false; return false;
} }
auto ptx_path = get_ptx_path(cache_dir, module_name); auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) { if (error) {
@@ -143,15 +117,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) { if (!ptx_file.good()) {
return false; return false;
} }
ptx.resize(ptx_size); ptx->resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size); ptx_file.read(ptx->data(), ptx_size);
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
if (tab != std::string::npos) { if (tab != std::string::npos) {
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1)); ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
} }
} }
return true; return true;
@@ -161,33 +135,23 @@ bool read_cached_ptx(
void write_cached_ptx( void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::string& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) { const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return; return;
} }
auto ptx_path = get_ptx_path(cache_dir, module_name); std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
} }
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl; txt_file << name << "\t" << mangled << std::endl;
} }
// Write the generated code std::ofstream source_file(cache_dir / (module_name + ".cu"));
std::ofstream source_file(ptx_path.replace_extension(".cu"));
source_file << source_code; source_file << source_code;
} }
@@ -253,86 +217,85 @@ constexpr const char* g_headers[] = {
jit_source_utils, jit_source_utils,
}; };
void compile( } // namespace
JitModule::JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const std::string& source, const KernelBuilder& builder) {
const std::vector<std::string>& kernel_names, // Check cache.
std::string& ptx, std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>>& ptx_kernels) { std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Create the program if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
nvrtcProgram prog; // Create program.
CHECK_NVRTC_ERROR(nvrtcCreateProgram( auto [source_code, kernel_names] = builder();
&prog, nvrtcProgram prog;
source.c_str(), CHECK_NVRTC_ERROR(nvrtcCreateProgram(
(module_name + ".cu").c_str(), &prog,
std::size(g_headers), source_code.c_str(),
g_headers, (module_name + ".cu").c_str(),
g_include_names)); std::size(g_headers),
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer( g_headers,
&prog, g_include_names));
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
for (const auto& name : kernel_names) { &prog,
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size, 0);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
}
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -349,77 +312,21 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_tuple(kernel, false, 0); kernels_[name] = kernel;
} }
} }
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims( CUfunction JitModule::get_kernel(const std::string& kernel_name) {
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) { if (it == kernels_.end()) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name)); fmt::format("There is no kernel named {}.", kernel_name));
} }
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) {
configure_kernel(kernel);
}
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
}
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
@@ -430,12 +337,11 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder, const KernelBuilder& builder) {
bool cache) {
auto& map = get_jit_module_cache(); auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder, cache).first; it = map.try_emplace(name, cu::device(device), name, builder).first;
} }
return it->second; return it->second;
} }

View File

@@ -19,8 +19,7 @@ namespace mlx::core::cu {
class Device; class Device;
using KernelBuilderResult = std::tuple< using KernelBuilderResult = std::pair<
/* precompiled */ bool,
/* source code */ std::string, /* source code */ std::string,
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
@@ -46,11 +45,6 @@ struct KernelArgs {
append_ptr(std::get<SmallVector<T>>(storage_.back()).data()); append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
} }
template <typename T>
void append(const std::vector<T>& vec) {
append(SmallVector<T>(vec.begin(), vec.end()));
}
// Make sure the arg is copied to an array with size of NDIM. // Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T> template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) { void append_ndim(SmallVector<T> vec) {
@@ -69,16 +63,14 @@ struct KernelArgs {
private: private:
std::vector<void*> args_; std::vector<void*> args_;
// The cuGraphAddKernelNode API requires passing pointers to arguments so // The cuLaunchKernel API requires passing pointers to arguments so store
// store temporary values until the node is created. // temporary values untill kernel is launched.
using Arg = std::variant< using Arg = std::variant<
std::monostate, std::monostate,
CUdeviceptr, CUdeviceptr,
bool,
int32_t, int32_t,
uint32_t, uint32_t,
int64_t, int64_t,
float,
SmallVector<const void*>, SmallVector<const void*>,
SmallVector<int32_t>, SmallVector<int32_t>,
SmallVector<int64_t>>; SmallVector<int64_t>>;
@@ -90,22 +82,16 @@ class JitModule {
JitModule( JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder, const KernelBuilder& builder);
bool cache);
~JitModule(); ~JitModule();
JitModule(const JitModule&) = delete; JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel( CUfunction get_kernel(const std::string& kernel_name);
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_; std::unordered_map<std::string, CUfunction> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();
@@ -113,7 +99,6 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder, const KernelBuilder& builder);
bool use_disk_cache = true);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -35,10 +35,12 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread /* = 1 */, int work_per_thread) {
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread); size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads; uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks; dim3 num_blocks;
if (large) { if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file includes host-only utilities for writing CUDA kernels, the // This file includes host-only utilies for writing CUDA kernels, the difference
// difference from backend/cuda/device/utils.cuh is that the latter file only // from backend/cuda/device/utils.cuh is that the latter file only include
// include device-only code. // device-only code.
#pragma once #pragma once
@@ -120,28 +120,19 @@ dim3 get_2d_grid_dims(
size_t divisor); size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2); std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims assuming each thread handles // Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// |work_per_thread| elements of |arr|. // assuming each thread handles |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args( std::tuple<dim3, uint> get_launch_args(
size_t size, size_t size,
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread = 1, int work_per_thread = 1);
uint max_block_dim = 1024);
inline std::tuple<dim3, uint> get_launch_args( inline std::tuple<dim3, uint>
const array& arr, get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args( return get_launch_args(
arr.size(), arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,15 +2,11 @@
#pragma once #pragma once
#include "mlx/utils.h"
#include <cstring> #include <cstring>
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <fmt/format.h>
namespace mlx::core { namespace mlx::core {
template < template <
@@ -31,14 +27,6 @@ class LRUCache {
} }
} }
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const { size_t size() const {
return map_.size(); return map_.size();
} }
@@ -88,14 +76,6 @@ class LRUCache {
return {it->second, false}; return {it->second, false};
} }
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value)); vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin(); map_[key] = vlist_.begin();
@@ -126,9 +106,6 @@ class LRUCache {
} }
} }
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_; list_type vlist_;
map_type map_; map_type map_;
size_t capacity_; size_t capacity_;

View File

@@ -11,7 +11,6 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
@@ -29,80 +28,6 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} }
} }
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace } // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -123,6 +48,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
int K = a_pre.shape(-1); int K = a_pre.shape(-1);
@@ -132,8 +60,65 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
gemm_and_bias( /////////////////////////////////////////////////////////////////////////////
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); // Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -158,29 +143,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc; int64_t ldc;
{ {
auto stx = c.strides()[c.ndim() - 2]; auto stx = c.strides()[c.ndim() - 2];
@@ -222,9 +184,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt with AddMM settings // Invoke cublasLt
CublasGemm gemm( cu::Matmul matmul(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -240,7 +202,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder, encoder,
out, out,
a, a,

View File

@@ -1,47 +1,11 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core { namespace mlx::core::cu {
namespace cu {
bool is_available() { bool is_available() {
return false; return false;
} }
} // namespace cu } // namespace mlx::core::cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -24,6 +24,8 @@ namespace mlx::core {
} }
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT) NO_GPU(FFT)
NO_GPU(GatherMM) NO_GPU(GatherMM)
NO_GPU(GatherQMM) NO_GPU(GatherQMM)
@@ -39,7 +41,12 @@ NO_GPU(Cholesky)
NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send) NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv) NO_GPU_MULTI(Recv)

View File

@@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
} // namespace } // namespace
void fast::Quantize::eval_gpu( void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("Quantize::eval_gpu"); nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream(); auto& s = stream();
auto& d = cu::device(s.device); auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s); auto& enc = d.get_command_encoder(s);

View File

@@ -181,47 +181,6 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
} }
} }
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu } // namespace cu
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
@@ -247,7 +206,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
const cu::ColReduceArgs& args) { cu::ColReduceArgs args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -271,55 +230,12 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel, grid, blocks, 0, indata, out.data<U>(), args);
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
}); });
}); });
}); });
} }
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -342,13 +258,6 @@ void col_reduce(
// Make the args struct to help route to the best kernel // Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes); cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -7,6 +7,8 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core { namespace mlx::core {
@@ -81,8 +83,7 @@ struct RowReduceArgs {
}; };
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1> template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -90,8 +91,8 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value(); const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op; ReduceOp op;
AlignedVector<T, N> vals[M]; T vals[M][N];
AlignedVector<U, M> accs; U accs[M];
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
accs[i] = init; accs[i] = init;
} }
@@ -100,31 +101,43 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M)); min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N); const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N); const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size + block.thread_rank() * N; in += start_row * size;
out += start_row; out += start_row;
for (size_t r = 0; r < full_blocks; r++) { if (size % N == 0) {
for (int k = 0; k < M; k++) { for (size_t r = 0; r < full_blocks; r++) {
vals[k] = load_vector<N>(in + k * size, 0); for (int k = 0; k < M; k++) {
} cub::LoadDirectBlockedVectorized<T, N>(
for (int k = 0; k < M; k++) { block.thread_rank(),
for (int j = 0; j < N; j++) { in + k * size + r * (block.size() * N),
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
} }
} }
in += block.size() * N;
} }
if (final_offset < size) { if (final_offset < size) {
for (int k = 0; k < M; k++) { for (int k = 0; k < M; k++) {
for (int i = 0; i < N; i++) { cub::LoadDirectBlocked(
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) block.thread_rank(),
? in[k * size + i] in + k * size + final_offset,
: cast_to<T>(init); vals[k],
} size,
} cast_to<T>(init));
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
@@ -132,11 +145,13 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
} }
__shared__ U shared_accumulators[32 * M]; __shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs.val, shared_accumulators, op, init); block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) { if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) { if (grid.block_rank() * M + M <= n_rows) {
store_vector(out, 0, accs); for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
} else { } else {
short offset = grid.block_rank() * M + M - n_rows; short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) { for (int i = offset; i < M; i++) {
@@ -146,10 +161,17 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
} }
} }
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4> template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
__global__ void row_reduce_looped( __global__ void row_reduce_looped(
const T* in, T* in,
U* out, U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) { const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
@@ -163,60 +185,36 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value(); U init = ReduceInit<Op, T>::value();
total[0] = init; total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
const size_t full_blocks = args.row_size / (block.size() * N_READS); size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS); size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
// Unaligned reduce for (size_t n = 0; n < args.non_row_reductions; n++) {
if (final_offset < args.row_size) { for (size_t r = 0; r < full_blocks; r++) {
bool mask[N_READS]; T vals[N_READS];
for (int i = 0; i < N_READS; i++) { cub::LoadDirectBlockedVectorized<T, N_READS>(
mask[i] = block.thread_rank(),
(final_offset + block.thread_rank() * N_READS + i) < args.row_size; in + loop.location() + r * BLOCK_DIM * N_READS,
} vals);
for (int i = 0; i < N_READS; i++) {
for (size_t n = 0; n < args.non_row_reductions; n++) { total[0] = op(total[0], cast_to<U>(vals[i]));
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
} }
{
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
} if (final_offset < args.row_size) {
T vals[N_READS];
// Aligned case cub::LoadDirectBlocked(
else { block.thread_rank(),
for (size_t n = 0; n < args.non_row_reductions; n++) { in + loop.location() + final_offset,
const T* inlocal = in + loop.location(); vals,
args.row_size - final_offset,
for (size_t r = 0; r < full_blocks; r++) { cast_to<T>(init));
auto vals = load_vector<N_READS>(inlocal, 0); for (int i = 0; i < N_READS; i++) {
for (int i = 0; i < N_READS; i++) { total[0] = op(total[0], cast_to<U>(vals[i]));
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
} }
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
__shared__ U shared_accumulators[32]; __shared__ U shared_accumulators[32];
@@ -236,6 +234,8 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the // Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel. // kernel.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -250,15 +250,14 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T); // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; int threads = std::min(1024UL, reductions);
warps /= 4; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
@@ -268,7 +267,6 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>; kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
} }
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back(); int size = plan.shape.back();
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size); kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -284,6 +282,8 @@ void row_reduce_looped(
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::RowReduceArgs args) { cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -295,27 +295,34 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag); using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
constexpr int N_READS = 16 / sizeof(T); T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
args.sort_access_pattern(in, axes); args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS; size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; int threads = std::min(1024UL, reductions);
warps /= 4; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>; auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>; dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
}); });
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args); kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
}); });
}); });
} }

Some files were not shown because too many files have changed in this diff Show More