Compare commits

...

50 Commits

Author SHA1 Message Date
Awni Hannun
39b04ce638 use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-31 11:49:59 -07:00
Mike Drob
d9e6349657 fix docs path (#2719) 2025-10-30 19:12:49 -05:00
Angelos Katharopoulos
b901a9f311 Fix the order of hosts in the ring (#2718)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-30 15:02:39 -07:00
Awni Hannun
68c5fa1c95 fix memory count bug (#2717) 2025-10-30 14:27:15 -07:00
Christopher Webb
793a31eeb6 Fix missing domain_uuid_key in thunderbolt ring setup (#2682) 2025-10-30 13:17:20 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
Awni Hannun
d1e06117e8 bump python (#2694) 2025-10-27 11:34:31 -07:00
Awni Hannun
539d8322d1 add median op (#2705) 2025-10-27 11:33:42 -07:00
Awni Hannun
c4767d110f fix addmm cpu (#2699) 2025-10-27 11:33:32 -07:00
David Koski
895217f25b optionally load metallib from framework (#2702)
* optionally load metallib from framework

* pre-commit

* adjust logic
2025-10-27 07:52:03 -07:00
Manuel Villanueva
0cfeeb60ca Einsum error msg improvement (#2690)
* Improved error message for Einsum

* Modifications via pre-commit

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-27 06:31:47 -07:00
Ronan Collobert
8f8af61a37 fix warnings showing up with -Wall (#2692) 2025-10-24 11:43:35 -07:00
Manuel Villanueva
233384161e Improved mx.split() docs (#2689)
* Improved mx.split() documentation

* Fix typo in docstring for array split function

* add example

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-24 09:48:41 -07:00
Awni Hannun
5bcf3a6794 format 2025-10-22 16:08:47 -07:00
wickedcoder
7707196297 Merge commit from fork
* add length validation to the header

* fix accessing out of bound index with .at()
2025-10-22 15:31:25 -07:00
wickedcoder
7e3471c987 Merge commit from fork
* add tensor->weights_data validation

* add null pointer check for tensor
2025-10-22 15:31:03 -07:00
Awni Hannun
9f0ba3ddf1 patch bump (#2680) 2025-10-17 12:12:07 -07:00
Awni Hannun
4bce5f9b2d suppress gcc 10.1 warnings (#2679)
* suppress gcc 10.1 warnings

* suppress gcc 10.1 warnings
2025-10-17 12:09:21 -07:00
Anastasiia Filippova
e9eab527eb Nccl timeout (#2673)
* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
2025-10-14 12:29:54 -07:00
Awni Hannun
36ca62dba8 remove unused unary file (#2672) 2025-10-13 19:36:26 -07:00
Manuel Villanueva
9cbb1b0148 Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-13 14:36:45 -07:00
Fabrizio Milo
9bfc476d72 Normalize README bullet formatting (#2671) 2025-10-13 12:13:30 -07:00
Awni Hannun
25e2356316 speed up scalars (#2669) 2025-10-13 12:10:15 -07:00
Awni Hannun
226a1d24e0 Debug cuda conv (#2662)
* use t4

* use t4
2025-10-10 16:12:47 -07:00
Awni Hannun
630350ad3e Precise sigmoid (#2659)
* bump patch

* Sigmoid matches PyTorch and is more precise on tails
2025-10-10 10:05:23 -07:00
Awni Hannun
380aeb58ae enable admm low-precision cpu (#2661) 2025-10-10 09:50:54 -07:00
Awni Hannun
f37389d100 bump patch (#2658) 2025-10-10 08:36:41 -07:00
Awni Hannun
e89e8b4272 Export with callback (#2612)
* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
2025-10-08 19:24:33 -07:00
AN Long
85a8824a8c Fix cumulative operations when axis=None (#2653) 2025-10-08 15:25:38 -07:00
Awni Hannun
f5d4397e5c Fix fast synch when fence is waited before a command buffer is created (#2657) 2025-10-08 11:23:46 -07:00
Awni Hannun
343e33b6d5 fix all_gather vjp (#2654) 2025-10-07 06:05:23 -07:00
Angelos Katharopoulos
0073096dd1 Split name into directories for cuda jit (#2656) 2025-10-07 01:52:58 -07:00
Angelos Katharopoulos
e3d004fed9 Fix and refactor row-reduce (#2650) 2025-10-07 01:51:08 -07:00
Awni Hannun
a393435d28 Speed up compile for node with many parents (#2649) 2025-10-03 19:30:36 -07:00
Awni Hannun
a7a94b29d7 Fix compile when outputs change (#2648) 2025-10-03 08:40:57 -07:00
Daniel Yeh
22a5da76c8 Faster complex matmul (#2571) 2025-10-02 23:33:15 -07:00
Andrey Portnoy
287c63a093 Configure CMake to export compile_commands.json (#2645)
This helps enable LSP for code navigation using clangd.
2025-10-02 15:40:32 -07:00
Awni Hannun
1c9ae1eaa1 cuda fix flaky test (#2646) 2025-10-02 15:40:04 -07:00
Angelos Katharopoulos
c2c3e0b0a2 [CUDA] Add a small column specialization to reduce (#2642) 2025-10-02 14:41:05 -07:00
Awni Hannun
b0cc71ae71 Faster triu, tril, where with scalar (#2644) 2025-10-02 12:21:27 -07:00
Awni Hannun
e88f2d4a8e fix cross entropy axis param (#2641)
* fix cross entropy axis param

* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
9cee557423 Fix status message (#2638) 2025-10-01 16:43:45 -07:00
Awni Hannun
bbf1423953 wait for tasks in cuda (#2636) 2025-09-30 16:08:46 -07:00
Angelos Katharopoulos
eb24267b56 Compile now can attach arbitrary data to an entry (#2634) 2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5 fix for max block dim (#2631) 2025-09-29 08:59:25 -07:00
AN Long
e76a8dd5c5 Fix incorrect path and typos (#2630) 2025-09-28 06:03:04 -07:00
Cheng
b466dea982 [CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
2025-09-27 11:27:17 +09:00
157 changed files with 4898 additions and 1772 deletions

View File

@@ -26,9 +26,9 @@ jobs:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install python@3.10
brew install doxygen
python3.9 -m venv env
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -140,7 +140,7 @@ jobs:
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
@@ -273,7 +273,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
@@ -328,7 +328,7 @@ jobs:
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -351,7 +351,7 @@ jobs:
parameters:
python_version:
type: string
default: "3.9"
default: "3.10"
build_env:
type: string
default: ""
@@ -387,7 +387,7 @@ jobs:
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
@@ -484,7 +484,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
@@ -503,7 +503,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
@@ -546,13 +546,13 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
@@ -564,14 +564,14 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:

View File

@@ -0,0 +1,24 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

68
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,68 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Check if build actually worked
shell: bash
run: python -c "import mlx.core"
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

78
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,78 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -0,0 +1,22 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

124
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,124 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
build-jit:
description: 'Whether to build with JIT'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip cmake setuptools
uv pip install nanobind==2.4.0 \
numpy torch tensorflow unittest-xml-reporting
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.build-jit == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 13 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 13.0
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

83
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

31
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

93
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,93 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

View File

@@ -1,20 +1,46 @@
on:
pull_request:
branches:
- main
name: Build and Test
on: pull_request
permissions:
contents: read
jobs:
check_lint:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs

188
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,188 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
build_documentation:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
# TODO: 3.14 had issues finding a compatible tensorflow
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiples: true
path: artifacts
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiples: true
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: build_cuda_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: build_linux_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: build_mac_release
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: artifacts
- name: Display structure of downloaded files
run: ls -R artifacts
# - name: Publish package distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,4 +1,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -87,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
@@ -173,7 +179,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html),
@@ -91,7 +91,7 @@ Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```
```text
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True):
for dtype in ("float32", "float16"):
for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
)
)
plt.close(fig)

View File

@@ -16,7 +16,7 @@ silicon computer is
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- Using a native Python >= 3.10
- macOS >= 13.5
.. note::
@@ -39,7 +39,7 @@ requirements:
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
CPU-only (Linux)
@@ -55,7 +55,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
- Python >= 3.10
Troubleshooting

View File

@@ -112,6 +112,7 @@ Operations
max
maximum
mean
median
meshgrid
min
minimum

View File

@@ -241,8 +241,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
status(Status::unscheduled),
inputs(std::move(inputs)) {
init();
}

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}};
return {Shape{1}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}};
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -11,6 +11,8 @@ namespace mlx::core {
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General,
};
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else {
topt = TernaryOpType::General;
}
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
b.flags());
}
break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General:
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,

View File

@@ -46,7 +46,6 @@ void eig_impl(
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -88,4 +88,47 @@ void matmul<double>(
}
}
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
const void* a_mask_ptr = nullptr;
const void* b_mask_ptr = nullptr;
const void* out_mask_ptr = nullptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
bool a_mask_bool = false;
bool b_mask_bool = false;
bool out_mask_bool = false;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
@@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -91,7 +91,6 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
@@ -108,6 +107,9 @@ void matmul_general(
} else if (out.dtype() == float64) {
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
@@ -128,10 +130,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -445,7 +448,6 @@ void mxfp4_qmm(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -487,7 +489,6 @@ void mxfp4_qmm_t(
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
@@ -1104,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -9,7 +10,7 @@
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// There seems to be a bug in simd/base_simd.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -15,6 +15,18 @@ namespace mlx::core {
namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -27,7 +39,7 @@ struct StridedIterator {
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
: stride_(stride), ptr_(ptr + offset * stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
std::stable_sort(st, ed, nan_aware_less<T>);
src_it.step();
}
}
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
std::nth_element(st, md, ed, nan_aware_less<T>);
}
}
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -83,8 +83,6 @@ void svd_impl(
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;

View File

@@ -77,7 +77,8 @@ struct Real {
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x));
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
}
SINGLE()
};
@@ -107,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -51,12 +51,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -170,7 +177,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -30,15 +30,20 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = 0;
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = 0;
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
@@ -86,13 +91,12 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -126,7 +130,7 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.lock();
}
active_memory_ += size;
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -382,20 +382,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
}

View File

@@ -210,6 +210,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -68,8 +64,8 @@ Device::~Device() {
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
// actual calls of CUDA APIs.
static thread_local int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
@@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
@@ -196,6 +193,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
@@ -220,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id);
}
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
@@ -233,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
@@ -253,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
@@ -295,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
@@ -306,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
@@ -354,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
@@ -366,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}
void CommandEncoder::synchronize() {

View File

@@ -83,7 +83,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
int get_num_ops();
void commit();
Device& device() {
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
// Make this device the current cuda device, this method is thread-safe.
void make_current();
CommandEncoder& get_command_encoder(Stream s);

View File

@@ -2,6 +2,8 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -257,8 +259,8 @@ struct Round {
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x)));
return (x < 0) ? 1 - y : y;
T y = 1 / (1 + exp(abs(x)));
return (x < 0) ? y : 1 - y;
}
};
@@ -334,4 +336,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both
// This file must not include any host-only code, utilities that work under both
// host and device can be put here.
//
// See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
}
};
// CUDA 11 does not have host side arithmatic operators for half types.
// CUDA 11 does not have host side arithmetic operators for half types.
template <typename T>
struct Limits<
T,

View File

@@ -5,19 +5,24 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
void new_stream(Stream s) {
// Force initalization of CUDA by creating an event, so the CUDA runtime and
// our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}
@@ -35,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
@@ -46,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
}
void finalize(Stream s) {

View File

@@ -22,11 +22,15 @@ namespace cu {
namespace {
// Manage cached cudaEvent_t objects.
struct CudaEventPool {
static CudaEventHandle create(int flags) {
auto& cache = cache_for(flags);
class CudaEventPool {
public:
CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(flags);
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
@@ -34,54 +38,89 @@ struct CudaEventPool {
}
}
static void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event));
void release(CudaEventHandle event) {
if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
}
static std::vector<CudaEventHandle>& cache_for(int flags) {
static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
};
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_);
}
void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_);
}
void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream);
}
bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
}
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
@@ -245,7 +284,7 @@ struct EventImpl {
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
} else {
cuda = std::make_unique<cu::CopyableCudaEvent>();
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
}
}
};

View File

@@ -13,9 +13,12 @@
namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
explicit CudaEvent(int flags);
CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
@@ -40,6 +43,9 @@ class CudaEvent {
// returns true if record() has not been called.
bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
private:
CudaEventHandle event_;
};

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
scale_type_ = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@@ -352,6 +355,16 @@ void CublasGemm::execute(
}
}
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
@@ -368,12 +381,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
alpha_ptr,
b, // a and b are swapped
a_desc_,
a,
b_desc_,
&beta,
beta_ptr,
c ? c : out,
c ? c_desc_ : out_desc_,
out,

View File

@@ -115,6 +115,7 @@ class CublasGemm {
uint64_t M_;
uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;

View File

@@ -99,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() {
return cache;
}
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
@@ -109,7 +133,7 @@ bool read_cached_ptx(
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx");
auto ptx_path = get_ptx_path(cache_dir, module_name);
std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) {
@@ -122,7 +146,7 @@ bool read_cached_ptx(
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
@@ -144,16 +168,26 @@ void write_cached_ptx(
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
auto ptx_path = get_ptx_path(cache_dir, module_name);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size());
}
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
// Write the generated code
std::ofstream source_file(ptx_path.replace_extension(".cu"));
source_file << source_code;
}
@@ -297,7 +331,8 @@ void load_module(
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -314,7 +349,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false);
kernels[name] = std::make_tuple(kernel, false, 0);
}
}
@@ -358,7 +393,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
@@ -369,14 +404,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) {
configure_kernel(it->second.first);
configure_kernel(kernel);
}
it->second.second = true;
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
}
return it->second.first;
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -99,10 +99,13 @@ class JitModule {
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/device/utils.cuh is that the latter file only include
// device-only code.
// This file includes host-only utilities for writing CUDA kernels, the
// difference from backend/cuda/device/utils.cuh is that the latter file only
// include device-only code.
#pragma once
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
}
} // namespace mlx::core

View File

@@ -93,6 +93,10 @@ void gemm_and_bias(
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
@@ -157,7 +161,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,

View File

@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;

View File

@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,83 @@
#pragma once
struct __nv_fp8_e8m0 {
__device__ __nv_fp8_e8m0(float x) {
if (!std::isfinite(x)) {
__x = 0xFF;
return;
}
if (x < 0.0f) {
__x = 0x00;
return;
}
float le = std::log2f(x);
int n = static_cast<int>(std::nearbyintf(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
__x = static_cast<uint8_t>(n + 127);
}
__device__ operator float() {
if (__x == 0xFF) {
return std::numeric_limits<float>::quiet_NaN();
}
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
}
uint8_t __x{0};
};
struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
__x = 0x7;
return;
}
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
x = std::abs(x);
if (x > 5.0f) {
__x = 0x7;
} else if (x >= 3.5f) {
__x = 0x6;
} else if (x > 2.5f) {
__x = 0x5;
} else if (x >= 1.75f) {
__x = 0x4;
} else if (x > 1.25f) {
__x = 0x3;
} else if (x >= 0.75f) {
__x = 0x2;
} else if (x > 0.25f) {
__x = 0x1;
} else {
__x = 0x0;
}
__x |= sign_bit;
}
__device__ operator float() {
static const float LUT[16] = {
0.0f,
0.5f,
1.0f,
1.5f,
2.0f,
3.0f,
4.0f,
6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
return LUT[__x];
}
uint8_t __x{0};
};

View File

@@ -0,0 +1,216 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
namespace mlx::core {
namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits>
struct Dequantize {
__device__ float operator()(uint8_t x) {
if constexpr (bits == 8) {
return float(*(__nv_fp8_e4m3*)(&x));
} else {
return float(*(__nv_fp4_e2m1*)(&x));
}
}
};
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) {
return;
}
float w_thread = w[index];
cg::greater<float> max_op;
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.__x;
scale = float(s);
// Write out the scales
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = warp.shfl_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, int group_size, int bits, bool use_mx_scale>
__global__ void
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
using ScaleType =
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
auto scale = float(((ScaleType*)(scales))[gindex]);
out += oindex;
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}
} // namespace cu
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>;
}
bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<uint8_t>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not quantize input with type float64.");
}
});
}
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s) {
constexpr int uint8_per_uint32 = 4;
int packs_per_int = 8 / bits;
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
if (bits == 8) {
kernel = cu::fp_dequantize<T, 32, 8, true>;
} else if (group_size == 16) {
kernel = cu::fp_dequantize<T, 16, 4, false>;
}
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
w.data<T>(),
w.size());
} else {
throw std::runtime_error(
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
}
});
}
} // namespace mlx::core

View File

@@ -57,23 +57,30 @@ void fast::Quantize::eval_gpu(
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
}
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
}
}
}

View File

@@ -24,4 +24,22 @@ void affine_dequantize(
cu::CommandEncoder& enc,
const Stream& s);
void fp_quantize(
const array& w,
array& wq,
array& scales,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
void fp_dequantize(
const array& wq,
const array& scales,
array& w,
int group_size,
int bits,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu
inline auto output_grid_for_col_reduce(
@@ -206,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -230,12 +271,55 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
});
});
});
}
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -258,6 +342,13 @@ void col_reduce(
// Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -7,8 +7,6 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -83,7 +81,8 @@ struct RowReduceArgs {
};
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
__global__ void
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[M][N];
U accs[M];
AlignedVector<T, N> vals[M];
AlignedVector<U, M> accs;
for (int i = 0; i < M; i++) {
accs[i] = init;
}
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
in += start_row * size + block.thread_rank() * N;
out += start_row;
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} else {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
vals[k] = load_vector<N>(in + k * size, 0);
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
in += block.size() * N;
}
if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + final_offset,
vals[k],
size,
cast_to<T>(init));
for (int i = 0; i < N; i++) {
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
? in[k * size + i]
: cast_to<T>(init);
}
}
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
__shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs, shared_accumulators, op, init);
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) {
out[i] = accs[i];
}
store_vector(out, 0, accs);
} else {
short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) {
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_looped(
T* in,
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
const size_t full_blocks = args.row_size / (block.size() * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
cast_to<T>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
{
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
__shared__ U shared_accumulators[32];
@@ -234,8 +236,6 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
@@ -250,14 +250,15 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
@@ -267,6 +268,7 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
}
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -282,8 +284,6 @@ void row_reduce_looped(
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -295,34 +295,27 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims
args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
});
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
});
});
}

View File

@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) {
if (topt == TernaryOpType::VectorVectorVector ||
topt == TernaryOpType::ScalarScalarScalar) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} else {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
});
}

View File

@@ -1,284 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(strides),
ndim);
}
});
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, ToFP8>) {
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
}
if (std::is_same_v<Op, FromFP8>) {
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
}
return false;
}

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file include utilies that are used by C++ code (i.e. .cpp files).
// This file include utilities that are used by C++ code (i.e. .cpp files).
#pragma once
@@ -12,6 +12,7 @@ namespace mlx::core {
namespace cu {
class Device;
}
struct Dtype;
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
explicit CudaStream(cu::Device& device);
};
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
} // namespace mlx::core

View File

@@ -5,9 +5,9 @@
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
Worker::Worker(Device& d)
: signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {

View File

@@ -15,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
explicit Worker(Device& d);
~Worker();
Worker(const Worker&) = delete;

View File

@@ -29,7 +29,7 @@ make_jit_source(
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
kernels/fp4.h)
make_jit_source(gemv_masked)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -32,7 +32,6 @@ namespace metal {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
@@ -41,7 +40,8 @@ MetalAllocator::MetalAllocator()
residency_set_.erase(buf);
}
buf->release();
}) {
}),
residency_set_(device_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =

View File

@@ -65,7 +65,6 @@ class MetalAllocator : public allocator::Allocator {
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
size_t num_resources_{0};
size_t resource_limit_{0};

View File

@@ -327,6 +327,10 @@ CustomKernelFunction metal_kernel(
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// silence some warnings
(void)is_precompiled_;
(void)shared_memory_;
auto& s = stream();
std::vector<array> copies;

View File

@@ -72,6 +72,19 @@ MTL::Library* try_load_bundle(
}
return nullptr;
}
MTL::Library* try_load_framework(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string resource_path = std::string(url->fileSystemRepresentation()) +
"/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
return nullptr;
}
#endif
// Firstly, search for the metallib in the same path as this binary
@@ -103,6 +116,17 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
return {library, nullptr};
}
}
// if SWIFTPM_BUNDLE is a framework identifier, try loading from that
auto frameworks = NS::Bundle::allFrameworks();
for (int i = 0, c = (int)frameworks->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));
if (!strcmp(bundle->bundleIdentifier()->utf8String(), SWIFTPM_BUNDLE)) {
library = try_load_framework(device, bundle->resourceURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
}
}
#endif
return {nullptr, nullptr};
}
@@ -471,6 +495,10 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
@@ -724,7 +752,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
kernel_map_.insert({hash_name, kernel});
return kernel;
}

View File

@@ -71,7 +71,7 @@ void eval(array& arr) {
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
@@ -82,7 +82,7 @@ void finalize(Stream s) {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}

View File

@@ -150,7 +150,6 @@ FFTPlan plan_fft(int n) {
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
@@ -313,8 +312,6 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
@@ -484,8 +481,6 @@ void four_step_fft(
std::vector<array>& copies,
const Stream& s,
bool in_place) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
@@ -786,7 +781,6 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);

View File

@@ -378,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
@@ -393,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
// Need placeholders so Metal doesn't complain
int shape_ = 0;
int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);

View File

@@ -24,7 +24,7 @@ const char* hadamard();
const char* logsumexp();
const char* quantized_utils();
const char* quantized();
const char* fp4_quantized();
const char* fp_quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();

View File

@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
kernel_source += get_template_definition(
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
kernel_source += get_template_definition(
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
kernel_source += get_template_definition(
"vn_" + lib_name, "ternary_v", t_str, op, false, false);
kernel_source += get_template_definition(
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
kernel_source += get_template_definition(
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
kernel_source += get_template_definition(
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
@@ -814,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
template_def);
return kernel_source;
});
@@ -841,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
std::string kernel_source;
concatenate(
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
concatenate(
kernel_source,
metal::quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
bool is_affine = mode == "affine";
concatenate(
kernel_source,
is_affine ? metal::quantized() : metal::fp_quantized(),
get_template_definition(
lib_name,
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -6,6 +6,7 @@ set(BASE_HEADERS
defines.h
erf.h
expm1f.h
fp8.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
@@ -109,7 +110,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

View File

@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr threadgroup complex64_t& operator+=(
threadgroup complex64_t& a,
complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}

View File

@@ -0,0 +1,59 @@
#pragma once
constexpr constant static float FP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
struct fp4_e2m1 {
fp4_e2m1(float x) {
if (metal::isnan(x)) {
bits = 0x7;
return;
}
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
x = metal::abs(x);
if (x > 5.0f) {
bits = 0x7;
} else if (x >= 3.5f) {
bits = 0x6;
} else if (x > 2.5f) {
bits = 0x5;
} else if (x >= 1.75f) {
bits = 0x4;
} else if (x > 1.25f) {
bits = 0x3;
} else if (x >= 0.75f) {
bits = 0x2;
} else if (x > 0.25f) {
bits = 0x1;
} else {
bits = 0x0;
}
bits |= sign_bit;
}
operator float() {
half converted = as_type<half>(ushort((bits & 7) << 9));
converted *= 16384.0;
converted = bits & 8 ? -converted : converted;
return converted;
}
uint8_t bits;
};

View File

@@ -1,127 +0,0 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
uint8_t, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
aligned, \
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
D, \
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
uint8_t, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -0,0 +1,82 @@
#pragma once
struct fp8_e4m3 {
template <typename T>
fp8_e4m3(T f) {
// From PyTorch
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 543 << 21;
uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign;
if (f_bits >= fp8_max) {
// Default behavior saturates to min/max
bits = 0x7E;
} else {
if (f_bits < (121 << 23)) {
f_bits = as_type<uint32_t>(
as_type<float>(f_bits) + as_type<float>(denorm_mask));
bits = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
f_bits += mant_odd;
bits = static_cast<uint8_t>(f_bits >> 20);
}
}
bits |= static_cast<uint8_t>(sign >> 24);
}
operator float() {
// From PyTorch:
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
uint32_t w = static_cast<uint32_t>(bits) << 24;
uint32_t sign = w & 0x80000000;
uint32_t nonsign = w & 0x7FFFFFFF;
uint32_t renorm_shift = metal::clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
int32_t inf_nan_mask =
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return as_type<float>(result);
}
uint8_t bits;
};
struct fp8_e8m0 {
fp8_e8m0(float x) {
if (!metal::isfinite(x)) {
bits = 0xFF;
return;
}
if (x < 0.0f) {
bits = 0x00;
return;
}
float le = metal::log2(x);
int n = int(metal::round(le));
n = n < -127 ? -127 : n;
n = n > 127 ? 127 : n;
bits = static_cast<uint8_t>(n + 127);
}
operator bfloat16_t() {
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
return as_type<bfloat16_t>(out);
}
operator float() {
return static_cast<float>(this->operator bfloat16_t());
}
uint8_t bits;
};

View File

@@ -3,6 +3,9 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp8.h"
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
@@ -26,15 +29,31 @@ inline constexpr short get_bytes_per_pack() {
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
return T(*(thread fp8_e8m0*)(&s));
}
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
}
}
};
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));
}
}
};
template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) {
@@ -59,80 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
constexpr constant static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
inline U
qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
(x_thread[4 * i] * Dequantize<4>{}(ws[i]) +
x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +
x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +
x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline void qouter(
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0xf];
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
result[2 * i] += x * scale * Dequantize<4>{}(w[i]);
result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);
}
}
@@ -155,8 +135,7 @@ template <
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
typename S>
short group_size>
struct QuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
@@ -183,12 +162,12 @@ struct QuantizedBlockLoader {
threadgroup T* dst;
const device uint8_t* src;
const device S* scales;
const device uint8_t* scales;
threadgroup T* lut;
QuantizedBlockLoader(
const device uint8_t* src_,
const device S* scales_,
const device uint8_t* scales_,
const int src_ld_,
threadgroup T* dst_,
threadgroup T* lut_,
@@ -208,7 +187,10 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
if (simd_group_id == 0 && simd_lane_id < 16) {
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
void load_unsafe() const {
@@ -270,20 +252,17 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl(
template <typename T, int group_size, int bits, int D>
METAL_FUNC void fp_qmv_quad_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE;
@@ -295,7 +274,6 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
@@ -311,11 +289,11 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -327,18 +305,17 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_fast_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_fast_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
@@ -353,7 +330,6 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -374,7 +350,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -390,18 +366,17 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
@@ -418,7 +393,6 @@ METAL_FUNC void mxfp4_qmv_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -448,8 +422,8 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -468,7 +442,7 @@ METAL_FUNC void mxfp4_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
}
@@ -497,7 +471,7 @@ METAL_FUNC void mxfp4_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
}
ws += block_size * bytes_per_pack / pack_factor;
@@ -517,7 +491,7 @@ METAL_FUNC void mxfp4_qmv_impl(
U s = dequantize_scale<U>(sl[0]);
result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
@@ -529,18 +503,17 @@ METAL_FUNC void mxfp4_qmv_impl(
}
}
template <typename T, const int group_size, typename S>
METAL_FUNC void mxfp4_qvm_impl(
template <typename T, const int group_size, int bits>
METAL_FUNC void fp_qvm_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack();
@@ -561,8 +534,6 @@ METAL_FUNC void mxfp4_qvm_impl(
thread U scale = 0;
thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
@@ -584,7 +555,7 @@ METAL_FUNC void mxfp4_qvm_impl(
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -597,7 +568,7 @@ METAL_FUNC void mxfp4_qvm_impl(
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
x += block_size;
scales += block_size * out_vec_size_g;
@@ -612,7 +583,7 @@ METAL_FUNC void mxfp4_qvm_impl(
scale = 0;
}
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
(thread uint8_t*)&w_local, x_local, scale, result);
}
// Accumulate in the simdgroup
@@ -633,14 +604,14 @@ METAL_FUNC void mxfp4_qvm_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_t_impl(
METAL_FUNC void fp_qmm_t_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -677,8 +648,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
@@ -759,13 +729,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_n_impl(
METAL_FUNC void fp_qmm_n_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -803,8 +773,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
auto wl = (const device uint8_t*)w;
@@ -891,11 +860,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
}
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
@@ -926,11 +895,11 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T*& y,
@@ -976,10 +945,10 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad(
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void fp_qmv_quad(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -993,9 +962,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
@@ -1013,26 +980,14 @@ template <typename T, int group_size, typename S, int D, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, S, D>(
w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
fp_qmv_quad_impl<T, group_size, bits, D>(
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
}
template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast(
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void fp_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1064,15 +1019,14 @@ template <typename T, int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1104,15 +1058,14 @@ template <typename T, const int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1144,15 +1097,14 @@ template <typename T, const int group_size, typename S, bool batched>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k(
template <typename T, const int group_size, int bits, int split_k = 32>
[[kernel]] void fp_qvm_split_k(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1188,32 +1140,22 @@ template <typename T, const int group_size, typename S, int split_k = 32>
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w,
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_t(
[[kernel]] void fp_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1254,21 +1196,21 @@ template <
s_strides,
tid);
}
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_n(
[[kernel]] void fp_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1311,14 +1253,14 @@ template <
tid);
}
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv_fast(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1360,15 +1302,14 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1410,15 +1351,14 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qvm(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1460,22 +1400,21 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_t(
[[kernel]] void fp_gather_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1526,20 +1465,20 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_n(
[[kernel]] void fp_gather_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1591,24 +1530,24 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
int group_size,
typename S,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void mxfp4_gather_qmm_rhs(
[[kernel]] void fp_gather_qmm_rhs(
const device T* x,
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device uint32_t* indices,
device T* y,
const constant int& M,
@@ -1644,8 +1583,7 @@ template <
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
@@ -1789,3 +1727,78 @@ template <
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void fp_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device uint8_t* scales [[buffer(2)]],
uint2 tidx [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr bool use_mx_scale = group_size == 32;
size_t index = tidx.x + grid_dim.x * size_t(tidx.y);
float scale;
float w_thread = w[index];
if (use_mx_scale) {
scale = simd_max(abs(w_thread));
} else {
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
scale = tidx.x < 16 ? w_max_l : w_max_r;
}
scale /= bits == 4 ? 6.0f : 448.0f;
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
auto s = ScaleType(scale);
uint8_t q_scale = s.bits;
scale = float(s);
// Write out the scales and biases
size_t gindex = index / group_size;
if (index % group_size == 0) {
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
if (bits == 4) {
uint8_t sval = simd_shuffle_down(output, 1);
output |= sval << bits;
}
constexpr int pack_factor = bits == 8 ? 1 : 2;
if (index % pack_factor == 0) {
out[index / pack_factor] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void fp_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
device T* out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr bool use_mx_scale = group_size == 32;
constexpr int pack_factor = bits == 8 ? 1 : 2;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * pack_factor;
size_t gindex = oindex / group_size;
out += oindex;
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
auto q_scale = ((device ScaleType*)(scales))[gindex];
auto scale = float(q_scale);
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
}
}

View File

@@ -0,0 +1,147 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp_quantized.h"
#define instantiate_quantized(mode, name, type) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4", \
fp_ ## name, \
type, \
32, \
4)
#define instantiate_quantized_batched(mode, name, type, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
batched)
#define instantiate_quantized_aligned(mode, name, type, aligned) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
fp_ ## name, \
type, \
32, \
4, \
aligned)
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
aligned, \
batched)
#define instantiate_quantized_quad(mode, name, type, D, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
D, \
batched)
#define instantiate_quantized_split_k(mode, name, type, split_k) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
fp_ ## name, \
type, \
32, \
4, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
4, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(mode, name, type) \
instantiate_quantized_batched(mode, name, type, 1) \
instantiate_quantized_batched(mode, name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
instantiate_quantized(mxfp4, gather_qmv, type) \
instantiate_quantized(mxfp4, gather_qvm, type) \
instantiate_quantized(mxfp4, gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
instantiate_kernel( \
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_quantize, \
type, \
group_size, \
bits) \
instantiate_kernel( \
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_dequantize, \
type, \
group_size, \
bits)
#define instantiate_quantize_dequantize_modes(type) \
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type) \
instantiate_quantize_dequantize_modes(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -15,6 +15,15 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
template <typename U>
struct DefaultAccT {
using type = float;
};
template <>
struct DefaultAccT<complex64_t> {
using type = complex64_t;
};
template <
typename T,
const int BM, /* Threadgroup rows (in simdgroups) */
@@ -24,8 +33,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -246,8 +257,10 @@ template <
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float>
typename AccT = typename DefaultAccT<T>::type>
struct GEMVTKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
@@ -453,7 +466,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -530,6 +543,7 @@ template <
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
instantiate_gemv_blocks(complex64, complex64_t);
template <
typename T,
@@ -565,7 +579,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -636,6 +650,7 @@ template <
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_bs_blocks(complex64, complex64_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
@@ -672,7 +687,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets
@@ -738,7 +753,8 @@ template <
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
template <
typename T,
@@ -773,8 +789,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
threadgroup float tgp_memory
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
@@ -848,4 +864,5 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on

View File

@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
b = w;
}
template <typename T, typename = void>
struct Init {
static constexpr constant T v = Limits<T>::max;
};
template <typename T>
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
};
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
static constexpr constant T init = Init<T>::v;
METAL_FUNC bool operator()(T a, T b) const {
if constexpr (
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
bool an = isnan(a);
bool bn = isnan(b);
if (an | bn) {
return (!an) & bn;
}
}
return a < b;
}
};

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -23,10 +23,12 @@
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
// clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \

View File

@@ -60,6 +60,7 @@
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
#define instantiate_accum(oname, otype, aname, atype) \
instantiate_kernel( \
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on
instantiate_accum(float32, float, float32, float);
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on

View File

@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
}
}
template <typename InT>
struct TransformNone<complex64_t, InT> {
static METAL_FUNC complex64_t apply(complex64_t x) {
return x;
}
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
return x;
}
};
template <
typename T,
typename U,
@@ -731,5 +741,406 @@ struct BlockMMA {
}
};
template <
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType,
typename Epilogue>
struct BlockMMA<
complex64_t,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
lda_tgp,
ldb_tgp,
AccumType,
Epilogue> {
static_assert(
metal::is_same_v<AccumType, float>,
"BlockMMA<complex64_t,...> expects float accumulators");
static_assert(
metal::is_same_v<U, complex64_t>,
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// When indexing complex as float[2]
STEEL_CONST short A_str_m_f = A_str_m * 2;
STEEL_CONST short A_str_k_f = A_str_k * 2;
STEEL_CONST short B_str_k_f = B_str_k * 2;
STEEL_CONST short B_str_n_f = B_str_n * 2;
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
// Accumulators (real/imag)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
// Offsets within threadgroup
short sm, sn;
short As_offset, Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
sm += tm;
sn += tn;
}
/* Karatsuba MMA: 3 real MMAs per K-chunk */
METAL_FUNC void mma(
const threadgroup complex64_t* As,
const threadgroup complex64_t* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
threadgroup const float* As_f =
reinterpret_cast<threadgroup const float*>(As);
threadgroup const float* Bs_f =
reinterpret_cast<threadgroup const float*>(Bs);
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
simdgroup_barrier(mem_flags::mem_none);
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
simdgroup_barrier(mem_flags::mem_none);
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
tile_matmad(P, Ar, Br, P);
tile_matmad(Q, Ai, Bi, Q);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
Ar.elems()[i] += Ai.elems()[i];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
Br.elems()[i] += Bi.elems()[i];
tile_matmad(R, Ar, Br, R);
// C_r += P - Q ; C_i -= Q
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
const auto p = P.elems()[i];
const auto q = Q.elems()[i];
const auto r = R.elems()[i];
Ctile_r.elems()[i] += (p - q);
Ctile_i.elems()[i] += (r - p - q);
}
// Progress to next simdgroup tile
As_f += tile_stride_a_f;
Bs_f += tile_stride_b_f;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off = (i * TM_stride) * ldd + (j * TN_stride);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
METAL_FUNC void
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
D += sm * ldd + sn;
start -= short2(sn, sm);
stop -= short2(sn, sm);
if (stop.y <= 0 || stop.x <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; ++i) {
const int row = i * TM_stride;
if (row >= start.y && row < stop.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; ++j) {
const int off = row * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
const int col = j * TN_stride + k;
if (col >= start.x && col < stop.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
int off = (i * TM_stride) * ldd + (j * TN_stride);
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
}
}
}
}
}
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
complex64_t out = epilogue_op.apply(
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
Ctile_r.elems()[i] = out.real;
Ctile_i.elems()[i] = out.imag;
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
complex64_t out = epilogue_op.apply(
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread auto& r = Ctile_r.frag_at(i, j);
thread auto& im = Ctile_i.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
complex64_t tmp[kelems];
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x &&
(i * TM_stride) < dst_tile_dims.y) {
tmp[k] = C[offset_c + k * fdc];
} else {
tmp[k] = complex64_t(0.0f, 0.0f);
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
r[k] = out.real;
im[k] = out.imag;
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[off_d + k] =
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in Cr, Ci
thread const auto& r = Ctile_r.frag_at(i, j);
thread const auto& im = Ctile_i.frag_at(i, j);
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[off_d + k] = epilogue_op.apply(
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -48,7 +48,8 @@ struct TransformAxpby {
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
return static_cast<OutT>(
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
}
};

View File

@@ -1,6 +1,11 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
@@ -11,16 +16,25 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
}
}
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
@@ -32,11 +46,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
}
}

View File

@@ -9,8 +9,12 @@
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \
instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \
instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \
instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \
instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
@@ -21,7 +25,9 @@
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \
instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \
instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \

View File

@@ -9,11 +9,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
}
}
@@ -28,11 +28,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
}
}
@@ -57,7 +57,7 @@ template <
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
out[out_idx++] = static_cast<U>(Op()(in[idx]));
idx += xstride;
}
}

Some files were not shown because too many files have changed in this diff Show More