Compare commits

..

10 Commits

Author SHA1 Message Date
Joona Havukainen
9a742090ae Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR. 2025-07-08 20:56:06 -07:00
Joona Havukainen
aca7fac9ef Make the max nanpropagation test more meaningful for integer types 2025-07-08 16:42:19 -07:00
Joona Havukainen
8b15773206 Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16. 2025-07-08 16:41:56 -07:00
Joona Havukainen
3e885f583a Cleanup using namespace alias 2025-07-07 18:25:57 -07:00
Joona Havukainen
c7af3016eb Only check nans on non-integral types in simd_reduce_impl. 2025-07-07 18:24:30 -07:00
Joona Havukainen
9794ec6b8e Improve the cpp unittest 2025-07-06 16:04:50 -07:00
Joona Havukainen
e0bb9f3ef8 Fix max complex64 nan propagation and add test 2025-07-06 15:53:50 -07:00
Joona Havukainen
5b089dc5da Pre-commit formatting 2025-07-06 15:50:16 -07:00
Joona Havukainen
af74818528 Adding benchmarks and testing for max op nanpropagation 2025-07-06 15:50:16 -07:00
Joona Havukainen
0d30e9e8ec Make max op NaN propagation rules align with numpy 2025-07-06 15:50:16 -07:00
144 changed files with 1598 additions and 4067 deletions

View File

@@ -7,6 +7,18 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -29,7 +41,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
pip install . -v CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when: - when:
condition: condition:
not: << parameters.upload-docs >> not: << parameters.upload-docs >>
@@ -61,9 +73,9 @@ jobs:
git push -f origin gh-pages git push -f origin gh-pages
linux_build_and_test: linux_build_and_test:
machine: docker:
image: ubuntu-2204:current - image: cimg/python:3.9
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
@@ -75,17 +87,21 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get upgrade -y
pip install --upgrade cmake pip install --upgrade cmake
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
pip install -e ".[dev]" CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -95,14 +111,13 @@ jobs:
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
python -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
mkdir -p build && cd build mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc` make -j `nproc`
- run: - run:
@@ -142,7 +157,8 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -157,8 +173,7 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
@@ -193,7 +208,8 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
@@ -201,7 +217,7 @@ jobs:
cuda_build_and_test: cuda_build_and_test:
machine: machine:
image: linux-cuda-12:2023.11.1 image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout
@@ -210,9 +226,10 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env python -m venv env
source env/bin/activate source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]" pip install -e ".[dev]"
- run: - run:
name: Run Python tests name: Run Python tests
@@ -261,6 +278,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -272,18 +290,9 @@ jobs:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
python setup.py clean --all << parameters.build_env >> \
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
- when: python -m build -w
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
steps: steps:
@@ -300,71 +309,63 @@ jobs:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
build_env: extra_env:
type: string type: string
default: "" default: "DEV_RELEASE=1"
machine: docker:
image: ubuntu-2204:current - image: ubuntu:20.04
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
name: Build wheel name: Build wheel
command: | command: |
PYTHON=python<< parameters.python_version >> PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive apt-get update
export NEEDRESTART_MODE=a apt-get upgrade -y
sudo apt-get update DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
sudo apt-get upgrade -y apt-get install -y apt-utils
TZ=Etc/UTC sudo apt-get -y install tzdata apt-get install -y software-properties-common
sudo apt-get install -y apt-utils add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y software-properties-common apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo add-apt-repository -y ppa:deadsnakes/ppa apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full apt-get install -y build-essential git
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env $PYTHON -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.build_env >> pip install ".[dev]" -v << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
python setup.py clean --all << parameters.extra_env >> \
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
bash python/scripts/repair_linux.sh python -m build --wheel
- when: auditwheel show dist/*
condition: auditwheel repair dist/* --plat manylinux_2_31_x86_64
equal: ["3.9", << parameters.python_version >>] - run:
steps: name: Upload package
- run: command: |
name: Build common package source env/bin/activate
command: | twine upload wheelhouse/*
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release: build_cuda_release:
parameters: parameters:
build_env: python_version:
type: string type: string
default: "" default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine: machine:
image: linux-cuda-12:default image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
@@ -375,25 +376,27 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.build_env >> MLX_BUILD_STAGE=2 \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh
- when: - run:
condition: << parameters.build_env >> name: Upload package
steps: command: |
- run: source env/bin/activate
name: Upload package twine upload wheelhouse/*.whl
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -405,6 +408,8 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
@@ -418,6 +423,8 @@ workflows:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
filters: filters:
@@ -499,25 +506,6 @@ workflows:
branches: branches:
ignore: /.*/ ignore: /.*/
upload-docs: true upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@@ -596,8 +584,99 @@ workflows:
- macosx_deployment_target: "15.0" - macosx_deployment_target: "15.0"
xcode_version: "15.0.0" xcode_version: "15.0.0"
python_version: "3.13" python_version: "3.13"
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -19,7 +19,6 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -64,8 +64,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------

View File

@@ -203,11 +203,6 @@ void time_reductions() {
TIME(max_along_0); TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1); TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {

View File

@@ -58,13 +58,6 @@ def time_max():
time_fn(mx.max, a, 0) time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -122,7 +115,6 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_min()
time_max() time_max()
time_maximum() time_maximum()
time_exp() time_exp()

View File

@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<array>, std::vector<int>> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/

View File

@@ -23,6 +23,13 @@ To install from PyPI you must meet the following requirements:
MLX is only available on devices running macOS >= 13.5 MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma) It is highly recommended to use macOS 14 (Sonoma)
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
conda install conda-forge::mlx
CUDA CUDA
^^^^ ^^^^
@@ -31,16 +38,8 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell .. code-block:: shell
pip install "mlx[cuda]" pip install mlx-cuda
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install "mlx[cpu]"
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@@ -89,20 +88,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
pip install . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
pip install -e ".[dev]" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
python setup.py build_ext --inplace CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -263,7 +262,7 @@ When building either the Python or C++ APIs make sure to pass the cmake flag
.. code-block:: shell .. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run: To build the C++ package run:

View File

@@ -19,4 +19,3 @@ Common Optimizers
Adamax Adamax
Lion Lion
MultiOptimizer MultiOptimizer
Muon

View File

@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** The name of primitive. */ /** Print the primitive. */
const char* name() const override { void print(std::ostream& os) override {
return "Axpby"; os << "Axpby";
} }
/** Equivalence check **/ /** Equivalence check **/

View File

@@ -1,20 +1,14 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
std::filesystem::path current_binary_dir() { std::string get_primitive_string(Primitive* primitive) {
static std::filesystem::path binary_dir = []() { std::ostringstream op_t;
Dl_info info; primitive->print(op_t);
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) { return op_t.str();
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
} }
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims( std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(

View File

@@ -2,7 +2,6 @@
#pragma once #pragma once
#include <filesystem>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@@ -10,8 +9,7 @@
namespace mlx::core { namespace mlx::core {
// Return the directory that contains current shared library. std::string get_primitive_string(Primitive* primitive);
std::filesystem::path current_binary_dir();
inline int64_t inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) { elem_to_loc(int elem, const Shape& shape, const Strides& strides) {

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the // The decomposition is computed in place, so just copy the input to the
// output. // output.
copy_cpu( copy(
a, a,
factor, factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -231,7 +231,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl; << namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else { } else {
os << x.primitive().name(); x.primitive().print(os);
os << "()("; os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";

View File

@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
// Fill with zeros // Fill with zeros
std::vector<array> temps; std::vector<array> temps;
temps.push_back(array(0, conv_dtype)); temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1]; size_t data_offset = padding_lo[0] * in_padded.strides()[1];
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded_slice.size(), in_padded_slice.size(),
data_offset); data_offset);
// Copy input values into the slice // Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice); temps.push_back(in_padded_slice);
// Make strided view // Make strided view
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
// Materialize strided view // Materialize strided view
Shape strided_reshape = {N * oH, wH * C}; Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream); copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided); temps.push_back(in_strided);
// Check wt dtype and prepare // Check wt dtype and prepare
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
wt.size(), wt.size(),
0); 0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream); copy(wt_transpose, gemm_wt, CopyType::General, stream);
temps.push_back(gemm_wt); temps.push_back(gemm_wt);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype = auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {}); gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream); copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt); temps.push_back(gemm_wt);
} }
@@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
// Copy results if needed // Copy results if needed
if (out.dtype() != float32) { if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); copy_inplace(gemm_out, out, CopyType::Vector, stream);
} }
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
@@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
// Fill with zeros // Fill with zeros
std::vector<array> temps; std::vector<array> temps;
temps.push_back(array(0, conv_dtype)); temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] + size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
@@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
temps.push_back(in_padded_slice); temps.push_back(in_padded_slice);
// Copy input values into the slice // Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view // Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C}; Shape strided_shape = {N, oH, oW, wH, wW, C};
@@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
// Materialize strided view // Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C}; Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream); copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided); temps.push_back(in_strided);
// Check wt dtype and prepare // Check wt dtype and prepare
@@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
auto ctype = auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {}); gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream); copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt); temps.push_back(gemm_wt);
} }
@@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
// Copy results if needed // Copy results if needed
if (out.dtype() != float32) { if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); copy_inplace(gemm_out, out, CopyType::Vector, stream);
} }
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
@@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
// Fill with zeros // Fill with zeros
std::vector<array> temps = {array(0, conv_dtype)}; std::vector<array> temps = {array(0, conv_dtype)};
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = 0; size_t data_offset = 0;
@@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
data_offset); data_offset);
// Copy input values into the slice // Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice); temps.push_back(in_padded_slice);
// Make strided view // Make strided view
@@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
} }
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream); copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided); temps.push_back(in_strided);
// Check wt dtype and prepare // Check wt dtype and prepare
@@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
auto ctype = auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {}); gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream); copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt); temps.push_back(gemm_wt);
} }
if (flip) { if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream); copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
temps.push_back(gemm_wt_); temps.push_back(gemm_wt_);
// Calculate the total size of the spatial dimensions // Calculate the total size of the spatial dimensions
@@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
// Copy results if needed // Copy results if needed
if (out.dtype() != float32) { if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); copy_inplace(gemm_out, out, CopyType::Vector, stream);
} }
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }

View File

@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_cpu_inplace( void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src); encoder.set_input_array(src);
encoder.set_output_array(dst); encoder.set_output_array(dst);
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) { void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype); bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) { if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to // If the output has the same type as the input then there is nothing to
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_cpu_inplace(src, dst, ctype, stream); copy_inplace(src, dst, ctype, stream);
} }
void copy_cpu_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const Shape& data_shape,

View File

@@ -10,14 +10,10 @@
namespace mlx::core { namespace mlx::core {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace( void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy_cpu_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const Shape& data_shape,

View File

@@ -14,7 +14,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
return {arr, false}; return {arr, false};
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream); copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true}; return {arr_copy, true};
} }
}; };
@@ -35,7 +35,7 @@ void AllReduce::eval_cpu(
return in; return in;
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_cpu(in, arr_copy, CopyType::General, s); copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy); out.copy_shared_buffer(arr_copy);
return arr_copy; return arr_copy;
} }

View File

@@ -135,7 +135,7 @@ void Eig::eval_cpu(
: array(a.shape(), complex64, nullptr, {}); : array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu( copy(
a, a,
a_copy, a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -196,7 +196,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes())); values.set_data(allocator::malloc(values.nbytes()));
copy_cpu( copy(
a, a,
vectors, vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) { if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_cpu( copy(
in, in,
out, out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General, in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream()); copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds; std::vector<array> inds;
@@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
auto ctype = auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream()); copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx); encoder.set_input_array(idx);

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹ // (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output. // The inverse is computed in place, so just copy the input to the output.
copy_cpu( copy(
a, a,
inv, inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -88,7 +88,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, x_copy, CopyType::General, s); copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -31,7 +31,7 @@ void luf_impl(
strides[ndim - 1] = M; strides[ndim - 1] = M;
strides[ndim - 2] = 1; strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_cpu_inplace( copy_inplace(
a, a,
lu, lu,
a.shape(), a.shape(),

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -53,58 +52,6 @@ inline void mask_matrix(
} }
} }
template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
Shape a_copy = a_shape;
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
for (int i = 0; i < num_segments; i++) {
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
a + k_start * a_strides[ndim - 1],
b + k_start * b_strides[ndim - 2],
out + i * M * N,
a_transposed,
b_transposed,
lda,
ldb,
N,
1.0,
0.0,
1,
a_copy,
a_strides,
b_copy,
b_strides);
}
}
} // namespace } // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -124,20 +71,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
return std::make_tuple(false, stx, arr, false); return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) { if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s); copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true); return std::make_tuple(true, sty, arr_copy, true);
} }
return std::make_tuple(true, sty, arr, false); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s); copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
@@ -386,7 +333,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back()); return std::make_tuple(false, stx, temps.back());
} }
@@ -490,121 +437,4 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cpu::get_command_encoder(stream());
auto check_transpose = [&s, &encoder](const array& x) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
if (stx == x.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, x);
} else if (stx == 1 && sty == x.shape(-2)) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);
}
};
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
auto& segments = inputs[2];
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(segments);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
segments = array::unsafe_weak_copy(segments),
out_ptr = out.data<void>(),
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb]() {
switch (a.dtype()) {
case float64:
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float32:
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float16:
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case bfloat16:
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
default:
throw std::invalid_argument(
"Segmented mm supports only real float types.");
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -81,7 +81,7 @@ void matmul_general(
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, stream); copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1); stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back()); return std::make_tuple(false, stx, temps.back());
} }
@@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream()); copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) { if (inputs[0].shape(-1) == 0) {
return; return;
} }

View File

@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream()); copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(in, out, ctype, stream()); copy(in, out, ctype, stream());
} }
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) { void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i]; size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer( out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset); out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
} }
} }
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous))) { (allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy_cpu(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());
} }
} }
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_cpu(in, out, ctype, stream()); copy(in, out, ctype, stream());
} }
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) { void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val // Fill output with val
copy_cpu(val, out, CopyType::Scalar, stream()); copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values // Find offset for start of input values
size_t data_offset = 0; size_t data_offset = 0;
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset); out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice // Copy input values into the slice
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
} }
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) { void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] = auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_cpu_inplace( copy_inplace(
/* const array& src = */ in, /* const array& src = */ in,
/* array& dst = */ out, /* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(), /* const Shape& data_shape = */ out.shape(),
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [out_offset, donated] = auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream()); compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_cpu_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_); prepare_slice(out, start_indices_, strides_);
// Do copy // Do copy
copy_cpu_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {}); auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in); in_tmp.copy_shared_buffer(in);
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream()); copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else { } else {
copy_cpu_inplace(in, tmp, CopyType::General, stream()); copy_inplace(in, tmp, CopyType::General, stream());
} }
auto flags = out.flags(); auto flags = out.flags();

View File

@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream); copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes())); q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes())); r.set_data(allocator::malloc(r.nbytes()));

View File

@@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr; return arr;
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
return temps.back(); return temps.back();
} }
}; };
@@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr; return arr;
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s); copy(arr, temps.back(), CopyType::General, s);
return temps.back(); return temps.back();
} }
}; };
@@ -713,7 +713,7 @@ void fast::AffineQuantize::eval_cpu(
return std::make_pair(arr, false); return std::make_pair(arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s); copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true); return std::make_pair(arr_copy, true);
} }
}; };

View File

@@ -350,15 +350,7 @@ struct MinReduce {
}; };
template <int N, typename T> template <int N, typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) { T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x); return simd::min(x);
}; };
}; };

View File

@@ -251,7 +251,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_cpu(in, arr_copy, CopyType::General, stream()); copy(in, arr_copy, CopyType::General, stream());
in = arr_copy; in = arr_copy;
encoder.add_temporary(arr_copy); encoder.add_temporary(arr_copy);
} }

View File

@@ -132,7 +132,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, x_copy, CopyType::General, s); copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -334,10 +334,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Copy input to output // Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
? CopyType::Vector copy(in, out, ctype, stream());
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -428,10 +426,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Copy input to output // Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
? CopyType::Vector copy(in, out, ctype, stream());
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -31,7 +31,7 @@ void svd_impl(
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
copy_cpu( copy(
a, a,
in, in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -35,14 +35,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -69,11 +67,6 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. # CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to # Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed. # true but the warning wouldn't be suppressed.
@@ -126,7 +119,3 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -16,86 +17,35 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(a[0], b[0]);
for (int i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(a[0], b[index]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[i]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(a[index], b[0]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(a[index], b[index]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@@ -176,7 +126,7 @@ template <typename Op>
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, std::string_view op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
@@ -246,25 +196,18 @@ void binary_op_gpu_inplace(
} }
}); });
} else { } else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) { } else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, kernel, out.data_size(), out.shape(), out.strides(), large());
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,
@@ -290,7 +233,7 @@ template <typename Op>
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, std::string_view op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@@ -299,11 +242,11 @@ void binary_op_gpu(
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \ nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \ auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, name(), s); \ binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
} }
BINARY_GPU(Add) BINARY_GPU(Add)
@@ -327,31 +270,33 @@ BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) { void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu"); nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) { if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s); binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else { } else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s); binary_op_gpu<cu::Equal>(inputs, out, op, s);
} }
} }
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s); binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s); binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s); binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s); binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s); binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break; break;
} }
} }

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -16,119 +17,52 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void __global__ void
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { auto out = Op{}(a[0], b[0]);
for (IdxT i = index * N_READS; i < size; ++i) { out_a[0] = out[0];
auto out = Op{}(a[0], b[0]); out_b[0] = out[1];
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void __global__ void
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { auto out = Op{}(a[0], b[index]);
for (IdxT i = index * N_READS; i < size; ++i) { out_a[index] = out[0];
auto out = Op{}(a[0], b[i]); out_b[index] = out[1];
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void __global__ void
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { auto out = Op{}(a[index], b[0]);
for (IdxT i = index * N_READS; i < size; ++i) { out_a[index] = out[0];
auto out = Op{}(a[i], b[0]); out_b[index] = out[1];
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void __global__ void
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { auto out = Op{}(a[index], b[index]);
for (IdxT i = index * N_READS; i < size; ++i) { out_a[index] = out[0];
auto out = Op{}(a[i], b[i]); out_b[index] = out[1];
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int NDIM> template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_two_g_nd( __global__ void binary_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
@@ -148,7 +82,7 @@ __global__ void binary_two_g_nd(
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_two_g( __global__ void binary_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
@@ -169,7 +103,7 @@ __global__ void binary_two_g(
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
constexpr bool supports_binary_two_op() { constexpr bool supports_binary_op() {
if (std::is_same_v<Op, DivMod>) { if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> && return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>); (std::is_integral_v<Out> || is_floating_v<Out>);
@@ -180,10 +114,10 @@ constexpr bool supports_binary_two_op() {
} // namespace cu } // namespace cu
template <typename Op> template <typename Op>
void binary_two_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, std::string_view op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
@@ -207,7 +141,7 @@ void binary_two_op_gpu_inplace(
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) { if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
@@ -227,12 +161,8 @@ void binary_two_op_gpu_inplace(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd< auto kernel = cu::
Op, binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(kernel, out_a, large());
encoder.add_kernel_node( encoder.add_kernel_node(
@@ -249,7 +179,7 @@ void binary_two_op_gpu_inplace(
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>; auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(kernel, out_a, large());
encoder.add_kernel_node( encoder.add_kernel_node(
@@ -268,25 +198,22 @@ void binary_two_op_gpu_inplace(
} }
}); });
} else { } else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) { } else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, kernel,
out_a.data_size(), out_a.data_size(),
out_a.shape(), out_a.shape(),
out_a.strides(), out_a.strides(),
large(), large());
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,
@@ -310,17 +237,17 @@ void binary_two_op_gpu_inplace(
} }
template <typename Op> template <typename Op>
void binary_two_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, std::string_view op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt); set_binary_op_output_data(a, b, outputs[1], bopt);
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s); binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
} }
void DivMod::eval_gpu( void DivMod::eval_gpu(
@@ -328,7 +255,7 @@ void DivMod::eval_gpu(
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu"); nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s); binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -53,10 +53,9 @@ struct FusedKernelBuilder {
// Build function signature. // Build function signature.
if (contiguous) { if (contiguous) {
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n"; os += "template <typename IdxT = uint32_t>\n";
} else { } else {
os += os += "template <int NDIM, typename IdxT = uint32_t>\n";
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
} }
os += fmt::format("__global__ void {}(\n", kernel_name + name); os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
@@ -68,46 +67,12 @@ struct FusedKernelBuilder {
} }
os += ") {\n"; os += ") {\n";
// Index. For non contiguous kernels we create a separate index // Index.
// variable per variable otherwise everyone uses `index`.
os += os +=
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n" " IdxT index = cg::this_grid().thread_rank();\n"
" if (index >= size) {\n" " if (index >= size) {\n"
" return;\n" " return;\n"
" }\n"; " }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@@ -124,9 +89,12 @@ struct FusedKernelBuilder {
} else if (contiguous) { } else if (contiguous) {
value = fmt::format("{}[index]", xname); value = fmt::format("{}[index]", xname);
} else { } else {
value = fmt::format("{}[{}_idx]", xname, xname); std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
} }
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
} }
// Write tape. // Write tape.
@@ -138,37 +106,23 @@ struct FusedKernelBuilder {
value = fmt::format( value = fmt::format(
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
} else { } else {
value = x.primitive().name(); std::ostringstream ss;
x.primitive().print(ss);
value = ss.str();
value += "{}("; value += "{}(";
for (size_t i = 0; i < x.inputs().size() - 1; ++i) { for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
} }
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
} }
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
} }
// Write output. // Write output.
for (const auto& x : outputs) { for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
} }
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n"; os += "}\n";
} }
}; };
@@ -204,28 +158,15 @@ void Compiled::eval_gpu(
builder.build("_strided", false); builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names = {
for (auto work_per_thread : std::array<int, 2>{1, 4}) { fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>", "mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
lib_name(), kernel_names.push_back(
work_per_thread)); fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_pair(std::move(builder.os), std::move(kernel_names));
}); });
@@ -268,21 +209,13 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size()); args.append<uint32_t>(outputs[0].data_size());
} }
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel. // Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t"; const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) { if (contiguous) {
kernel_name += kernel_name += fmt::format("_contiguous<{}>", index_type);
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
} else { } else {
kernel_name += fmt::format( kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
} }
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) { for (const auto& in : inputs) {
@@ -293,8 +226,7 @@ void Compiled::eval_gpu(
} }
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }

View File

@@ -10,43 +10,19 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int N_READS> template <typename In, typename Out, typename IdxT>
__global__ void copy_s(const In* in, Out* out, IdxT size) { __global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = CastOp<In, Out>{}(in[0]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
template <typename In, typename Out, typename IdxT, int N_READS> template <typename In, typename Out, typename IdxT>
__global__ void copy_v(const In* in, Out* out, IdxT size) { __global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = CastOp<In, Out>{}(in[index]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@@ -65,19 +41,12 @@ void copy_contiguous(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. auto kernel = cu::copy_s<InType, OutType, IdxT>;
constexpr int N_READS = 4;
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>; kernel = cu::copy_v<InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, kernel, out.data_size(), out.shape(), out.strides(), large());
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -57,15 +57,8 @@ void Device::make_current() {
} }
} }
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
@@ -175,7 +168,15 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
} }
} }
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
} }
@@ -263,30 +264,22 @@ void CommandEncoder::commit() {
graph_key_ += std::to_string(graph_node_count_); graph_key_ += std::to_string(graph_node_count_);
graph_key_ += "."; graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_); graph_key_ += std::to_string(empty_node_count_);
auto [it, _] = graph_cache_.emplace(graph_key_, nullptr);
auto& graph_exec = it->second;
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; if (graph_exec != NULL) {
cudaGraphExecUpdateResultInfo update_result;
if (graph_exec != nullptr) { cudaGraphExecUpdate(graph_exec, graph_, &update_result);
cudaGraphExecUpdateResult update_result; if (update_result.result != cudaGraphExecUpdateSuccess) {
#if CUDART_VERSION >= 12000 cudaGetLastError();
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
graph_exec = nullptr; graph_exec = NULL;
} }
} }
if (graph_exec == nullptr) { if (graph_exec == NULL) {
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
} }
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy // TODO smarter cache policy

View File

@@ -93,7 +93,6 @@ class CommandEncoder {
void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes); void insert_graph_dependencies(std::vector<GraphNode> nodes);
Device& device_;
CudaStream stream_; CudaStream stream_;
cudaGraph_t graph_; cudaGraph_t graph_;
Worker worker_; Worker worker_;

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include <cuda/atomic> #include <cuda/atomic>
@@ -48,7 +48,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
atomicAdd(out, val); atomicAdd(out, val);
} }
inline __device__ void atomic_add(complex64_t* out, complex64_t val) { inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
#if __CUDA_ARCH__ < 900 #if __CUDA_ARCH__ < 900
atomic_add_general(out, val); atomic_add_general(out, val);
#else #else
@@ -58,7 +58,12 @@ inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
#if __CUDA_ARCH__ < 800 #if __CUDA_ARCH__ < 800
#if CCCL_VERSION >= 2008000
atomic_add_general(out, val); atomic_add_general(out, val);
#else
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
assert(cccl_version_too_old_for_bfloat16_atomic_add);
#endif
#else #else
atomicAdd(out, val); atomicAdd(out, val);
#endif #endif

View File

@@ -1,7 +1,10 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda/std/array> #include <cuda/std/array>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -44,7 +47,7 @@ struct Remainder {
} else { } else {
return x % y; return x % y;
} }
} else if constexpr (is_complex_v<T>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return x % y; return x % y;
} else { } else {
T r = fmod(x, y); T r = fmod(x, y);
@@ -66,12 +69,14 @@ struct Equal {
struct NaNEqual { struct NaNEqual {
template <typename T> template <typename T>
__device__ bool operator()(T x, T y) { __device__ bool operator()(T x, T y) {
if constexpr (is_complex_v<T>) { if constexpr (std::is_same_v<T, cuComplex>) {
return x == y || return x == y ||
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) && (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(y.imag())) || isnan(cuCimagf(y))) ||
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) || (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag()); isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else { } else {
return x == y || (isnan(x) && isnan(y)); return x == y || (isnan(x) && isnan(y));
} }
@@ -109,38 +114,36 @@ struct LessEqual {
struct LogAddExp { struct LogAddExp {
template <typename T> template <typename T>
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (is_complex_v<T>) { if (isnan(x) || isnan(y)) {
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) || return cuda::std::numeric_limits<T>::quiet_NaN();
isnan(y.imag())) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = x.real() > y.real() ? x : y;
auto min = x.real() < y.real() ? x : y;
auto min_real = min.real();
auto max_real = max.real();
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
} else {
return Log{}(Exp{}(min) + Exp{}(max));
}
} else {
return Log1p{}(Exp{}(min - max)) + max;
}
} else {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
} }
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
}; };
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
}; };
struct Maximum { struct Maximum {
@@ -148,8 +151,8 @@ struct Maximum {
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y); return max(x, y);
} else if constexpr (is_complex_v<T>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(x.real()) || isnan(x.imag())) { if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x; return x;
} }
return x > y ? x : y; return x > y ? x : y;
@@ -167,8 +170,8 @@ struct Minimum {
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y); return min(x, y);
} else if constexpr (is_complex_v<T>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(x.real()) || isnan(x.imag())) { if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x; return x;
} }
return x < y ? x : y; return x < y ? x : y;
@@ -191,8 +194,8 @@ struct Multiply {
struct NotEqual { struct NotEqual {
template <typename T> template <typename T>
__device__ bool operator()(T x, T y) { __device__ bool operator()(T x, T y) {
if constexpr (is_complex_v<T>) { if constexpr (std::is_same_v<T, cuComplex>) {
return x.real() != y.real() || x.imag() != y.imag(); return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
} else { } else {
return x != y; return x != y;
} }
@@ -212,8 +215,19 @@ struct Power {
base *= base; base *= base;
} }
return res; return res;
} else if constexpr (is_complex_v<T>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return pow(base, exp); if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else { } else {
return powf(base, exp); return powf(base, exp);
} }

View File

@@ -2,10 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/complex.cuh" #include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -20,48 +17,34 @@ struct CastOp {
} }
}; };
// Castings between complex and boolean.
template <typename T>
struct CastOp<complex_t<T>, bool> {
static constexpr bool is_castable = true;
__device__ bool operator()(complex_t<T> x) {
return x.real() != 0 && x.imag() != 0;
}
};
template <typename T>
struct CastOp<bool, complex_t<T>> {
static constexpr bool is_castable = true;
__device__ complex_t<T> operator()(bool x) {
return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
}
};
// Converting a complex number to real number discards the imaginary part. // Converting a complex number to real number discards the imaginary part.
template <typename T, typename DstT> template <typename DstT>
struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> { struct CastOp<
static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>; cuComplex,
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(complex_t<T> x) { __device__ DstT operator()(cuComplex x) {
static_assert(!is_complex_v<DstT>); static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
return static_cast<DstT>(x.real()); return static_cast<DstT>(cuCrealf(x));
} }
}; };
// Allow converting a real number to complex number. // Allow converting a real number to complex number.
template <typename SrcT, typename T> template <typename SrcT>
struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> { struct CastOp<
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>; SrcT,
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ complex_t<T> operator()(SrcT x) { __device__ cuComplex operator()(SrcT x) {
static_assert(!is_complex_v<SrcT>); static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
return complex_t<T>{static_cast<T>(x), 0}; return cuComplex{static_cast<float>(x), 0};
} }
}; };
// Do nothing when no casting is needed.
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
struct CastOp< struct CastOp<
SrcT, SrcT,
@@ -74,51 +57,9 @@ struct CastOp<
} }
}; };
// In CUDA 11 the half types do not define conversions between some types,
// provide fallbacks here.
#if CUDART_VERSION < 12000
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
(cuda::std::is_same_v<DstT, __half> ||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
!cuda::std::is_same_v<DstT, __half> &&
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
(cuda::std::is_same_v<SrcT, __half> ||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
#endif // CUDART_VERSION < 12000
// Helper to deduce the SrcT.
template <typename DstT, typename SrcT>
inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp. // Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator> template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) { __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type; using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) { if constexpr (std::is_same_v<SrcT, DstT>) {
return it; return it;

View File

@@ -1,60 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
// Make multiplication and division faster.
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
#include <cuda/std/complex>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
// TODO: Consider using a faster implementation as cuda::std::complex has to
// conform to C++ standard.
template <typename T>
using complex_t = cuda::std::complex<T>;
using complex64_t = complex_t<float>;
using complex128_t = complex_t<double>;
template <typename T>
struct is_complex : cuda::std::false_type {};
template <typename T>
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;
// cuda::std::complex is missing some operators.
template <typename T>
inline __host__ __device__ complex_t<T> operator%(
complex_t<T> a,
complex_t<T> b) {
T r = a.real() - floor(a.real() / b.real()) * b.real();
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
return complex_t<T>{r, i};
}
template <typename T>
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
}
template <typename T>
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
return operator>(b, a);
}
template <typename T>
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
return !(a > b);
}
template <typename T>
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
return !(a < b);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,240 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2017-2024 The Simons Foundation, Inc.
//
// FINUFFT is licensed under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
#pragma once
#include <cuComplex.h>
// This header provides some helper functions for cuComplex types.
// It mainly wraps existing CUDA implementations to provide operator overloads
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
// all provided by CUDA
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCadd(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCsub(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCmul(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCdiv(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
return make_cuDoubleComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(double a, const cuDoubleComplex& b) {
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
return make_cuDoubleComplex(
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCaddf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCsubf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCmulf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCdivf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
return make_cuFloatComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuFloatComplex& a,
const cuFloatComplex& b) {
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(float a, const cuFloatComplex& b) {
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
return make_cuFloatComplex(
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
}

View File

@@ -14,6 +14,8 @@ struct Abs {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) { if constexpr (cuda::std::is_unsigned_v<T>) {
return x; return x;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
} else { } else {
return abs(x); return abs(x);
} }
@@ -25,6 +27,8 @@ struct ArcCos {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return acos(x); return acos(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcCosh { struct ArcCosh {
@@ -39,6 +43,8 @@ struct ArcSin {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return asin(x); return asin(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcSinh { struct ArcSinh {
@@ -53,6 +59,8 @@ struct ArcTan {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return atan(x); return atan(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcTanh { struct ArcTanh {
@@ -74,8 +82,6 @@ struct Ceil {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x; return x;
} else if constexpr (is_complex_v<T>) {
return T{ceil(x.real()), ceil(x.imag())};
} else { } else {
return ceil(x); return ceil(x);
} }
@@ -83,23 +89,34 @@ struct Ceil {
}; };
struct Conjugate { struct Conjugate {
template <typename T> __device__ cuComplex operator()(cuComplex x) {
__device__ complex_t<T> operator()(complex_t<T> x) { return {cuCrealf(x), -cuCimagf(x)};
return conj(x);
} }
}; };
struct Cos { struct Cos {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return cos(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return cos(x);
}
} }
}; };
struct Cosh { struct Cosh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return cosh(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return cosh(x);
}
} }
}; };
@@ -132,7 +149,12 @@ struct ErfInv {
struct Exp { struct Exp {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return exp(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
} else {
return exp(x);
}
} }
}; };
@@ -154,8 +176,6 @@ struct Floor {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x; return x;
} else if constexpr (is_complex_v<T>) {
return T{floor(x.real()), floor(x.imag())};
} else { } else {
return floor(x); return floor(x);
} }
@@ -163,25 +183,30 @@ struct Floor {
}; };
struct Imag { struct Imag {
template <typename T> __device__ float operator()(cuComplex x) {
__device__ auto operator()(complex_t<T> x) { return cuCimagf(x);
return x.imag();
} }
}; };
struct Log { struct Log {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return log(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
} }
}; };
struct Log2 { struct Log2 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (is_complex_v<T>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x); auto y = Log{}(x);
return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F}; return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else { } else {
return log2(x); return log2(x);
} }
@@ -191,31 +216,20 @@ struct Log2 {
struct Log10 { struct Log10 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return log10(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
} }
}; };
struct Log1p { struct Log1p {
template <typename T> template <typename T>
__device__ T operator()(T z) { __device__ T operator()(T x) {
if constexpr (is_complex_v<T>) { return log1p(x);
float x = z.real();
float y = z.imag();
float zabs = Abs{}(z).real();
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
} }
}; };
@@ -228,8 +242,8 @@ struct LogicalNot {
struct Negative { struct Negative {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (is_complex_v<T>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0} - x; return 0 - x;
} else { } else {
return -x; return -x;
} }
@@ -237,17 +251,16 @@ struct Negative {
}; };
struct Real { struct Real {
template <typename T> __device__ float operator()(cuComplex x) {
__device__ auto operator()(complex_t<T> x) { return cuCrealf(x);
return x.real();
} }
}; };
struct Round { struct Round {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (is_complex_v<T>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {rint(x.real()), rint(x.imag())}; return {rint(cuCrealf(x)), rint(cuCimagf(x))};
} else { } else {
return rint(x); return rint(x);
} }
@@ -267,8 +280,8 @@ struct Sign {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) { if constexpr (cuda::std::is_unsigned_v<T>) {
return x != 0; return x != 0;
} else if constexpr (is_complex_v<T>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (x.real() == 0 && x.imag() == 0) { if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
return x; return x;
} else { } else {
return x / Abs()(x); return x / Abs()(x);
@@ -284,14 +297,26 @@ struct Sign {
struct Sin { struct Sin {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return sin(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return sin(x);
}
} }
}; };
struct Sinh { struct Sinh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return sinh(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return sinh(x);
}
} }
}; };
@@ -307,31 +332,77 @@ struct Sqrt {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return sqrt(x); return sqrt(x);
} }
__device__ cuComplex operator()(cuComplex x) {
auto xr = cuCrealf(x);
auto xi = cuCimagf(x);
if (xr == 0.0f && xi == 0.0f) {
return {0.0f, 0.0f};
}
auto r = cuCrealf(Abs{}(x));
auto a = sqrt((r + xr) / 2.0f);
auto b_abs = sqrt((r - xr) / 2.0f);
auto b = copysign(b_abs, xi);
return {a, b};
}
}; };
struct Rsqrt { struct Rsqrt {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (is_complex_v<T>) { return rsqrt(x);
return 1.0f / Sqrt{}(x); }
} else { __device__ cuComplex operator()(cuComplex x) {
return rsqrt(x); return 1.0f / Sqrt{}(x);
}
} }
}; };
struct Tan { struct Tan {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return tan(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tan_a = tan(cuCrealf(x));
float tanh_b = tanh(cuCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
} }
}; };
struct Tanh { struct Tanh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return tanh(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tanh_a = tanh(cuCrealf(x));
float tan_b = tan(cuCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
} }
}; };
__device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
};
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -8,9 +8,9 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/complex.cuh"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include <cuComplex.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda/std/array> #include <cuda/std/array>
@@ -28,27 +28,6 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>; using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>; using Strides = cuda::std::array<int64_t, MAX_NDIM>;
// Vectorized load/store.
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
};
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -99,20 +78,20 @@ struct Limits<
return cuda::std::numeric_limits<T>::infinity(); return cuda::std::numeric_limits<T>::infinity();
} }
static constexpr __host__ __device__ T min() { static constexpr __host__ __device__ T min() {
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return -cuda::std::numeric_limits<float>::infinity();
#else
return -cuda::std::numeric_limits<T>::infinity(); return -cuda::std::numeric_limits<T>::infinity();
#else
return -cuda::std::numeric_limits<float>::infinity();
#endif #endif
} }
static constexpr __host__ __device__ T finite_max() { static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max(); return cuda::std::numeric_limits<T>::max();
} }
static constexpr __host__ __device__ T finite_min() { static constexpr __host__ __device__ T finite_min() {
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return cuda::std::numeric_limits<float>::lowest();
#else
return cuda::std::numeric_limits<T>::lowest(); return cuda::std::numeric_limits<T>::lowest();
#else
return cuda::std::numeric_limits<float>::lowest();
#endif #endif
} }
}; };
@@ -127,13 +106,13 @@ struct Limits<bool> {
} }
}; };
template <typename T> template <>
struct Limits<complex_t<T>> { struct Limits<cuComplex> {
static constexpr __host__ __device__ complex_t<T> max() { static constexpr __host__ __device__ cuComplex max() {
return {Limits<T>::max(), Limits<T>::max()}; return {Limits<float>::max(), Limits<float>::max()};
} }
static constexpr __host__ __device__ complex_t<T> min() { static constexpr __host__ __device__ cuComplex min() {
return {Limits<T>::min(), Limits<T>::min()}; return {Limits<float>::min(), Limits<float>::min()};
} }
}; };
@@ -359,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
} }
}; };
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -90,6 +90,8 @@ bool CudaEvent::completed() const {
// SharedEvent implementations // SharedEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current; uint64_t current;
while ((current = ac->load()) < value) { while ((current = ac->load()) < value) {
@@ -110,6 +112,8 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
} // namespace
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. // Allocate cuda::atomic on managed memory.
Atomic* ac; Atomic* ac;

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/version.h"
#include "cuda_jit_sources.h" #include "cuda_jit_sources.h"
@@ -13,7 +12,6 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <unistd.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -51,41 +49,14 @@ const std::string& cuda_home() {
return home; return home;
} }
// Return the location of CCCL headers shipped with the distribution.
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
}
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() { const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path { static std::filesystem::path cache = []() -> std::filesystem::path {
std::filesystem::path cache; std::filesystem::path cache;
if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) { if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
cache = c; cache = c;
} else { } else {
cache = cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
} }
if (!std::filesystem::exists(cache)) { if (!std::filesystem::exists(cache)) {
std::error_code error; std::error_code error;
@@ -137,8 +108,7 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return; return;
} }
@@ -151,9 +121,6 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl; txt_file << name << "\t" << mangled << std::endl;
} }
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
} }
// Return if |device|'s version is not newer than |major|.|minor| version. // Return if |device|'s version is not newer than |major|.|minor| version.
@@ -193,7 +160,7 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "complex.cuh", INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "fp16_math.cuh",
INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "indexing.cuh",
INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "scatter_ops.cuh",
@@ -209,7 +176,7 @@ constexpr const char* g_headers[] = {
jit_source_binary_ops, jit_source_binary_ops,
jit_source_cast_op, jit_source_cast_op,
jit_source_config, jit_source_config,
jit_source_complex, jit_source_cucomplex_math,
jit_source_fp16_math, jit_source_fp16_math,
jit_source_indexing, jit_source_indexing,
jit_source_scatter_ops, jit_source_scatter_ops,
@@ -246,24 +213,16 @@ JitModule::JitModule(
} }
// Compile program. // Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device); bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format( std::string compute = fmt::format(
"--gpu-architecture={}_{}{}", "--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute", use_sass ? "sm" : "compute",
device.compute_capability_major(), device.compute_capability_major(),
device.compute_capability_minor()); device.compute_capability_minor());
args.push_back(compute.c_str()); std::string include = fmt::format("--include-path={}/include", cuda_home());
std::string cccl_include = cccl_dir(); const char* args[] = {compute.c_str(), include.c_str()};
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result = nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data()); nvrtcCompileProgram(prog, std::size(args), args);
if (compile_result != NVRTC_SUCCESS) { if (compile_result != NVRTC_SUCCESS) {
size_t log_size; size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
@@ -293,8 +252,7 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx( write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
// Load module. // Load module.

View File

@@ -11,6 +11,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
@@ -78,7 +79,7 @@ struct CTypeToCudaType<bfloat16_t> {
template <> template <>
struct CTypeToCudaType<complex64_t> { struct CTypeToCudaType<complex64_t> {
using type = cu::complex64_t; using type = cuComplex;
}; };
template <typename T> template <typename T>
@@ -90,14 +91,10 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> || cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>; cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex numbers.
template <typename T>
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
cuda::std::is_same_v<T, complex128_t>;
// Type traits for detecting complex or real floating point numbers. // Type traits for detecting complex or real floating point numbers.
template <typename T> template <typename T>
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>; inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host. // Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t> template <int NDIM = MAX_NDIM, typename T = int32_t>

View File

@@ -237,7 +237,8 @@ void LayerNorm::eval_gpu(
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@@ -294,7 +295,9 @@ void LayerNormVJP::eval_gpu(
return x; return x;
} }
copied = true; copied = true;
return contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable(); bool donate_g = inputs[3].is_donatable();

View File

@@ -108,7 +108,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -27,35 +27,6 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
} }
} }
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul { class MatMul {
public: public:
MatMul( MatMul(
@@ -72,7 +43,7 @@ class MatMul {
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride) int64_t b_batch_stride)
: handle_(device.lt_handle()), pref_(cublas_preference(device)) { : handle_(device.lt_handle()) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
@@ -106,6 +77,20 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
} }
MatMul( MatMul(
@@ -119,6 +104,7 @@ class MatMul {
uint64_t b_rows, uint64_t b_rows,
uint64_t b_cols, uint64_t b_cols,
int64_t ldb, int64_t ldb,
bool c_transposed,
int64_t ldc, int64_t ldc,
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
@@ -140,15 +126,15 @@ class MatMul {
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cuda_type(dtype); auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
} }
~MatMul() { ~MatMul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); cublasLtMatrixLayoutDestroy(a_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); cublasLtMatrixLayoutDestroy(b_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); cublasLtMatrixLayoutDestroy(c_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); cublasLtMatrixLayoutDestroy(out_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); cublasLtMatmulDescDestroy(matmul_desc_);
} }
void run( void run(
@@ -273,9 +259,9 @@ class MatMul {
return desc; return desc;
} }
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr}; cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr}; cublasLtMatrixLayout_t c_desc_{nullptr};
@@ -296,7 +282,8 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy = contiguous_copy_gpu(arr, s); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
enc.add_temporary(arr_copy); enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }
@@ -402,7 +389,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
auto c = inputs[2]; auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Init checks and prep // Init checks and prep
@@ -415,24 +404,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays // the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@@ -470,6 +442,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K, K,
N, N,
ldb, ldb,
c_transposed,
ldc, ldc,
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),

View File

@@ -82,7 +82,7 @@ NO_GPU(Load)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF) NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul) NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM) NO_GPU(Scan)
NO_GPU_MULTI(SVD) NO_GPU_MULTI(SVD)
NO_GPU(Inverse) NO_GPU(Inverse)
NO_GPU(Cholesky) NO_GPU(Cholesky)
@@ -91,6 +91,7 @@ NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
} // namespace fast } // namespace fast

View File

@@ -1,386 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr float eps = 1e-7;
constexpr int simd_size = WARP_SIZE;
constexpr float n_bins = (1 << bits) - 1;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t in_index = offset * values_per_reduce;
if (in_index >= size) {
return;
}
size_t out_index = power_of_2_bits
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
float w_thread[values_per_reduce];
float w_min = Limits<float>::max();
float w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
float val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
cg::greater<float> max_op;
cg::less<float> min_op;
auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());
w_min = cg::reduce(warp, w_min, min_op);
w_max = cg::reduce(warp, w_max, max_op);
float scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
float edge = side ? w_min : w_max;
float q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
float bias = at_zero ? 0 : edge;
// Write out the scales and biases
size_t gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = static_cast<T>(scale);
biases[gindex] = static_cast<T>(bias);
}
using OutType = std::conditional_t<bits == 5, uint64_t, uint32_t>;
OutType output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output |= val << (bits * (i % pack_factor));
}
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
out[out_index + i / pack_factor] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = warp.shfl_down(val, j);
output |= static_cast<OutType>(sval)
<< (bits * (j * values_per_reduce + i));
}
}
}
if constexpr (bits == 3 || bits == 6) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else if constexpr (bits == 5) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
out[out_index + 3] = (output & 0xff000000) >> 24;
out[out_index + 4] = (output & 0xff00000000) >> 32;
}
} else {
if constexpr (writes_per_reduce > 0) {
if (out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
}
}
template <typename T, int group_size, int bits>
__global__ void affine_dequantize(
const uint8_t* w,
const T* scales,
const T* biases,
T* out,
size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
out += oindex;
if constexpr (bits == 3) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x7) * scale + bias;
out[1] = static_cast<T>((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (static_cast<T>((w[0] & 0xc0) >> 6) +
static_cast<T>((w[1] & 0x1) << 2)) *
scale +
bias;
out[3] = static_cast<T>((w[1] & 0xe) >> 1) * scale + bias;
out[4] = static_cast<T>((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (static_cast<T>((w[1] & 0x80) >> 7) +
static_cast<T>((w[2] & 0x3) << 1)) *
scale +
bias;
out[6] = static_cast<T>((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = static_cast<T>((w[2] & 0xe0) >> 5) * scale + bias;
} else if constexpr (bits == 5) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x1f) * scale + bias;
out[1] = (static_cast<T>((w[0] & 0xe0) >> 5) +
static_cast<T>((w[1] & 0x3) << 3)) *
scale +
bias;
out[2] = static_cast<T>((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = (static_cast<T>((w[1] & 0x80) >> 7) +
static_cast<T>((w[2] & 0xf) << 1)) *
scale +
bias;
out[4] = (static_cast<T>((w[2] & 0xf0) >> 4) +
static_cast<T>((w[3] & 0x1) << 4)) *
scale +
bias;
out[5] = static_cast<T>((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = (static_cast<T>((w[3] & 0xc0) >> 6) +
static_cast<T>((w[4] & 0x7) << 2)) *
scale +
bias;
out[7] = static_cast<T>((w[4] & 0xf8) >> 3) * scale + bias;
} else if constexpr (bits == 6) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x3f) * scale + bias;
out[1] = (static_cast<T>((w[0] >> 6) & 0x03) +
static_cast<T>((w[1] & 0x0f) << 2)) *
scale +
bias;
out[2] = (static_cast<T>((w[1] >> 4) & 0x0f) +
static_cast<T>((w[2] & 0x03) << 4)) *
scale +
bias;
out[3] = static_cast<T>((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = scale * static_cast<T>(d) + bias;
}
}
}
} // namespace cu
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<uint8_t>(),
inputs[1].data<DataType>(),
inputs[2].data<DataType>(),
out.data<DataType>(),
out.size());
} else {
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size());
}
});
});
});
}
} // namespace mlx::core

View File

@@ -47,7 +47,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy = contiguous_copy_gpu(in, s); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy); encoder.add_temporary(in_copy);
in = in_copy; in = in_copy;
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);

View File

@@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (; i + block.size() * N <= check; i += block.size() * N) { for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals); cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], cast_to<U>(vals[j])); accs[0] = op(accs[0], __cast<U, T>(vals[j]));
} }
} }
if (i < check) { if (i < check) {
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init)); block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], cast_to<U>(vals[i])); accs[0] = op(accs[0], __cast<U, T>(vals[i]));
} }
} }

View File

@@ -3,6 +3,7 @@
#include <numeric> #include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -127,7 +128,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], cast_to<U>(vals[i])); totals[i] = op(totals[i], __cast<U, T>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
@@ -136,7 +137,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], cast_to<U>(vals[i])); totals[i] = op(totals[i], __cast<U, T>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
@@ -149,9 +150,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
in + loop.location(), in + loop.location(),
vals, vals,
args.reduction_stride - tile_x * BN, args.reduction_stride - tile_x * BN,
cast_to<T>(ReduceInit<Op, T>::value())); __cast<T, U>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], cast_to<U>(vals[i])); totals[i] = op(totals[i], __cast<U, T>(vals[i]));
} }
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }

View File

@@ -3,6 +3,7 @@
#include <type_traits> #include <type_traits>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"

View File

@@ -2,8 +2,6 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/atomic_ops.cuh"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh"
@@ -42,15 +40,15 @@ struct Sum {
} }
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomic_add(x, y); atomicAdd(x, y);
} }
__device__ void atomic_update(int* x, int y) { __device__ void atomic_update(int* x, int y) {
atomic_add(x, y); atomicAdd(x, y);
} }
__device__ void atomic_update(float* x, float y) { __device__ void atomic_update(float* x, float y) {
atomic_add(x, y); atomicAdd(x, y);
} }
}; };
@@ -69,18 +67,6 @@ struct Prod {
struct Min { struct Min {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
if constexpr (is_complex_v<T>) {
if (isnan(a.real()) || isnan(a.imag())) {
return a;
}
if (isnan(b.real()) || isnan(b.imag())) {
return b;
}
} else if constexpr (!cuda::std::is_integral_v<T>) {
if (isnan(a) || isnan(b)) {
return cuda::std::numeric_limits<float>::quiet_NaN();
}
}
return a < b ? a : b; return a < b ? a : b;
} }
@@ -93,18 +79,6 @@ struct Min {
struct Max { struct Max {
template <typename T> template <typename T>
__device__ __forceinline__ T operator()(T a, T b) { __device__ __forceinline__ T operator()(T a, T b) {
if constexpr (is_complex_v<T>) {
if (isnan(a.real()) || isnan(a.imag())) {
return a;
}
if (isnan(b.real()) || isnan(b.imag())) {
return b;
}
} else if constexpr (!cuda::std::is_integral_v<T>) {
if (isnan(a) || isnan(b)) {
return cuda::std::numeric_limits<float>::quiet_NaN();
}
}
return a > b ? a : b; return a > b ? a : b;
} }
@@ -175,10 +149,10 @@ struct ReduceInit<Or, T> {
template <typename T> template <typename T>
struct ReduceInit<Sum, T> { struct ReduceInit<Sum, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (is_complex_v<T>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0}; return T{0, 0};
} else { } else {
return cast_to<typename ReduceResult<Sum, T>::type>(0); return typename ReduceResult<Sum, T>::type{0};
} }
} }
}; };
@@ -186,10 +160,10 @@ struct ReduceInit<Sum, T> {
template <typename T> template <typename T>
struct ReduceInit<Prod, T> { struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (is_complex_v<T>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 0}; return T{1, 0};
} else { } else {
return cast_to<typename ReduceResult<Prod, T>::type>(1); return typename ReduceResult<Prod, T>::type{1};
} }
} }
}; };

View File

@@ -4,7 +4,6 @@
#include <numeric> #include <numeric>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -56,6 +55,22 @@ __device__ void atomic_reduce(T* x, T y) {
} }
} }
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
template <typename T, int N, typename Block, typename Warp, typename Op> template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {

View File

@@ -3,6 +3,7 @@
#include <numeric> #include <numeric>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -112,7 +113,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + r * (block.size() * N), in + k * size + r * (block.size() * N),
vals[k]); vals[k]);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
} }
} }
} }
@@ -124,7 +125,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + r * (block.size() * N), in + k * size + r * (block.size() * N),
vals[k]); vals[k]);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
} }
} }
} }
@@ -137,9 +138,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
in + k * size + final_offset, in + k * size + final_offset,
vals[k], vals[k],
size, size,
cast_to<T>(init)); __cast<T, U>(init));
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
} }
} }
} }
@@ -198,7 +199,7 @@ __global__ void row_reduce_looped(
in + loop.location() + r * BLOCK_DIM * N_READS, in + loop.location() + r * BLOCK_DIM * N_READS,
vals); vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i])); total[0] = op(total[0], __cast<U, T>(vals[i]));
} }
} }
if (final_offset < args.row_size) { if (final_offset < args.row_size) {
@@ -208,9 +209,9 @@ __global__ void row_reduce_looped(
in + loop.location() + final_offset, in + loop.location() + final_offset,
vals, vals,
args.row_size - final_offset, args.row_size - final_offset,
cast_to<T>(init)); __cast<T, U>(init));
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i])); total[0] = op(total[0], __cast<U, T>(vals[i]));
} }
} }
// TODO: Maybe block.sync() here? // TODO: Maybe block.sync() here?

View File

@@ -74,7 +74,7 @@ __global__ void rms_norm(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
normalizer += t * t; normalizer += t * t;
@@ -130,7 +130,7 @@ __global__ void rms_norm_vjp(
T wn[N_READS] = {}; T wn[N_READS] = {};
T gn[N_READS] = {}; T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -206,7 +206,8 @@ void RMSNorm::eval_gpu(
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@@ -258,7 +259,9 @@ void RMSNormVJP::eval_gpu(
return x; return x;
} }
copied = true; copied = true;
return contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable(); bool donate_g = inputs[2].is_donatable();

View File

@@ -1,465 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/scan.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename T>
struct ScanResult {
using type = T;
};
template <>
struct ScanResult<Sum, bool> {
using type = int32_t;
};
template <typename T>
struct ReduceInit<LogAddExp, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::min();
}
};
template <bool reverse, typename T, typename U, int N_READS>
inline __device__ void
load_values(int index, const T* in, U (&values)[N_READS], int size, U init) {
int remaining = size - index * N_READS;
if constexpr (reverse) {
in += remaining - N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
values[N_READS - i - 1] =
(N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
values[N_READS - i - 1] = cast_to<U>(in[i]);
}
}
} else {
in += index * N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
values[i] = (i < remaining) ? cast_to<U>(in[i]) : init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
values[i] = cast_to<U>(in[i]);
}
}
}
}
template <bool reverse, int offset, typename T, int N_READS>
inline __device__ void
store_values(int index, T* out, T (&values)[N_READS], int size) {
int start = index * N_READS + offset;
int remaining = size - start;
if constexpr (reverse) {
out += remaining - N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
if (N_READS - i - 1 < remaining) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; ++i) {
out[i] = values[N_READS - i - 1];
}
}
} else {
out += start;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
if (i < remaining) {
out[i] = values[i];
}
}
} else {
for (int i = 0; i < N_READS; ++i) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
__shared__ U warp_sums[WARP_SIZE];
Op op;
U init = ReduceInit<Op, T>::value();
U prefix = init;
// Scan per block.
for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {
int32_t index = r * block.size() + block.thread_rank();
U values[N_READS];
load_values<reverse>(index, in, values, axis_size, init);
// Compute an inclusive scan per thread.
for (int i = 1; i < N_READS; ++i) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums.
U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op);
if (warp.thread_rank() == 0) {
prev_thread_sum = init;
}
// Write wrap's sum to shared memory.
if (warp.thread_rank() == WARP_SIZE - 1) {
warp_sums[warp.meta_group_rank()] =
op(prev_thread_sum, values[N_READS - 1]);
}
block.sync();
// Compute exclusive scan of warp sums.
if (warp.meta_group_rank() == 0) {
U prev_warp_sum =
cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op);
if (warp.thread_rank() == 0) {
prev_warp_sum = init;
}
warp_sums[warp.thread_rank()] = prev_warp_sum;
}
block.sync();
// Compute the output.
for (int i = 0; i < N_READS; ++i) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], warp_sums[warp.meta_group_rank()]);
values[i] = op(values[i], prev_thread_sum);
}
// Write the values.
if (inclusive) {
store_values<reverse, 0>(index, out, values, axis_size);
} else {
store_values<reverse, 1>(index, out, values, axis_size);
if (reverse) {
if (block.thread_rank() == 0 && index == 0) {
out[axis_size - 1] = init;
}
} else {
if (block.thread_rank() == 0 && index == 0) {
out[0] = init;
}
}
}
block.sync();
// Share the prefix.
if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&
(warp.thread_rank() == WARP_SIZE - 1)) {
warp_sums[0] = values[N_READS - 1];
}
block.sync();
prefix = warp_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
int BM,
int BN,
bool inclusive,
bool reverse>
__global__ void strided_scan(
const T* in,
U* out,
int32_t axis_size,
int64_t stride,
int64_t stride_blocks) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U);
constexpr int n_warps = BN / N_READS;
constexpr int n_scans = BN / n_warps;
__shared__ U read_buffer[BM * BN_pad];
Op op;
U init = ReduceInit<Op, T>::value();
U values[n_scans];
U prefix[n_scans];
for (int i = 0; i < n_scans; ++i) {
prefix[i] = init;
}
// Compute offsets.
int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride;
int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN;
uint read_offset_y = (block.thread_rank() * N_READS) / BN;
uint read_offset_x = (block.thread_rank() * N_READS) % BN;
uint scan_offset_y = warp.thread_rank();
uint scan_offset_x = warp.meta_group_rank() * n_scans;
uint stride_limit = stride - global_index_x;
in += offset + global_index_x + read_offset_x;
out += offset + global_index_x + read_offset_x;
U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x;
U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x;
for (uint j = 0; j < axis_size; j += BM) {
// Calculate the indices for the current thread.
uint index_y = j + read_offset_y;
uint check_index_y = index_y;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM.
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
read_into[i] = in[index_y * stride + i];
}
} else {
for (int i = 0; i < N_READS; ++i) {
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
read_into[i] = in[index_y * stride + i];
} else {
read_into[i] = init;
}
}
}
block.sync();
// Read strided into registers.
for (int i = 0; i < n_scans; ++i) {
values[i] = read_from[i];
}
// Perform the scan.
for (int i = 0; i < n_scans; ++i) {
values[i] = cg::inclusive_scan(warp, values[i], op);
values[i] = op(values[i], prefix[i]);
prefix[i] = warp.shfl(values[i], WARP_SIZE - 1);
}
// Write to SM.
for (int i = 0; i < n_scans; ++i) {
read_from[i] = values[i];
}
block.sync();
// Write to device memory.
if (!inclusive) {
if (check_index_y == 0) {
if ((read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
out[index_y * stride + i] = init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
if ((read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
out[index_y * stride + i] = read_into[i];
}
} else {
for (int i = 0; i < N_READS; ++i) {
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = read_into[i];
}
}
}
}
}
} // namespace cu
template <typename F>
void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) {
if (scan_op == Scan::ReduceType::Max) {
f(type_identity<cu::Max>{});
} else if (scan_op == Scan::ReduceType::Min) {
f(type_identity<cu::Min>{});
} else if (scan_op == Scan::ReduceType::Sum) {
f(type_identity<cu::Sum>{});
} else if (scan_op == Scan::ReduceType::Prod) {
f(type_identity<cu::Prod>{});
} else if (scan_op == Scan::ReduceType::LogAddExp) {
f(type_identity<cu::LogAddExp>{});
} else {
throw std::invalid_argument("Unknown reduce type.");
}
}
template <typename Op>
const char* op_to_string() {
if (cuda::std::is_same_v<Op, cu::Max>) {
return "Max";
} else if (cuda::std::is_same_v<Op, cu::Min>) {
return "Min";
} else if (cuda::std::is_same_v<Op, cu::Sum>) {
return "Sum";
} else if (cuda::std::is_same_v<Op, cu::Prod>) {
return "Prod";
} else if (cuda::std::is_same_v<Op, cu::LogAddExp>) {
return "LogAddExp";
} else {
throw std::invalid_argument("Unknown op.");
}
}
template <typename Op, typename T>
constexpr bool supports_scan_op() {
if constexpr (cuda::std::is_same_v<Op, LogAddExp>) {
return is_inexact_v<T>;
} else {
return true;
}
}
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Scan::eval_gpu");
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}
constexpr int N_READS = 4;
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) {
using Op = MLX_GET_TYPE(scan_op_tag);
if constexpr (supports_scan_op<Op, T>) {
using U = typename cu::ScanResult<Op, T>::type;
dispatch_bool(inclusive_, [&](auto inclusive) {
dispatch_bool(reverse_, [&](auto reverse) {
if (contiguous) {
auto kernel = cu::contiguous_scan<
T,
U,
Op,
N_READS,
inclusive.value,
reverse.value>;
int block_dim = cuda::ceil_div(axis_size, N_READS);
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
encoder.add_kernel_node(
kernel,
in.data_size() / axis_size,
block_dim,
in.data<T>(),
out.data<U>(),
axis_size);
} else {
constexpr int BM = WARP_SIZE;
constexpr int BN = WARP_SIZE;
auto kernel = cu::strided_scan<
T,
U,
Op,
N_READS,
BM,
BN,
inclusive.value,
reverse.value>;
int64_t stride = in.strides()[axis_];
int64_t stride_blocks = cuda::ceil_div(stride, BN);
dim3 num_blocks = get_2d_grid_dims(
in.shape(), in.strides(), axis_size * stride);
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
num_blocks.x *= stride_blocks;
} else {
num_blocks.y *= stride_blocks;
}
int block_dim = (BN / N_READS) * WARP_SIZE;
encoder.add_kernel_node(
kernel,
num_blocks,
block_dim,
in.data<T>(),
out.data<U>(),
axis_size,
stride,
stride_blocks);
}
});
});
} else {
throw std::runtime_error(fmt::format(
"Can not do scan op {} on inputs of {} with result of {}.",
op_to_string<Op>(),
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
} // namespace mlx::core

View File

@@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
// Thread reduce. // Thread reduce.
AccT prevmax; AccT prevmax;
AccT maxval = Limits<AccT>::finite_min(); AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = cast_to<AccT>(0); AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS]; AccT vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
@@ -125,7 +125,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -72,7 +72,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) { if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim); array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s); in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in); encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);

View File

@@ -15,27 +15,12 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename T, typename IdxT, int N_READS> template <typename Op, typename T, typename IdxT>
__global__ void __global__ void
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(a[index], b[index], c[index]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i], c[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
auto c_vec = load_vector<N_READS>(c, index);
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@@ -164,18 +149,11 @@ void ternary_op_gpu_inplace(
} }
}); });
} else { } else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. auto kernel = cu::ternary_v<Op, DType, IdxT>;
constexpr int N_READS = 4;
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, kernel, out.data_size(), out.shape(), out.strides(), large());
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
@@ -17,24 +18,11 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS> template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_v(const In* in, Out* out, IdxT size) { __global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) {
if ((index + 1) * N_READS > size) { out[index] = Op{}(in[index]);
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
} }
} }
@@ -70,10 +58,10 @@ constexpr bool supports_unary_op() {
!std::is_same_v<In, bool>; !std::is_same_v<In, bool>;
} }
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) { if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>; return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
} }
if (std::is_same_v<Op, Conjugate>) { if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>; return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
} }
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> || if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
@@ -87,7 +75,7 @@ constexpr bool supports_unary_op() {
return std::is_same_v<In, Out> && is_inexact_v<In>; return std::is_same_v<In, Out> && is_inexact_v<In>;
} }
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) { if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>; return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
} }
if (std::is_same_v<Op, LogicalNot>) { if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>; return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
@@ -101,7 +89,7 @@ template <typename Op>
void unary_op_gpu_inplace( void unary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& in = inputs[0]; auto& in = inputs[0];
if (in.size() == 0) { if (in.size() == 0) {
@@ -124,20 +112,14 @@ void unary_op_gpu_inplace(
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) { if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) { dispatch_bool(large, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
if (contig) { if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, kernel, out.data_size(), out.shape(), out.strides(), large);
out.data_size(),
out.shape(),
out.strides(),
large,
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,
@@ -146,7 +128,6 @@ void unary_op_gpu_inplace(
out.data<OutType>(), out.data<OutType>(),
out.data_size()); out.data_size());
} else { } else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>; auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
@@ -177,17 +158,17 @@ template <typename Op>
void unary_op_gpu( void unary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
set_unary_output_data(inputs[0], out); set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s); unary_op_gpu_inplace<Op>(inputs, out, op, s);
} }
#define UNARY_GPU(func) \ #define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \ nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \ auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \ unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
} }
UNARY_GPU(Abs) UNARY_GPU(Abs)
@@ -223,15 +204,16 @@ UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) { void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu"); nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s); unary_op_gpu<cu::Log>(inputs, out, op, s);
break; break;
case Base::two: case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s); unary_op_gpu<cu::Log2>(inputs, out, op, s);
break; break;
case Base::ten: case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s); unary_op_gpu<cu::Log10>(inputs, out, op, s);
break; break;
} }
} }
@@ -242,7 +224,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s); unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);

View File

@@ -61,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
case float64: case float64:
return "double"; return "double";
case complex64: case complex64:
return "complex64_t"; return "cuComplex";
default: default:
return "unknown"; return "unknown";
} }

View File

@@ -46,10 +46,4 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
} }
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -43,7 +43,4 @@ void copy_gpu_inplace(
// Fill the output with the scalar val // Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s); void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -63,7 +63,6 @@ if(MLX_METAL_JIT)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(steel/gemm/kernels/steel_gemm_segmented)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
kernels/steel/utils.h kernels/steel/utils.h

View File

@@ -7,20 +7,20 @@
#define BINARY_GPU(func) \ #define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, name()); \ binary_op_gpu(inputs, out, get_primitive_string(this)); \
} }
#define BINARY_GPU_MULTI(func) \ #define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, name()); \ binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
} }
namespace mlx::core { namespace mlx::core {
std::string get_kernel_name( std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
const char* op, const std::string& op,
const array& a, const array& a,
bool large, bool large,
int ndim, int ndim,
@@ -65,7 +65,7 @@ std::string get_kernel_name(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -179,7 +179,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op) { const std::string& op) {
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s); binary_op_gpu(inputs, outputs, op, s);
} }
@@ -187,7 +187,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
std::vector<array> outputs = {out}; std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s); binary_op_gpu_inplace(inputs, outputs, op, s);
@@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -209,7 +209,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op) { const std::string& op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s); binary_op_gpu(inputs, out, op, s);
} }
@@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_op_gpu(inputs, out, name()); binary_op_gpu(inputs, out, get_primitive_string(this));
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_op_gpu(inputs, out, name()); binary_op_gpu(inputs, out, get_primitive_string(this));
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, name()); binary_op_gpu(inputs, out, get_primitive_string(this));
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, name()); binary_op_gpu(inputs, out, get_primitive_string(this));
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, name()); binary_op_gpu(inputs, out, get_primitive_string(this));
break; break;
} }
} }

View File

@@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const char* op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const char* op, const std::string& op,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -212,7 +212,9 @@ inline void build_kernel(
get_type_string(x.dtype()), get_type_string(x.dtype()),
namer.get_name(x.inputs()[0])); namer.get_name(x.inputs()[0]));
} else { } else {
os += x.primitive().name(); std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += "()("; os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));

View File

@@ -149,7 +149,8 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize // Materialize
array wt_transpose = contiguous_copy_gpu(wt_view, s); auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
@@ -960,12 +961,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0]; auto in = inputs[0];
auto wt = inputs[1]; auto wt = inputs[1];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copies.push_back(in); copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
} }
if (!wt.flags().row_contiguous) { if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s); array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copies.push_back(wt); copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
} }
// 3D conv // 3D conv

View File

@@ -86,7 +86,7 @@ void copy_gpu_inplace(
} }
} else { } else {
work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (!large && work_per_thread > 1) { if (work_per_thread > 1) {
kernel_name += "n"; kernel_name += "n";
} }
} }

View File

@@ -1,18 +1,20 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <filesystem>
#include <sstream> #include <sstream>
#define NS_PRIVATE_IMPLEMENTATION #define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
namespace { namespace {
@@ -78,7 +80,12 @@ MTL::Library* try_load_bundle(
std::pair<MTL::Library*, NS::Error*> load_colocated_library( std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device, MTL::Device* device,
const std::string& relative_path) { const std::string& relative_path) {
auto path = current_binary_dir() / relative_path; std::string binary_dir = get_binary_directory();
if (binary_dir.size() == 0) {
return {nullptr, nullptr};
}
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) { if (!path.has_extension()) {
path.replace_extension(".metallib"); path.replace_extension(".metallib");
} }
@@ -190,7 +197,7 @@ MTL::Library* load_library(
std::ostringstream msg; std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. " msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << current_binary_dir() << "/" << "We attempted to load it from <" << get_binary_directory() << "/"
<< lib_name << ".metallib" << ">"; << lib_name << ".metallib" << ">";
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle."; msg << " and from the Swift PM bundle.";

View File

@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <Metal/Metal.hpp> #include <Metal/Metal.hpp>
#include <dlfcn.h>
#include <filesystem>
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <shared_mutex> #include <shared_mutex>
@@ -13,8 +15,22 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/device.h" #include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_binary_directory() {
Dl_info info;
std::string directory;
int success = dladdr((void*)get_binary_directory, &info);
if (success) {
directory = fs::path(info.dli_fname).remove_filename().c_str();
}
return directory;
}
using MTLFCList = using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;

View File

@@ -575,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Set source info // Set source info
if (ndim > 1) { compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
} else {
// The following will be ignored in the kernel but we still have to set
// some value so that metal validation passes.
compute_encoder.set_vector_bytes(idx.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4);
compute_encoder.set_vector_bytes(idx.strides(), 5);
}
compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8); compute_encoder.set_bytes(out.shape(axis_), 8);

View File

@@ -34,7 +34,6 @@ const char* steel_gemm_fused();
const char* steel_gemm_masked(); const char* steel_gemm_masked();
const char* steel_gemm_splitk(); const char* steel_gemm_splitk();
const char* steel_gemm_gather(); const char* steel_gemm_gather();
const char* steel_gemm_segmented();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();

View File

@@ -8,6 +8,12 @@ using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -27,7 +33,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type); auto in_t = get_type_string(in_type);
@@ -52,10 +58,10 @@ MTL::ComputePipelineState* get_unary_kernel(
} }
void append_binary_kernels( void append_binary_kernels(
const std::string& lib_name, const std::string lib_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op, const std::string op,
std::string& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
@@ -106,7 +112,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source; std::string kernel_source;
@@ -123,7 +129,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
@@ -138,7 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype type, Dtype type,
const char* op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type); auto t_str = get_type_string(type);
@@ -646,43 +652,6 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_segmented(),
get_template_definition(
lib_name,
"segmented_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op); const std::string op);
MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op); const std::string op);
MTL::ComputePipelineState* get_binary_two_kernel( MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const char* op); const std::string op);
MTL::ComputePipelineState* get_ternary_kernel( MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype type, Dtype type,
const char* op); const std::string op);
MTL::ComputePipelineState* get_copy_kernel( MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d, metal::Device& d,
@@ -175,20 +175,6 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int wn, int wn,
bool rhs); bool rhs);
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
@@ -257,10 +243,8 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
// Create a GPU kernel template definition for JIT compilation // Create a GPU kernel template definition for JIT compilation
template <typename... Args> template <typename... Args>
std::string get_template_definition( std::string
std::string_view name, get_template_definition(std::string name, std::string func, Args... args) {
std::string_view func,
Args... args) {
std::ostringstream s; std::ostringstream s;
s << func << "<"; s << func << "<";
bool first = true; bool first = true;

View File

@@ -71,7 +71,6 @@ set(STEEL_HEADERS
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_segmented.h
steel/gemm/kernels/steel_gemm_splitk.h steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h steel/utils/type_traits.h
steel/utils/integral_constant.h) steel/utils/integral_constant.h)
@@ -121,7 +120,6 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h) build_kernel(gemv_masked steel/utils.h)
endif() endif()

View File

@@ -1,134 +0,0 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <metal_math>
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline void get_float_word(thread uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline void get_float_word(thread int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline void set_float_word(thread float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline float frexp_expf(float x, thread int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = metal::exp(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = z.real;
y = z.imag;
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return complex64_t{
metal::cos(y) * exp_x * scale1 * scale2,
metal::sin(y) * exp_x * scale1 * scale2};
}
inline complex64_t cexpf(const thread complex64_t& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = z.real;
y = z.imag;
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return complex64_t{metal::exp(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return complex64_t{metal::cos(y), metal::sin(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return complex64_t{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return complex64_t{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return complex64_t{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = metal::exp(x);
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
}
}

View File

@@ -31,7 +31,6 @@ inline void threadgroup_sum(
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
x[i] = simd_sum(x[i]); x[i] = simd_sum(x[i]);
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
xs[N * simd_group_id + i] = x[i]; xs[N * simd_group_id + i] = x[i];

View File

@@ -643,14 +643,14 @@ struct QuantizedBlockLoader {
return; return;
} }
if (reduction_dim == 1 && bi >= src_tile_dim.x) { if (reduction_dim == 1 && bi >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) { for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0); dst[i] = T(0);
} }
return; return;
} }
if (reduction_dim == 0 && bi >= src_tile_dim.y) { if (reduction_dim == 0 && bi >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) { for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0); dst[i] = T(0);
} }

View File

@@ -164,15 +164,7 @@ struct Min {
DEFINE_SIMD_REDUCE() DEFINE_SIMD_REDUCE()
template <typename T> template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) { T simd_reduce_impl(T val) {
return simd_min(val);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_min(val); return simd_min(val);
} }
@@ -184,38 +176,11 @@ struct Min {
} }
// Operator // Operator
template <typename T> U operator()(U a, U b) {
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if (metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a < b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
if (!real_is_nan && !imag_is_nan) {
return a < b ? a : b;
} else if (real_is_nan && !imag_is_nan) {
return complex64_t(
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
} else if (!real_is_nan && imag_is_nan) {
return complex64_t(
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
} else {
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
}
};
}; };
template <typename U> template <typename U>
struct Max { struct Max {
DEFINE_SIMD_REDUCE() DEFINE_SIMD_REDUCE()

View File

@@ -1,266 +0,0 @@
// Copyright © 2025 Apple Inc.
using namespace mlx::steel;
constant bool segments_contiguous [[function_constant(199)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* segments [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Move the pointers to the output tile
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Move the pointers to the start of the segment
uint32_t k_start, k_end;
if (segments_contiguous) {
k_start = segments[2 * tid.z];
k_end = segments[2 * tid.z + 1];
} else {
// We accept either contiguous (above) or weird strides where the beginning
// of the next one is the previous one. Basically the last two strides are
// both 1!
k_start = segments[tid.z];
k_end = segments[tid.z + 1];
}
A += transpose_a ? k_start * params->lda : k_start;
B += transpose_b ? k_start : k_start * params->ldb;
C += tid.z * params->batch_stride_d;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Matrix level alignment so only check K
if (align_M && align_N) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, params->ldd);
} else {
// Tile aligned do the same as above
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_safe(
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
loader_b.load_safe(
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -1,43 +0,0 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
segmented_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
instantiate_segmented_mm_shapes_helper(
bfloat16,
bfloat16_t,
bfloat16,
bfloat16_t);
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);

View File

@@ -5,7 +5,6 @@
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/cexpf.h"
#include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h" #include "mlx/backend/metal/kernels/expm1f.h"
@@ -179,7 +178,8 @@ struct Exp {
return metal::precise::exp(x); return metal::precise::exp(x);
}; };
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return cexpf(x); auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
} }
}; };

View File

@@ -25,7 +25,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} }

View File

@@ -33,7 +33,8 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy = contiguous_copy_gpu(arr, s); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }
@@ -42,7 +43,8 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} else { } else {
@@ -73,7 +75,8 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
} }
} }
array x_copy = contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
} }
@@ -1861,165 +1864,4 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
} }
void segmented_mm(
const array& a_,
const array& b_,
const array& segments_,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
auto check_segments_layout = [&d, &s](const array& x) {
// Contiguous so return early
if (x.flags().row_contiguous) {
return std::make_tuple(true, x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 2; i++) {
rc &=
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
}
rc &= x.strides(x.ndim() - 1) == 1;
if (x.ndim() > 1) {
rc &= x.strides(x.ndim() - 2) == 1;
}
if (rc) {
return std::make_tuple(false, x);
}
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
auto [segments_contiguous, segments] = check_segments_layout(segments_);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_segmented_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = {
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
};
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_segments_contiguous_",
segments_contiguous ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n');
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_segmented_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
// Prepare the matmul params
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ 0,
/* const int batch_ndim = */ 0};
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(segments, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& segments = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
segmented_mm(a, b, segments, out, M, N, K, d, s);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const char*) { const std::string) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const char*) { const std::string) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
Dtype, Dtype,
const char*) { const std::string) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype, Dtype,
const char*) { const std::string) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -210,22 +210,6 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts); return d.get_kernel(kernel_name, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -40,7 +40,8 @@ void RMSNorm::eval_gpu(
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@@ -106,7 +107,9 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; return {x, false};
} }
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return {x_copy, true};
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
@@ -238,7 +241,8 @@ void LayerNorm::eval_gpu(
} }
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@@ -315,7 +319,8 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; return {x, false};
} }
array x_copy = contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return {x_copy, true};
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();

View File

@@ -20,7 +20,8 @@ namespace {
inline array inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} else { } else {
@@ -37,7 +38,8 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) { if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x; return x;
} else { } else {
array x_copy = contiguous_copy_gpu(x, s); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} }

View File

@@ -989,7 +989,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction // input for the axes with stride smaller than the minimum reduction
// stride. // stride.
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
array in_copy = contiguous_copy_gpu(in, s); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
d.add_temporary(in_copy, s.index); d.add_temporary(in_copy, s.index);
in = in_copy; in = in_copy;
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);

View File

@@ -398,7 +398,8 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(std::move(arr_copy)); copies.push_back(std::move(arr_copy));
return copies.back(); return copies.back();
} else { } else {

Some files were not shown because too many files have changed in this diff Show More