Compare commits

...

50 Commits

Author SHA1 Message Date
Awni Hannun
39b04ce638 use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-31 11:49:59 -07:00
Mike Drob
d9e6349657 fix docs path (#2719) 2025-10-30 19:12:49 -05:00
Angelos Katharopoulos
b901a9f311 Fix the order of hosts in the ring (#2718)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-30 15:02:39 -07:00
Awni Hannun
68c5fa1c95 fix memory count bug (#2717) 2025-10-30 14:27:15 -07:00
Christopher Webb
793a31eeb6 Fix missing domain_uuid_key in thunderbolt ring setup (#2682) 2025-10-30 13:17:20 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
Awni Hannun
d1e06117e8 bump python (#2694) 2025-10-27 11:34:31 -07:00
Awni Hannun
539d8322d1 add median op (#2705) 2025-10-27 11:33:42 -07:00
Awni Hannun
c4767d110f fix addmm cpu (#2699) 2025-10-27 11:33:32 -07:00
David Koski
895217f25b optionally load metallib from framework (#2702)
* optionally load metallib from framework

* pre-commit

* adjust logic
2025-10-27 07:52:03 -07:00
Manuel Villanueva
0cfeeb60ca Einsum error msg improvement (#2690)
* Improved error message for Einsum

* Modifications via pre-commit

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-27 06:31:47 -07:00
Ronan Collobert
8f8af61a37 fix warnings showing up with -Wall (#2692) 2025-10-24 11:43:35 -07:00
Manuel Villanueva
233384161e Improved mx.split() docs (#2689)
* Improved mx.split() documentation

* Fix typo in docstring for array split function

* add example

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-24 09:48:41 -07:00
Awni Hannun
5bcf3a6794 format 2025-10-22 16:08:47 -07:00
wickedcoder
7707196297 Merge commit from fork
* add length validation to the header

* fix accessing out of bound index with .at()
2025-10-22 15:31:25 -07:00
wickedcoder
7e3471c987 Merge commit from fork
* add tensor->weights_data validation

* add null pointer check for tensor
2025-10-22 15:31:03 -07:00
Awni Hannun
9f0ba3ddf1 patch bump (#2680) 2025-10-17 12:12:07 -07:00
Awni Hannun
4bce5f9b2d suppress gcc 10.1 warnings (#2679)
* suppress gcc 10.1 warnings

* suppress gcc 10.1 warnings
2025-10-17 12:09:21 -07:00
Anastasiia Filippova
e9eab527eb Nccl timeout (#2673)
* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
2025-10-14 12:29:54 -07:00
Awni Hannun
36ca62dba8 remove unused unary file (#2672) 2025-10-13 19:36:26 -07:00
Manuel Villanueva
9cbb1b0148 Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-13 14:36:45 -07:00
Fabrizio Milo
9bfc476d72 Normalize README bullet formatting (#2671) 2025-10-13 12:13:30 -07:00
Awni Hannun
25e2356316 speed up scalars (#2669) 2025-10-13 12:10:15 -07:00
Awni Hannun
226a1d24e0 Debug cuda conv (#2662)
* use t4

* use t4
2025-10-10 16:12:47 -07:00
Awni Hannun
630350ad3e Precise sigmoid (#2659)
* bump patch

* Sigmoid matches PyTorch and is more precise on tails
2025-10-10 10:05:23 -07:00
Awni Hannun
380aeb58ae enable admm low-precision cpu (#2661) 2025-10-10 09:50:54 -07:00
Awni Hannun
f37389d100 bump patch (#2658) 2025-10-10 08:36:41 -07:00
Awni Hannun
e89e8b4272 Export with callback (#2612)
* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
2025-10-08 19:24:33 -07:00
AN Long
85a8824a8c Fix cumulative operations when axis=None (#2653) 2025-10-08 15:25:38 -07:00
Awni Hannun
f5d4397e5c Fix fast synch when fence is waited before a command buffer is created (#2657) 2025-10-08 11:23:46 -07:00
Awni Hannun
343e33b6d5 fix all_gather vjp (#2654) 2025-10-07 06:05:23 -07:00
Angelos Katharopoulos
0073096dd1 Split name into directories for cuda jit (#2656) 2025-10-07 01:52:58 -07:00
Angelos Katharopoulos
e3d004fed9 Fix and refactor row-reduce (#2650) 2025-10-07 01:51:08 -07:00
Awni Hannun
a393435d28 Speed up compile for node with many parents (#2649) 2025-10-03 19:30:36 -07:00
Awni Hannun
a7a94b29d7 Fix compile when outputs change (#2648) 2025-10-03 08:40:57 -07:00
Daniel Yeh
22a5da76c8 Faster complex matmul (#2571) 2025-10-02 23:33:15 -07:00
Andrey Portnoy
287c63a093 Configure CMake to export compile_commands.json (#2645)
This helps enable LSP for code navigation using clangd.
2025-10-02 15:40:32 -07:00
Awni Hannun
1c9ae1eaa1 cuda fix flaky test (#2646) 2025-10-02 15:40:04 -07:00
Angelos Katharopoulos
c2c3e0b0a2 [CUDA] Add a small column specialization to reduce (#2642) 2025-10-02 14:41:05 -07:00
Awni Hannun
b0cc71ae71 Faster triu, tril, where with scalar (#2644) 2025-10-02 12:21:27 -07:00
Awni Hannun
e88f2d4a8e fix cross entropy axis param (#2641)
* fix cross entropy axis param

* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
9cee557423 Fix status message (#2638) 2025-10-01 16:43:45 -07:00
Awni Hannun
bbf1423953 wait for tasks in cuda (#2636) 2025-09-30 16:08:46 -07:00
Angelos Katharopoulos
eb24267b56 Compile now can attach arbitrary data to an entry (#2634) 2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5 fix for max block dim (#2631) 2025-09-29 08:59:25 -07:00
AN Long
e76a8dd5c5 Fix incorrect path and typos (#2630) 2025-09-28 06:03:04 -07:00
Cheng
b466dea982 [CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
2025-09-27 11:27:17 +09:00
157 changed files with 4898 additions and 1772 deletions

View File

@@ -26,9 +26,9 @@ jobs:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install python@3.10
brew install doxygen
python3.9 -m venv env
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -140,7 +140,7 @@ jobs:
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
@@ -273,7 +273,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
@@ -328,7 +328,7 @@ jobs:
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -351,7 +351,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
build_env:
type: string
default: ""
@@ -387,7 +387,7 @@ jobs:
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -484,7 +484,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
@@ -503,7 +503,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
@@ -546,13 +546,13 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
@@ -564,14 +564,14 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:

View File

@@ -0,0 +1,24 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

68
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,68 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

78
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,78 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -0,0 +1,22 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

124
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,124 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
build-jit:
description: 'Whether to build with JIT'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip cmake setuptools
uv pip install nanobind==2.4.0 \
numpy torch tensorflow unittest-xml-reporting
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.build-jit == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 13 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 13.0
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

83
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

31
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

93
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,93 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

View File

@@ -1,20 +1,46 @@
on:
pull_request:
branches:
- main
name: Build and Test
on: pull_request
permissions:
contents: read
jobs:
check_lint:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs

188
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,188 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
build_documentation:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiples: true
path: artifacts
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiples: true
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: build_cuda_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: build_linux_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: build_mac_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,4 +1,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -87,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
@@ -173,7 +179,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()

View File

@@ -11,28 +11,28 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```
```text
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16"):
for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
)
)
plt.close(fig)

View File

@@ -16,7 +16,7 @@ silicon computer is
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- Using a native Python >= 3.10
- macOS >= 13.5
.. note::
@@ -39,7 +39,7 @@ requirements:
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
CPU-only (Linux)
@@ -55,7 +55,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
Troubleshooting

View File

@@ -112,6 +112,7 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -241,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
status(Status::unscheduled),
inputs(std::move(inputs)) {
init();
}

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}};
return {Shape{1}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}};
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -11,6 +11,8 @@ namespace mlx::core {
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General,
};
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else {
topt = TernaryOpType::General;
}
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
b.flags());
}
break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General:
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,

View File

@@ -46,7 +46,6 @@ void eig_impl(
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -88,4 +88,47 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
const void* a_mask_ptr = nullptr;
const void* b_mask_ptr = nullptr;
const void* out_mask_ptr = nullptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
bool a_mask_bool = false;
bool b_mask_bool = false;
bool out_mask_bool = false;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
@@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -91,7 +91,6 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
@@ -108,6 +107,9 @@ void matmul_general(
} else if (out.dtype() == float64) {
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
@@ -128,10 +130,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -445,7 +448,6 @@ void mxfp4_qmm(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -487,7 +489,6 @@ void mxfp4_qmm_t(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -1104,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -9,7 +10,7 @@
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// There seems to be a bug in simd/base_simd.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -15,6 +15,18 @@ namespace mlx::core {
namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -27,7 +39,7 @@ struct StridedIterator {
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
: stride_(stride), ptr_(ptr + offset * stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
std::stable_sort(st, ed, nan_aware_less<T>);
src_it.step();
}
}
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
std::nth_element(st, md, ed, nan_aware_less<T>);
}
}
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -83,8 +83,6 @@ void svd_impl(
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;

View File

@@ -77,7 +77,8 @@ struct Real {
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x));
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
}
SINGLE()
};
@@ -107,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -51,12 +51,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -170,7 +177,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -30,15 +30,20 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = 0;
loc.id = i;
#else
int loc = 0;
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
@@ -86,13 +91,12 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -126,7 +130,7 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.lock();
}
active_memory_ += size;
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -382,15 +382,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
@@ -398,6 +396,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
return;
}
}
}
// Use fallback kernel for settings not supported by cuDNN.
gemm_conv(

View File

@@ -210,6 +210,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -68,8 +64,8 @@ Device::~Device() {
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
// actual calls of CUDA APIs.
static thread_local int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
@@ -196,6 +193,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
@@ -220,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id);
}
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
@@ -233,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
@@ -253,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
@@ -295,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
@@ -306,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
@@ -354,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
@@ -366,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}
void CommandEncoder::synchronize() {

View File

@@ -83,7 +83,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
int get_num_ops();
void commit();
Device& device() {
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
// Make this device the current cuda device, this method is thread-safe.
void make_current();
CommandEncoder& get_command_encoder(Stream s);

View File

@@ -2,6 +2,8 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -257,8 +259,8 @@ struct Round {
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
T y = 1 / (1 + exp(abs(x)));
return (x < 0) ? y : 1 - y;
}
};
@@ -334,4 +336,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both
// This file must not include any host-only code, utilities that work under both
// host and device can be put here.
//
// See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
}
};
// CUDA 11 does not have host side arithmatic operators for half types.
// CUDA 11 does not have host side arithmetic operators for half types.
template <typename T>
struct Limits<
T,

View File

@@ -5,19 +5,24 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of CUDA by creating an event, so the CUDA runtime and
// our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}
@@ -35,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
@@ -46,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
}
void finalize(Stream s) {

View File

@@ -22,11 +22,15 @@ namespace cu {
namespace {
// Manage cached cudaEvent_t objects.
struct CudaEventPool {
static CudaEventHandle create(int flags) {
auto& cache = cache_for(flags);
class CudaEventPool {
public:
CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(flags);
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
@@ -34,54 +38,89 @@ struct CudaEventPool {
}
}
static void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event));
void release(CudaEventHandle event) {
if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
}
static std::vector<CudaEventHandle>& cache_for(int flags) {
static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
};
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_);
}
void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_);
}
void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream);
}
bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
}
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
@@ -245,7 +284,7 @@ struct EventImpl {
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
} else {
cuda = std::make_unique<cu::CopyableCudaEvent>();
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
}
}
};

View File

@@ -13,9 +13,12 @@
namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
explicit CudaEvent(int flags);
CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
@@ -40,6 +43,9 @@ class CudaEvent {
// returns true if record() has not been called.
bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
private:
CudaEventHandle event_;
};

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
scale_type_ = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@@ -352,6 +355,16 @@ void CublasGemm::execute(
}
}
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
@@ -368,12 +381,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
alpha_ptr,
b, // a and b are swapped
a_desc_,
a,
b_desc_,
&beta,
beta_ptr,
c ? c : out,
c ? c_desc_ : out_desc_,
out,

View File

@@ -115,6 +115,7 @@ class CublasGemm {
uint64_t M_;
uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;

View File

@@ -99,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() {
return cache;
}
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
@@ -109,7 +133,7 @@ bool read_cached_ptx(
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx");
auto ptx_path = get_ptx_path(cache_dir, module_name);
std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) {
@@ -122,7 +146,7 @@ bool read_cached_ptx(
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
@@ -144,16 +168,26 @@ void write_cached_ptx(
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
auto ptx_path = get_ptx_path(cache_dir, module_name);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size());
}
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
// Write the generated code
std::ofstream source_file(ptx_path.replace_extension(".cu"));
source_file << source_code;
}
@@ -297,7 +331,8 @@ void load_module(
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -314,7 +349,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false);
kernels[name] = std::make_tuple(kernel, false, 0);
}
}
@@ -358,7 +393,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
@@ -369,14 +404,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) {
configure_kernel(it->second.first);
configure_kernel(kernel);
}
it->second.second = true;
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
}
return it->second.first;
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -99,10 +99,13 @@ class JitModule {
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/device/utils.cuh is that the latter file only include
// device-only code.
// This file includes host-only utilities for writing CUDA kernels, the
// difference from backend/cuda/device/utils.cuh is that the latter file only
// include device-only code.
#pragma once
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
}
} // namespace mlx::core

View File

@@ -93,6 +93,10 @@ void gemm_and_bias(
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
@@ -157,7 +161,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,

View File

@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;

View File

@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,83 @@
#pragma once
struct __nv_fp8_e8m0 {
__device__ __nv_fp8_e8m0(float x) {
if (!std::isfinite(x)) {
__x = 0xFF;
return;
}
if (x < 0.0f) {
__x = 0x00;
return;
}
float le = std::log2f(x);
int n = static_cast<int>(std::nearbyintf(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
__x = static_cast<uint8_t>(n + 127);
}
__device__ operator float() {
if (__x == 0xFF) {
return std::numeric_limits<float>::quiet_NaN();
}
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
}
uint8_t __x{0};
};
struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
__x = 0x7;
return;
}
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
x = std::abs(x);
if (x > 5.0f) {
__x = 0x7;
} else if (x >= 3.5f) {
__x = 0x6;
} else if (x > 2.5f) {
__x = 0x5;
} else if (x >= 1.75f) {
__x = 0x4;
} else if (x > 1.25f) {
__x = 0x3;
} else if (x >= 0.75f) {
__x = 0x2;
} else if (x > 0.25f) {
__x = 0x1;
} else {
__x = 0x0;
}
__x |= sign_bit;
}
__device__ operator float() {
static const float LUT[16] = {
0.0f,
0.5f,
1.0f,
1.5f,
2.0f,
3.0f,
4.0f,
6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
return LUT[__x];
}
uint8_t __x{0};
};

View File

@@ -0,0 +1,216 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
if constexpr (bits == 8) {
return float(*(__nv_fp8_e4m3*)(&x));
} else {
return float(*(__nv_fp4_e2m1*)(&x));
}
}
};
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
return;
}
float w_thread = w[index];
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto scale = float(((ScaleType*)(scales))[gindex]);
out += oindex;
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}
} // namespace cu
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not quantize input with type float64.");
}
});
}
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
constexpr int uint8_per_uint32 = 4;
int packs_per_int = 8 / bits;
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_dequantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_dequantize<T, 16, 4, false>;
}
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
}
});
}
} // namespace mlx::core

View File

@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
}
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
}
}
}

View File

@@ -24,4 +24,22 @@ void affine_dequantize(
cu::CommandEncoder& enc,
const Stream& s);
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu
inline auto output_grid_for_col_reduce(
@@ -206,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -230,12 +271,55 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
});
});
});
}
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -258,6 +342,13 @@ void col_reduce(
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -7,8 +7,6 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -83,7 +81,8 @@ struct RowReduceArgs {
};
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
__global__ void
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[M][N];
U accs[M];
AlignedVector<T, N> vals[M];
AlignedVector<U, M> accs;
for (int i = 0; i < M; i++) {
accs[i] = init;
}
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
in += start_row * size + block.thread_rank() * N;
out += start_row;
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
vals[k] = load_vector<N>(in + k * size, 0);
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
in += block.size() * N;
}
if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + final_offset,
vals[k],
size,
cast_to<T>(init));
for (int i = 0; i < N; i++) {
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
? in[k * size + i]
: cast_to<T>(init);
}
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
__shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs, shared_accumulators, op, init);
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
store_vector(out, 0, accs);
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_looped(
T* in,
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
@@ -185,37 +163,61 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
const size_t full_blocks = args.row_size / (block.size() * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
{
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
cast_to<T>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init);
@@ -234,8 +236,6 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
@@ -250,14 +250,15 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
@@ -267,6 +268,7 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -282,8 +284,6 @@ void row_reduce_looped(
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -295,34 +295,27 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
});
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
});
});
}

View File

@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) {
if (topt == TernaryOpType::VectorVectorVector ||
topt == TernaryOpType::ScalarScalarScalar) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} else {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
});
}

View File

@@ -1,284 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(strides),
ndim);
}
});
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, ToFP8>) {
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
}
if (std::is_same_v<Op, FromFP8>) {
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
}
return false;
}

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file include utilies that are used by C++ code (i.e. .cpp files).
// This file include utilities that are used by C++ code (i.e. .cpp files).
#pragma once
@@ -12,6 +12,7 @@ namespace mlx::core {
namespace cu {
class Device;
}
struct Dtype;
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
explicit CudaStream(cu::Device& device);
};
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
} // namespace mlx::core

View File

@@ -5,9 +5,9 @@
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
Worker::Worker(Device& d)
: signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {

View File

@@ -15,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
explicit Worker(Device& d);
~Worker();
Worker(const Worker&) = delete;

View File

@@ -29,7 +29,7 @@ make_jit_source(
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
kernels/fp4.h)
make_jit_source(gemv_masked)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -32,7 +32,6 @@ namespace metal {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
@@ -41,7 +40,8 @@ MetalAllocator::MetalAllocator()
residency_set_.erase(buf);
}
buf->release();
}) {
}),
residency_set_(device_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =

View File

@@ -65,7 +65,6 @@ class MetalAllocator : public allocator::Allocator {
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
size_t num_resources_{0};
size_t resource_limit_{0};

View File

@@ -327,6 +327,10 @@ CustomKernelFunction metal_kernel(
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// silence some warnings
(void)is_precompiled_;
(void)shared_memory_;
auto& s = stream();
std::vector<array> copies;

View File

@@ -72,6 +72,19 @@ MTL::Library* try_load_bundle(
}
return nullptr;
}
MTL::Library* try_load_framework(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string resource_path = std::string(url->fileSystemRepresentation()) +
"/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
return nullptr;
}
#endif
// Firstly, search for the metallib in the same path as this binary
@@ -103,6 +116,17 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
return {library, nullptr};
}
}
// if SWIFTPM_BUNDLE is a framework identifier, try loading from that
auto frameworks = NS::Bundle::allFrameworks();
for (int i = 0, c = (int)frameworks->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));
if (!strcmp(bundle->bundleIdentifier()->utf8String(), SWIFTPM_BUNDLE)) {
library = try_load_framework(device, bundle->resourceURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
}
}
#endif
return {nullptr, nullptr};
}
@@ -471,6 +495,10 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
@@ -724,7 +752,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
kernel_map_.insert({hash_name, kernel});
return kernel;
}

View File

@@ -71,7 +71,7 @@ void eval(array& arr) {
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
@@ -82,7 +82,7 @@ void finalize(Stream s) {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}

View File

@@ -150,7 +150,6 @@ FFTPlan plan_fft(int n) {
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
@@ -313,8 +312,6 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
@@ -484,8 +481,6 @@ void four_step_fft(
std::vector<array>& copies,
const Stream& s,
bool in_place) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
@@ -786,7 +781,6 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);

View File

@@ -378,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
@@ -393,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);

View File

@@ -24,7 +24,7 @@ const char* hadamard();
const char* logsumexp();
const char* quantized_utils();
const char* quantized();
const char* fp4_quantized();
const char* fp_quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();

View File

@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
kernel_source += get_template_definition(
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
kernel_source += get_template_definition(
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
kernel_source += get_template_definition(
"vn_" + lib_name, "ternary_v", t_str, op, false, false);
kernel_source += get_template_definition(
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
kernel_source += get_template_definition(
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
kernel_source += get_template_definition(
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
@@ -814,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
template_def);
return kernel_source;
});
@@ -841,13 +856,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
std::string kernel_source;
concatenate(
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
bool is_affine = mode == "affine";
concatenate(
kernel_source,
metal::quantized(),
is_affine ? metal::quantized() : metal::fp_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
get_type_string(x.dtype()),
group_size,
bits,
@@ -857,23 +872,6 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
wm,
wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -6,6 +6,7 @@ set(BASE_HEADERS
defines.h
erf.h
expm1f.h
fp8.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
@@ -109,7 +110,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

View File

@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr threadgroup complex64_t& operator+=(
threadgroup complex64_t& a,
complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}

View File

@@ -0,0 +1,59 @@
#pragma once
constexpr constant static float FP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
struct fp4_e2m1 {
fp4_e2m1(float x) {
if (metal::isnan(x)) {
bits = 0x7;
return;
}
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
x = metal::abs(x);
if (x > 5.0f) {
bits = 0x7;
} else if (x >= 3.5f) {
bits = 0x6;
} else if (x > 2.5f) {
bits = 0x5;
} else if (x >= 1.75f) {
bits = 0x4;
} else if (x > 1.25f) {
bits = 0x3;
} else if (x >= 0.75f) {
bits = 0x2;
} else if (x > 0.25f) {
bits = 0x1;
} else {
bits = 0x0;
}
bits |= sign_bit;
}
operator float() {
half converted = as_type<half>(ushort((bits & 7) << 9));
converted *= 16384.0;
converted = bits & 8 ? -converted : converted;
return converted;
}
uint8_t bits;
};

View File

@@ -1,127 +0,0 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
uint8_t, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
aligned, \
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
D, \
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
uint8_t, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -0,0 +1,82 @@
#pragma once
struct fp8_e4m3 {
template <typename T>
fp8_e4m3(T f) {
// From PyTorch
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 543 << 21;
uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign;
if (f_bits >= fp8_max) {
// Default behavior saturates to min/max
bits = 0x7E;
} else {
if (f_bits < (121 << 23)) {
f_bits = as_type<uint32_t>(
as_type<float>(f_bits) + as_type<float>(denorm_mask));
bits = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
f_bits += mant_odd;
bits = static_cast<uint8_t>(f_bits >> 20);
}
}
bits |= static_cast<uint8_t>(sign >> 24);
}
operator float() {
// From PyTorch:
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
uint32_t w = static_cast<uint32_t>(bits) << 24;
uint32_t sign = w & 0x80000000;
uint32_t nonsign = w & 0x7FFFFFFF;
uint32_t renorm_shift = metal::clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
int32_t inf_nan_mask =
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return as_type<float>(result);
}
uint8_t bits;
};
struct fp8_e8m0 {
fp8_e8m0(float x) {
if (!metal::isfinite(x)) {
bits = 0xFF;
return;
}
if (x < 0.0f) {
bits = 0x00;
return;
}
float le = metal::log2(x);
int n = int(metal::round(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
bits = static_cast<uint8_t>(n + 127);
}
operator bfloat16_t() {
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
return as_type<bfloat16_t>(out);
}
operator float() {
return static_cast<float>(this->operator bfloat16_t());
}
uint8_t bits;
};

View File

@@ -3,6 +3,9 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp8.h"
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
@@ -26,15 +29,31 @@ inline constexpr short get_bytes_per_pack() {
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
return T(*(thread fp8_e8m0*)(&s));
}
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) {
@@ -59,80 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
constexpr constant static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
inline U
qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline void qouter(
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0xf];
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
result[2 * i] += x * scale * Dequantize<4>{}(w[i]);
result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);
}
}
@@ -155,8 +135,7 @@ template <
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
typename S>
short group_size>
struct QuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
@@ -183,12 +162,12 @@ struct QuantizedBlockLoader {
threadgroup T* dst;
const device uint8_t* src;
const device S* scales;
const device uint8_t* scales;
threadgroup T* lut;
QuantizedBlockLoader(
const device uint8_t* src_,
const device S* scales_,
const device uint8_t* scales_,
const int src_ld_,
threadgroup T* dst_,
threadgroup T* lut_,
@@ -208,7 +187,10 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
if (simd_group_id == 0 && simd_lane_id < 16) {
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
void load_unsafe() const {
@@ -270,20 +252,17 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl(
template <typename T, int group_size, int bits, int D>
METAL_FUNC void fp_qmv_quad_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE;
@@ -295,7 +274,6 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
@@ -311,11 +289,11 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -327,18 +305,17 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_fast_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_fast_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
@@ -353,7 +330,6 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -374,7 +350,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -390,18 +366,17 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
@@ -418,7 +393,6 @@ METAL_FUNC void mxfp4_qmv_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -448,8 +422,8 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -468,7 +442,7 @@ METAL_FUNC void mxfp4_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -497,7 +471,7 @@ METAL_FUNC void mxfp4_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -517,7 +491,7 @@ METAL_FUNC void mxfp4_qmv_impl(
U s = dequantize_scale<U>(sl[0]);
result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
@@ -529,18 +503,17 @@ METAL_FUNC void mxfp4_qmv_impl(
}
}
template <typename T, const int group_size, typename S>
METAL_FUNC void mxfp4_qvm_impl(
template <typename T, const int group_size, int bits>
METAL_FUNC void fp_qvm_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack();
@@ -561,8 +534,6 @@ METAL_FUNC void mxfp4_qvm_impl(
thread U scale = 0;
thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
@@ -584,7 +555,7 @@ METAL_FUNC void mxfp4_qvm_impl(
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -597,7 +568,7 @@ METAL_FUNC void mxfp4_qvm_impl(
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -612,7 +583,7 @@ METAL_FUNC void mxfp4_qvm_impl(
scale = 0;
}
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
}
// Accumulate in the simdgroup
@@ -633,14 +604,14 @@ METAL_FUNC void mxfp4_qvm_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_t_impl(
METAL_FUNC void fp_qmm_t_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -677,8 +648,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
@@ -759,13 +729,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_n_impl(
METAL_FUNC void fp_qmm_n_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -803,8 +773,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
auto wl = (const device uint8_t*)w;
@@ -891,11 +860,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
}
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
@@ -926,11 +895,11 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T*& y,
@@ -976,10 +945,10 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad(
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void fp_qmv_quad(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -993,9 +962,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
@@ -1013,26 +980,14 @@ template <typename T, int group_size, typename S, int D, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, S, D>(
w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
fp_qmv_quad_impl<T, group_size, bits, D>(
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
}
template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast(
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void fp_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1064,15 +1019,14 @@ template <typename T, int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1104,15 +1058,14 @@ template <typename T, const int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1144,15 +1097,14 @@ template <typename T, const int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k(
template <typename T, const int group_size, int bits, int split_k = 32>
[[kernel]] void fp_qvm_split_k(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1188,32 +1140,22 @@ template <typename T, const int group_size, typename S, int split_k = 32>
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w,
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_t(
[[kernel]] void fp_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1254,21 +1196,21 @@ template <
s_strides,
tid);
}
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_n(
[[kernel]] void fp_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1311,14 +1253,14 @@ template <
tid);
}
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv_fast(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1360,15 +1302,14 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1410,15 +1351,14 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qvm(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1460,22 +1400,21 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_t(
[[kernel]] void fp_gather_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1526,20 +1465,20 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_n(
[[kernel]] void fp_gather_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1591,24 +1530,24 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
int group_size,
typename S,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void mxfp4_gather_qmm_rhs(
[[kernel]] void fp_gather_qmm_rhs(
const device T* x,
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device uint32_t* indices,
device T* y,
const constant int& M,
@@ -1644,8 +1583,7 @@ template <
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
@@ -1789,3 +1727,78 @@ template <
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void fp_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device uint8_t* scales [[buffer(2)]],
uint2 tidx [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr bool use_mx_scale = group_size == 32;
size_t index = tidx.x + grid_dim.x * size_t(tidx.y);
float scale;
float w_thread = w[index];
if (use_mx_scale) {
scale = simd_max(abs(w_thread));
} else {
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
scale = tidx.x < 16 ? w_max_l : w_max_r;
}
scale /= bits == 4 ? 6.0f : 448.0f;
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.bits;
scale = float(s);
// Write out the scales and biases
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = simd_shuffle_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void fp_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr bool use_mx_scale = group_size == 32;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * pack_factor;
size_t gindex = oindex / group_size;
out += oindex;
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
auto q_scale = ((device ScaleType*)(scales))[gindex];
auto scale = float(q_scale);
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}

View File

@@ -0,0 +1,147 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp_quantized.h"
#define instantiate_quantized(mode, name, type) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4", \
fp_ ## name, \
type, \
32, \
4)
#define instantiate_quantized_batched(mode, name, type, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
batched)
#define instantiate_quantized_aligned(mode, name, type, aligned) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
fp_ ## name, \
type, \
32, \
4, \
aligned)
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
aligned, \
batched)
#define instantiate_quantized_quad(mode, name, type, D, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
D, \
batched)
#define instantiate_quantized_split_k(mode, name, type, split_k) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
fp_ ## name, \
type, \
32, \
4, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
4, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(mode, name, type) \
instantiate_quantized_batched(mode, name, type, 1) \
instantiate_quantized_batched(mode, name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
instantiate_quantized(mxfp4, gather_qmv, type) \
instantiate_quantized(mxfp4, gather_qvm, type) \
instantiate_quantized(mxfp4, gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
instantiate_kernel( \
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_quantize, \
type, \
group_size, \
bits) \
instantiate_kernel( \
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_dequantize, \
type, \
group_size, \
bits)
#define instantiate_quantize_dequantize_modes(type) \
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type) \
instantiate_quantize_dequantize_modes(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -15,6 +15,15 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
template <typename U>
struct DefaultAccT {
using type = float;
};
template <>
struct DefaultAccT<complex64_t> {
using type = complex64_t;
};
template <
typename T,
const int BM, /* Threadgroup rows (in simdgroups) */
@@ -24,8 +33,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -246,8 +257,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVTKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -453,7 +466,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -530,6 +543,7 @@ template <
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
instantiate_gemv_blocks(complex64, complex64_t);
template <
typename T,
@@ -565,7 +579,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -636,6 +650,7 @@ template <
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_bs_blocks(complex64, complex64_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
@@ -672,7 +687,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -738,7 +753,8 @@ template <
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
template <
typename T,
@@ -773,8 +789,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
threadgroup float tgp_memory
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -848,4 +864,5 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on

View File

@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
b = w;
}
template <typename T, typename = void>
struct Init {
static constexpr constant T v = Limits<T>::max;
};
template <typename T>
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
};
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
static constexpr constant T init = Init<T>::v;
METAL_FUNC bool operator()(T a, T b) const {
if constexpr (
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
bool an = isnan(a);
bool bn = isnan(b);
if (an | bn) {
return (!an) & bn;
}
}
return a < b;
}
};

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -23,10 +23,12 @@
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
// clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \

View File

@@ -60,6 +60,7 @@
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
#define instantiate_accum(oname, otype, aname, atype) \
instantiate_kernel( \
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on
instantiate_accum(float32, float, float32, float);
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on

View File

@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
}
}
template <typename InT>
struct TransformNone<complex64_t, InT> {
static METAL_FUNC complex64_t apply(complex64_t x) {
return x;
}
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
return x;
}
};
template <
typename T,
typename U,
@@ -731,5 +741,406 @@ struct BlockMMA {
}
};
template <
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType,
typename Epilogue>
struct BlockMMA<
complex64_t,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
lda_tgp,
ldb_tgp,
AccumType,
Epilogue> {
static_assert(
metal::is_same_v<AccumType, float>,
"BlockMMA<complex64_t,...> expects float accumulators");
static_assert(
metal::is_same_v<U, complex64_t>,
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// When indexing complex as float[2]
STEEL_CONST short A_str_m_f = A_str_m * 2;
STEEL_CONST short A_str_k_f = A_str_k * 2;
STEEL_CONST short B_str_k_f = B_str_k * 2;
STEEL_CONST short B_str_n_f = B_str_n * 2;
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
// Accumulators (real/imag)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
// Offsets within threadgroup
short sm, sn;
short As_offset, Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
sm += tm;
sn += tn;
}
/* Karatsuba MMA: 3 real MMAs per K-chunk */
METAL_FUNC void mma(
const threadgroup complex64_t* As,
const threadgroup complex64_t* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
threadgroup const float* As_f =
reinterpret_cast<threadgroup const float*>(As);
threadgroup const float* Bs_f =
reinterpret_cast<threadgroup const float*>(Bs);
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
simdgroup_barrier(mem_flags::mem_none);
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
tile_matmad(P, Ar, Br, P);
tile_matmad(Q, Ai, Bi, Q);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
Ar.elems()[i] += Ai.elems()[i];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
Br.elems()[i] += Bi.elems()[i];
tile_matmad(R, Ar, Br, R);
// C_r += P - Q ; C_i -= Q
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
const auto p = P.elems()[i];
const auto q = Q.elems()[i];
const auto r = R.elems()[i];
Ctile_r.elems()[i] += (p - q);
Ctile_i.elems()[i] += (r - p - q);
}
// Progress to next simdgroup tile
As_f += tile_stride_a_f;
Bs_f += tile_stride_b_f;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off = (i * TM_stride) * ldd + (j * TN_stride);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
if (stop.y <= 0 || stop.x <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; ++i) {
const int row = i * TM_stride;
if (row >= start.y && row < stop.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; ++j) {
const int off = row * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
const int col = j * TN_stride + k;
if (col >= start.x && col < stop.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
int off = (i * TM_stride) * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
complex64_t out = epilogue_op.apply(
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
Ctile_r.elems()[i] = out.real;
Ctile_i.elems()[i] = out.imag;
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
complex64_t out = epilogue_op.apply(
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
complex64_t tmp[kelems];
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x &&
(i * TM_stride) < dst_tile_dims.y) {
tmp[k] = C[offset_c + k * fdc];
} else {
tmp[k] = complex64_t(0.0f, 0.0f);
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[off_d + k] =
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off_d + k] = epilogue_op.apply(
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -48,7 +48,8 @@ struct TransformAxpby {
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
return static_cast<OutT>(
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
}
};

View File

@@ -1,6 +1,11 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
@@ -11,16 +16,25 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
}
}
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
@@ -32,11 +46,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
}
}

View File

@@ -9,8 +9,12 @@
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \
instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \
instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \
instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \
instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
@@ -21,7 +25,9 @@
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \
instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \
instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \

View File

@@ -9,11 +9,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
}
}
@@ -28,11 +28,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
}
}
@@ -57,7 +57,7 @@ template <
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
out[out_idx++] = static_cast<U>(Op()(in[idx]));
idx += xstride;
}
}

Some files were not shown because too many files have changed in this diff Show More