mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
30 Commits
b901a9f311
...
sign-warns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24828b1b2f | ||
|
|
9f649b5658 | ||
|
|
18aa921388 | ||
|
|
8d13a0bc6b | ||
|
|
ac75c87fd7 | ||
|
|
7107802e09 | ||
|
|
c5913131cf | ||
|
|
19ab7911f6 | ||
|
|
4a1b1796b7 | ||
|
|
b48d298205 | ||
|
|
8277e71ea9 | ||
|
|
b0d985416a | ||
|
|
8d10f3ec75 | ||
|
|
6343622c67 | ||
|
|
979abf462b | ||
|
|
981d2fdaf0 | ||
|
|
5a306d3495 | ||
|
|
5baa361779 | ||
|
|
1bac0db7e3 | ||
|
|
a1212b4e44 | ||
|
|
45a8b226af | ||
|
|
76ef1e98f3 | ||
|
|
63d91557e0 | ||
|
|
310e501e6a | ||
|
|
cacc3ab7fd | ||
|
|
53525cba23 | ||
|
|
3d67b717a0 | ||
|
|
953b2f5be2 | ||
|
|
26f7155537 | ||
|
|
66fcb9fe94 |
24
.github/actions/build-cuda-release/action.yml
vendored
24
.github/actions/build-cuda-release/action.yml
vendored
@@ -1,24 +0,0 @@
|
||||
name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
MLX_BUILD_STAGE: 2
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
python -m build -w
|
||||
|
||||
if [ -f "python/scripts/repair_cuda.sh" ]; then
|
||||
bash python/scripts/repair_cuda.sh
|
||||
fi
|
||||
68
.github/actions/build-cuda/action.yml
vendored
68
.github/actions/build-cuda/action.yml
vendored
@@ -1,68 +0,0 @@
|
||||
name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type (debug, release)'
|
||||
required: false
|
||||
default: 'debug'
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
required: true
|
||||
default: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
# this value is dependent on the CUDA tools installed in the setup-linux workflow
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Python package
|
||||
shell: bash
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
- name: Check if build actually worked
|
||||
shell: bash
|
||||
run: python -c "import mlx.core"
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: cpu
|
||||
run: python -m unittest discover python/tests -v
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
run: python -m tests discover python/tests -v
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
|
||||
- name: Build Python package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: ${{ inputs.nvcc-location }}
|
||||
38
.github/actions/build-docs/action.yml
vendored
38
.github/actions/build-docs/action.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: 'Build Documentation'
|
||||
description: 'Build documentation on a mac'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup machine
|
||||
uses: ./.github/actions/setup-macos
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
brew install doxygen
|
||||
uv pip install --upgrade pip cmake
|
||||
uv pip install -r docs/requirements.txt
|
||||
uv pip install . -v
|
||||
|
||||
- name: Build documentation
|
||||
shell: bash
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
cd docs
|
||||
doxygen
|
||||
make html O=-W
|
||||
|
||||
- name: Create artifact tar
|
||||
shell: sh
|
||||
run: tar -cf artifact.tar --cd docs/build/html -L .
|
||||
|
||||
# Do it manually because upload-pages-artifact requires gtar
|
||||
- name: Upload artifact
|
||||
id: upload-artifact
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: github-pages
|
||||
path: artifact.tar
|
||||
retention-days: 1
|
||||
if-no-files-found: error
|
||||
78
.github/actions/build-linux/action.yml
vendored
78
.github/actions/build-linux/action.yml
vendored
@@ -1,78 +0,0 @@
|
||||
name: 'Build and Test on Linux'
|
||||
description: 'Build and test MLX on Linux'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type'
|
||||
required: false
|
||||
default: 'debug'
|
||||
type: choice
|
||||
options:
|
||||
- debug
|
||||
- release
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
type: boolean
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set DEBUG
|
||||
shell: sh
|
||||
if: inputs.build-type == 'debug'
|
||||
run: echo "DEBUG=1" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Python package
|
||||
shell: sh
|
||||
env:
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: sh
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Run Python tests
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if grep -Fq '[WARN]' stderr.log ; then
|
||||
grep -F '[WARN]' stderr.log
|
||||
echo "Distributed ring test failed";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: sh
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build Python package
|
||||
if: inputs.build-type == 'release'
|
||||
shell: bash
|
||||
run: |
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
if [ -f "python/scripts/repair_linux.sh" ]; then
|
||||
bash python/scripts/repair_linux.sh
|
||||
fi
|
||||
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
22
.github/actions/build-macos-release/action.yml
vendored
22
.github/actions/build-macos-release/action.yml
vendored
@@ -1,22 +0,0 @@
|
||||
name: 'Build macOS release'
|
||||
description: 'Build MLX releases macOS'
|
||||
|
||||
inputs:
|
||||
macos-target:
|
||||
description: 'macOS build target'
|
||||
required: false
|
||||
default: '15.0'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build Python package(s)
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv pip install build
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 uv run -m build -w
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 uv run -m build -w
|
||||
124
.github/actions/build-macos/action.yml
vendored
124
.github/actions/build-macos/action.yml
vendored
@@ -1,124 +0,0 @@
|
||||
name: 'Build and Test on macOS'
|
||||
description: 'Build and test MLX on macOS'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type (debug, release)'
|
||||
required: false
|
||||
default: 'debug'
|
||||
type: choice
|
||||
options:
|
||||
- debug
|
||||
- release
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
build-jit:
|
||||
description: 'Whether to build with JIT'
|
||||
required: false
|
||||
default: 'true'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
env:
|
||||
DEBUG: 1
|
||||
DEV_RELEASE: 1
|
||||
run: |
|
||||
uv pip install --upgrade pip cmake setuptools
|
||||
uv pip install nanobind==2.4.0 \
|
||||
numpy torch tensorflow unittest-xml-reporting
|
||||
uv pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
|
||||
- name: Run Python tests
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
run: |
|
||||
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
|
||||
- name: Build example extension
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project test.py
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build small binary with JIT
|
||||
if: inputs.build-jit == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run Python tests with JIT
|
||||
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
run: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
uv run -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
- name: Build macOS 13 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 13.0
|
||||
- name: Build macOS 14 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 14.0
|
||||
- name: Build macOS 15 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
83
.github/actions/setup-linux/action.yml
vendored
83
.github/actions/setup-linux/action.yml
vendored
@@ -1,83 +0,0 @@
|
||||
name: 'Setup Linux Environment'
|
||||
description: 'Install dependencies for Linux builds'
|
||||
|
||||
inputs:
|
||||
runner-type:
|
||||
description: 'Whether to set this up as a linux or CUDA runner'
|
||||
required: false
|
||||
default: 'linux'
|
||||
type: choice
|
||||
options:
|
||||
- linux
|
||||
- cuda
|
||||
python-version:
|
||||
description: 'Version of python to set up'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Free disk space
|
||||
shell: sh
|
||||
if: inputs.runner-type == 'linux'
|
||||
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
|
||||
- name: Install common dependencies
|
||||
env:
|
||||
TZ: Etc/UTC
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
|
||||
sudo apt autoremove -y
|
||||
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
cache: 'pip'
|
||||
|
||||
- name: setup python venv
|
||||
shell: bash
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
pip install --upgrade pip cmake
|
||||
|
||||
- name: Install MPI
|
||||
if: inputs.runner-type == 'linux'
|
||||
shell: bash
|
||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||
|
||||
- name: Network CUDA installation from packages
|
||||
id: install-cuda
|
||||
if: inputs.runner-type == 'cuda'
|
||||
env:
|
||||
TZ: Etc/UTC
|
||||
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
|
||||
run: |
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
|
||||
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
|
||||
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
|
||||
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
|
||||
|
||||
- name: Package and Driver Report
|
||||
if: inputs.runner-type == 'cuda'
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||
echo "NVIDIA Driver Packages Available:"
|
||||
sudo ubuntu-drivers list --gpgpu
|
||||
echo "NVIDIA Driver Version:"
|
||||
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
|
||||
echo "Installed NVIDIA and CUDA packages:"
|
||||
dpkg -l | egrep "cuda|nvidia" -i
|
||||
echo "DKMS Status:"
|
||||
dkms status || echo "dkms not found"
|
||||
echo "NVIDIA-SMI Status:"
|
||||
nvidia-smi || echo "nvidia-smi not found"
|
||||
31
.github/actions/setup-macos/action.yml
vendored
31
.github/actions/setup-macos/action.yml
vendored
@@ -1,31 +0,0 @@
|
||||
name: 'Setup macOS Environment'
|
||||
description: 'Install dependencies for macOS builds'
|
||||
|
||||
inputs:
|
||||
install-mpi:
|
||||
description: 'Whether to install MPI'
|
||||
required: false
|
||||
default: 'true'
|
||||
type: boolean
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Homebrew packages
|
||||
shell: sh
|
||||
if: inputs.install-mpi == 'true'
|
||||
run: /opt/homebrew/bin/brew install openmpi
|
||||
|
||||
- name: Verify MetalToolchain installed
|
||||
shell: bash
|
||||
run: xcodebuild -showComponent MetalToolchain
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
activate-environment: true
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -1,6 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
28
.github/workflows/documentation.yml
vendored
28
.github/workflows/documentation.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Documentation
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy:
|
||||
needs: build
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
93
.github/workflows/nightly.yml
vendored
93
.github/workflows/nightly.yml
vendored
@@ -1,93 +0,0 @@
|
||||
name: Nightly Build
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 33 6 * * 1-5
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build_linux_release:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
with:
|
||||
build-type: release
|
||||
run-tests: false
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
retention-days: 7
|
||||
- name: Upload mlx-cpu artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
retention-days: 7
|
||||
|
||||
build_linux_with_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
|
||||
build_mac_release:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.13"]
|
||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: "15.0"
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
build_cuda_with_tests:
|
||||
runs-on: gpu-t4-4-core
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
- uses: ./.github/actions/build-cuda
|
||||
|
||||
build_cuda_release:
|
||||
runs-on: ubuntu-22-large
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
retention-days: 7
|
||||
|
||||
56
.github/workflows/pull_request.yml
vendored
56
.github/workflows/pull_request.yml
vendored
@@ -1,46 +1,20 @@
|
||||
name: Build and Test
|
||||
|
||||
on: pull_request
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
linux_build_and_test:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
|
||||
mac_build_and_test:
|
||||
runs-on: [self-hosted, macos]
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
cuda_build_and_test:
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
- uses: ./.github/actions/build-cuda
|
||||
|
||||
build_documentation:
|
||||
runs-on: [self-hosted, macos]
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
python-version: 3.8
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
|
||||
188
.github/workflows/release.yml
vendored
188
.github/workflows/release.yml
vendored
@@ -1,188 +0,0 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy_documentation:
|
||||
needs: build_documentation
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
|
||||
build_linux_release:
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
with:
|
||||
build-type: release
|
||||
run-tests: false
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
build_mac_release:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
with:
|
||||
build-type: release
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mac-wheels-${{ matrix.python-version }}
|
||||
path: dist/mlx-*.whl
|
||||
- name: Upload Metal artifacts
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-metal
|
||||
path: dist/mlx_metal-*.whl
|
||||
|
||||
build_cuda_release:
|
||||
runs-on: ubuntu-22-large
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build_linux_release, build_mac_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: linux-wheels-*
|
||||
merge-multiples: true
|
||||
path: artifacts
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: mac-wheels-*
|
||||
merge-multiples: true
|
||||
path: artifacts
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
pypi-publish-cuda:
|
||||
name: Upload CUDA release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_cuda_release
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: artifacts
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
pypi-publish-cpu:
|
||||
name: Upload CPU release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_linux_release
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: artifacts
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
pypi-publish-metal:
|
||||
name: Upload Metal release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_mac_release
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-metal
|
||||
path: artifacts
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
@@ -1,10 +1,4 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
# - id: end-of-file-fixer
|
||||
# - id: trailing-whitespace
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
|
||||
@@ -20,9 +20,13 @@ project(
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
add_compile_options(-Wall -Wextra)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
@@ -88,11 +92,6 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
@@ -14,14 +14,17 @@ void array_basics() {
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
assert(s == 1.0);
|
||||
(void)s;
|
||||
|
||||
// Scalars have a size of 1:
|
||||
size_t size = x.size();
|
||||
int64_t size = x.size();
|
||||
assert(size == 1);
|
||||
(void)size;
|
||||
|
||||
// Scalars have 0 dimensions:
|
||||
int ndim = x.ndim();
|
||||
assert(ndim == 0);
|
||||
(void)ndim;
|
||||
|
||||
// The shape should be an empty vector:
|
||||
auto shape = x.shape();
|
||||
@@ -30,6 +33,7 @@ void array_basics() {
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == mx::float32);
|
||||
(void)dtype;
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = mx::array(1, mx::int32);
|
||||
|
||||
@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(shapes); ++i) {
|
||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||
}
|
||||
// For each node in |outputs|, its siblings are the other nodes.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(outputs); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
@@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data_size = size();
|
||||
array_desc_->flags.contiguous = true;
|
||||
array_desc_->flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(shape().begin(), shape().end());
|
||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
||||
auto max_dim =
|
||||
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
|
||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
|
||||
}
|
||||
|
||||
void array::set_data(
|
||||
@@ -192,7 +193,7 @@ array::~array() {
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
if (auto n = std::ssize(siblings()); n > 0) {
|
||||
bool do_detach = true;
|
||||
// If all siblings have siblings.size() references except
|
||||
// the one we are currently destroying (which has siblings.size() + 1)
|
||||
@@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
bool is_deletable =
|
||||
(a.array_desc_.use_count() <= a.siblings().size() + 1);
|
||||
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
|
||||
// An array with siblings is deletable only if all of its siblings
|
||||
// are deletable
|
||||
for (auto& s : a.siblings()) {
|
||||
@@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
}
|
||||
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||
is_deletable &=
|
||||
s.array_desc_.use_count() <= a.siblings().size() + is_input;
|
||||
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
|
||||
}
|
||||
if (is_deletable) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
|
||||
14
mlx/array.h
14
mlx/array.h
@@ -81,22 +81,22 @@ class array {
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
int itemsize() const {
|
||||
return size_of(dtype());
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
int64_t size() const {
|
||||
return array_desc_->size;
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
int64_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
int ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
}
|
||||
|
||||
@@ -329,7 +329,7 @@ class array {
|
||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||
**/
|
||||
size_t data_size() const {
|
||||
int64_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
}
|
||||
|
||||
@@ -340,7 +340,7 @@ class array {
|
||||
return array_desc_->data->buffer;
|
||||
}
|
||||
|
||||
size_t buffer_size() const {
|
||||
int64_t buffer_size() const {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
@@ -530,7 +530,7 @@ array::array(
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
if (std::ssize(data) != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
int64_t r = 1, c = 1;
|
||||
for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
@@ -60,7 +60,8 @@ void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
|
||||
i < std::ssize(outputs);
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
@@ -70,7 +71,7 @@ void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(outputs); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
@@ -206,11 +207,11 @@ void Split::eval(
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
int64_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
@@ -240,7 +241,7 @@ void Split::eval(
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(indices); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
@@ -254,7 +255,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
if (j < std::ssize(axes_) && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
@@ -272,7 +273,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
for (int ax = 0; ax < std::ssize(axes_); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ void compiled_allocate_outputs(
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
@@ -138,7 +138,7 @@ void compiled_allocate_outputs(
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
for (; o < std::ssize(outputs); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
@@ -147,7 +147,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
@@ -162,7 +162,7 @@ void compiled_allocate_outputs(
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
for (; o < std::ssize(outputs); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
@@ -193,7 +193,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
@@ -201,7 +201,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
for (int i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
@@ -224,13 +224,13 @@ bool compiled_use_large_index(
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
int64_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
int64_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
|
||||
@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
for (int i = 1; i < std::ssize(axes); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
|
||||
@@ -24,8 +24,8 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
int64_t data_offset,
|
||||
int64_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
|
||||
@@ -61,7 +61,7 @@ void slice(
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
int64_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
}
|
||||
size_t size = shape[0];
|
||||
int64_t size = shape[0];
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
current_shape *= shape[to_collapse[k]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
for (int j = 0; j < std::ssize(strides); j++) {
|
||||
const auto& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
@@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
}
|
||||
|
||||
inline bool is_donatable(const array& in, const array& out) {
|
||||
constexpr size_t donation_extra = 16384;
|
||||
constexpr int64_t donation_extra = 16384;
|
||||
|
||||
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||
void arange(T start, T next, array& out, int64_t size, Stream stream) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
@@ -19,12 +19,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
|
||||
@@ -17,7 +17,12 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -81,7 +86,7 @@ void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -146,7 +151,7 @@ void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -187,7 +192,7 @@ void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -137,21 +137,21 @@ void binary_op(
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
for (int64_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
|
||||
@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
N = a.shape(-1),
|
||||
size = a.size()]() mutable {
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
size_t num_matrices = size / (N * N);
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
int64_t num_matrices = size / (N * N);
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
|
||||
@@ -49,7 +49,7 @@ static CompilerCache& cache() {
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
bool compile_available_for_device(const Device& /* device */) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ inline void build_kernel(
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
@@ -238,7 +238,7 @@ inline void build_kernel(
|
||||
} else {
|
||||
os << x.primitive().name();
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
|
||||
@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
@@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
const bool flip,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
@@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Pad input
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
for (int i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
@@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Make strided view
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
for (int i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
for (int i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(wt_strides); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ void Recv::eval_cpu(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
(void)inputs;
|
||||
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
|
||||
@@ -70,7 +70,7 @@ void eig_impl(
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
|
||||
@@ -165,7 +165,7 @@ void eigh_impl(
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
|
||||
@@ -20,8 +20,8 @@ struct CommandEncoder {
|
||||
CommandEncoder(CommandEncoder&&) = delete;
|
||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||
|
||||
void set_input_array(const array& a) {}
|
||||
void set_output_array(array& a) {}
|
||||
void set_input_array(const array& /* a */) {}
|
||||
void set_output_array(array& /* a */) {}
|
||||
|
||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||
// an input are complete.
|
||||
|
||||
@@ -12,12 +12,12 @@ void matmul(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -34,7 +34,7 @@ void matmul_bnns(
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
size_t /* ldc */,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
@@ -52,7 +52,7 @@ void matmul_bnns(
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
if (beta != 1.0 && beta != 0.0) {
|
||||
// scale the output
|
||||
for (auto i = 0; i < batch_size * M * N; ++i) {
|
||||
for (size_t i = 0; i < batch_size * M * N; ++i) {
|
||||
out[i] *= beta;
|
||||
}
|
||||
beta = 1.0;
|
||||
@@ -127,7 +127,7 @@ void matmul_bnns(
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
@@ -148,12 +148,12 @@ void matmul<float16_t>(
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -13,20 +13,20 @@ void matmul<float>(
|
||||
float* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_sgemm(
|
||||
@@ -54,20 +54,20 @@ void matmul<double>(
|
||||
double* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_dgemm(
|
||||
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
size_t loc = b * n;
|
||||
int64_t loc = b * n;
|
||||
T* data_ptr = out + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
for (int i = 0; i < std::ssize(row); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
}
|
||||
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
int64_t loc = b * n * m;
|
||||
T* data_ptr = out + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
|
||||
@@ -78,7 +78,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
int64_t i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
@@ -91,7 +91,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
int64_t i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
@@ -101,11 +101,11 @@ void gather(
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
int64_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
|
||||
@@ -115,10 +115,10 @@ void gather(
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
int64_t out_idx = 0;
|
||||
for (int64_t idx = 0; idx < ind_size; idx++) {
|
||||
int64_t src_idx = 0;
|
||||
for (int ii = 0; ii < std::ssize(inds); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
@@ -134,7 +134,7 @@ void gather(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
for (int64_t jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
@@ -403,11 +403,11 @@ void scatter(
|
||||
const std::vector<int>& axes) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
int64_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
int64_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
@@ -418,9 +418,9 @@ void scatter(
|
||||
|
||||
auto out_ptr = out.data<InT>();
|
||||
auto upd_ptr = updates.data<InT>();
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < inds.size(); ++j) {
|
||||
for (int64_t i = 0; i < n_updates; ++i) {
|
||||
int64_t out_offset = 0;
|
||||
for (int j = 0; j < std::ssize(inds); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
@@ -429,7 +429,7 @@ void scatter(
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
for (int64_t j = 0; j < update_size; ++j) {
|
||||
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
|
||||
@@ -122,7 +122,7 @@ void inverse_impl(
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
const int64_t num_matrices = a.size() / (N * N);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(inv);
|
||||
@@ -130,13 +130,13 @@ void inverse_impl(
|
||||
auto inv_ptr = inv.data<T>();
|
||||
if (tri) {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
tri_inv<T>(inv_ptr + N * N * i, N, upper);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
general_inv<T>(inv_ptr + N * N * i, N);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -25,7 +25,7 @@ inline void mask_matrix(
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
const int64_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
@@ -61,13 +61,13 @@ inline void segmented_mm(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
int64_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [b_transposed, ldb, b, b_copied] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str,
|
||||
int64_t X_data_str,
|
||||
int64_t Y_data_str,
|
||||
const Shape& mask_shape,
|
||||
const Strides& mask_strides,
|
||||
bool is_bool) {
|
||||
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto a_ptr = a.data<float>();
|
||||
auto b_ptr = b.data<float>();
|
||||
auto out_ptr = out.data<float>();
|
||||
size_t num_matrices = out.size() / (M * size_t(N));
|
||||
int64_t num_matrices = out.size() / (M * int64_t(N));
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
encoder.dispatch([a_ptr,
|
||||
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get batch dims
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = M * N;
|
||||
int64_t matrix_stride_out = M * N;
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
|
||||
@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(axes); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
@@ -124,6 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
(void)inputs;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
@@ -193,9 +194,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
int64_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
@@ -205,7 +206,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
constexpr int64_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
@@ -254,8 +255,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
int64_t data_offset = 0;
|
||||
for (int i = 0; i < std::ssize(axes_); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
@@ -274,10 +275,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
int64_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
int64_t elems_per_key = out.size() / num_keys;
|
||||
int64_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
@@ -291,8 +292,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
uintptr_t half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
|
||||
@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = M;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < num_reflectors; ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
orgqr<T>(
|
||||
&M,
|
||||
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&info);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < M; ++j) {
|
||||
for (int k = 0; k < num_reflectors; ++k) {
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
#include "mlx/backend/cpu/unary_ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -1105,44 +1102,4 @@ void fast::Quantize::eval_cpu(
|
||||
});
|
||||
}
|
||||
|
||||
void fast::ConvertFP8::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
set_unary_output_data(in, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
to_fp8 = to_fp8_]() mutable {
|
||||
if (to_fp8) {
|
||||
switch (in.dtype()) {
|
||||
case float16:
|
||||
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
default:
|
||||
unary_op<float, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
|
||||
break;
|
||||
default:
|
||||
unary_op<uint8_t, float>(in, out, detail::FromFP8());
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
@@ -201,15 +200,6 @@ SIMD_DEFAULT_COMPARISONS(<=)
|
||||
SIMD_DEFAULT_COMPARISONS(==)
|
||||
SIMD_DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> clz(Simd<T, N> x) {
|
||||
auto a = *(uint32x4_t*)(&x);
|
||||
auto b = *((uint32x4_t*)(&x) + 1);
|
||||
a = vclzq_u32(a);
|
||||
b = vclzq_u32(b);
|
||||
return asd::make_uint8(a, b);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
return asd::atan2(a.value, b.value);
|
||||
@@ -263,12 +253,12 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
// Raising an integer to a negative power is undefined
|
||||
if (any(exp < 0)) {
|
||||
if (any(exp < static_cast<T>(0))) {
|
||||
return 0;
|
||||
}
|
||||
while (any(exp > 0)) {
|
||||
while (any(exp > static_cast<T>(0))) {
|
||||
res = select((exp & 1) != 0, res * base, res);
|
||||
base = select(exp > 0, base * base, base);
|
||||
base = select(exp > static_cast<T>(0), base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
|
||||
@@ -171,11 +171,6 @@ DEFAULT_BINARY(&)
|
||||
DEFAULT_BINARY(&&)
|
||||
DEFAULT_BINARY(||)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> clz(Simd<T, 1> x_) {
|
||||
return __builtin_clz(x_.value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
|
||||
@@ -79,7 +79,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
auto poly_mask =
|
||||
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
@@ -87,8 +88,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
|
||||
@@ -120,8 +120,8 @@ template <typename T>
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
@@ -214,8 +214,8 @@ template <typename T>
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
|
||||
@@ -27,7 +27,7 @@ void svd_impl(
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -121,7 +121,7 @@ void svd_impl(
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
@@ -153,10 +153,10 @@ void svd_impl(
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
const array& /* a */,
|
||||
bool /* compute_uv */,
|
||||
std::vector<array>& /* outputs */,
|
||||
Stream /* stream */) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
|
||||
@@ -136,7 +136,7 @@ void ternary_op(
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
|
||||
for (int64_t i = 0; i < shape; i += 1) {
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
|
||||
auto ndim = a.ndim();
|
||||
if (a.flags().contiguous) {
|
||||
auto size = a.data_size();
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
size -= N;
|
||||
src += N;
|
||||
dst += N;
|
||||
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
int64_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
int64_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
for (int64_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
|
||||
@@ -108,73 +108,4 @@ struct Square {
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
template <int N>
|
||||
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
|
||||
return *(Simd<float, N>*)(&x);
|
||||
}
|
||||
template <int N>
|
||||
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
|
||||
return *(Simd<uint32_t, N>*)(&x);
|
||||
}
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T, int N>
|
||||
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
||||
uint32_t fp8_max = 543 << 21;
|
||||
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
||||
Simd<uint32_t, N> f_bits;
|
||||
Simd<float, N> f32 = f;
|
||||
f_bits = fp32_to_bits(f32);
|
||||
Simd<uint8_t, N> result = 0u;
|
||||
auto sign = f_bits & 0x80000000;
|
||||
f_bits = f_bits ^ sign;
|
||||
|
||||
auto f_bits_low =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
|
||||
|
||||
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
|
||||
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
|
||||
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
|
||||
|
||||
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
|
||||
result = select(f_bits < (121 << 23), result_low, result_high);
|
||||
|
||||
auto result_sat = Simd<uint8_t, N>(0x7E);
|
||||
result = select(f_bits >= fp8_max, result_sat, result);
|
||||
return result | Simd<uint8_t, N>(sign >> 24);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint8_t operator()(T x) {
|
||||
return (*this)(Simd<T, 1>(x)).value;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<uint8_t, N> x) {
|
||||
auto w = Simd<uint32_t, N>(x) << 24;
|
||||
auto sign = w & 0x80000000;
|
||||
auto nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
auto renorm_shift = clz(nonsign);
|
||||
renorm_shift = simd::select(
|
||||
renorm_shift > Simd<uint32_t, N>{4},
|
||||
renorm_shift - Simd<uint32_t, N>{4},
|
||||
Simd<uint32_t, N>{0});
|
||||
|
||||
Simd<int32_t, N> inf_nan_mask =
|
||||
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
|
||||
auto result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
float operator()(uint8_t x) {
|
||||
return (*this)(Simd<uint8_t, 1>(x)).value;
|
||||
}
|
||||
};
|
||||
} // namespace mlx::core::detail
|
||||
|
||||
@@ -51,19 +51,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||
|
||||
# fp4 is not available on < 12.8
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
|
||||
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
|
||||
endif()
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||
@@ -177,6 +170,11 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
@@ -97,6 +97,7 @@ CudaAllocator::CudaAllocator()
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size <= small_block_size) {
|
||||
size = 8;
|
||||
@@ -130,7 +131,7 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += buf->size;
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@@ -336,17 +334,4 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
__device__ uint8_t operator()(T x) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -35,9 +35,9 @@ std::vector<array> precompiled_cuda_kernel(
|
||||
const std::vector<ScalarArg>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
int /* shared_memory */,
|
||||
std::optional<float> /* init_value */,
|
||||
bool /* ensure_row_contiguous */,
|
||||
StreamOrDevice) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
@@ -306,7 +306,7 @@ void affine_dequantize(
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
void fast::ConvertFP8::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (to_fp8_) {
|
||||
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
|
||||
} else {
|
||||
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -1,83 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
struct __nv_fp8_e8m0 {
|
||||
__device__ __nv_fp8_e8m0(float x) {
|
||||
if (!std::isfinite(x)) {
|
||||
__x = 0xFF;
|
||||
return;
|
||||
}
|
||||
if (x < 0.0f) {
|
||||
__x = 0x00;
|
||||
return;
|
||||
}
|
||||
float le = std::log2f(x);
|
||||
int n = static_cast<int>(std::nearbyintf(le));
|
||||
|
||||
n = n < -127 ? -127 : n;
|
||||
n = n > 127 ? 127 : n;
|
||||
__x = static_cast<uint8_t>(n + 127);
|
||||
}
|
||||
|
||||
__device__ operator float() {
|
||||
if (__x == 0xFF) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return std::ldexp(1.0f, static_cast<int>(__x) - 127);
|
||||
}
|
||||
|
||||
uint8_t __x{0};
|
||||
};
|
||||
|
||||
struct __nv_fp4_e2m1 {
|
||||
__device__ __nv_fp4_e2m1(float x) {
|
||||
if (std::isnan(x)) {
|
||||
__x = 0x7;
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
||||
x = std::abs(x);
|
||||
|
||||
if (x > 5.0f) {
|
||||
__x = 0x7;
|
||||
} else if (x >= 3.5f) {
|
||||
__x = 0x6;
|
||||
} else if (x > 2.5f) {
|
||||
__x = 0x5;
|
||||
} else if (x >= 1.75f) {
|
||||
__x = 0x4;
|
||||
} else if (x > 1.25f) {
|
||||
__x = 0x3;
|
||||
} else if (x >= 0.75f) {
|
||||
__x = 0x2;
|
||||
} else if (x > 0.25f) {
|
||||
__x = 0x1;
|
||||
} else {
|
||||
__x = 0x0;
|
||||
}
|
||||
__x |= sign_bit;
|
||||
}
|
||||
|
||||
__device__ operator float() {
|
||||
static const float LUT[16] = {
|
||||
0.0f,
|
||||
0.5f,
|
||||
1.0f,
|
||||
1.5f,
|
||||
2.0f,
|
||||
3.0f,
|
||||
4.0f,
|
||||
6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
return LUT[__x];
|
||||
}
|
||||
uint8_t __x{0};
|
||||
};
|
||||
@@ -1,216 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda_fp4.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
__device__ uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
} else {
|
||||
return __nv_fp4_e2m1(x).__x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
if constexpr (bits == 8) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(__nv_fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float w_thread = w[index];
|
||||
|
||||
cg::greater<float> max_op;
|
||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||
|
||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
// Convert to mx scale or nv scale
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto s = ScaleType(scale);
|
||||
uint8_t q_scale = s.__x;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = warp.shfl_down(output, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t oindex = offset * pack_factor;
|
||||
|
||||
if (oindex >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t gindex = oindex / group_size;
|
||||
using ScaleType =
|
||||
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
||||
auto scale = float(((ScaleType*)(scales))[gindex]);
|
||||
|
||||
out += oindex;
|
||||
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||
}
|
||||
bool large = w.size() > UINT_MAX;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<uint8_t>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int = 8 / bits;
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_dequantize<T, 32, 4, true>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_dequantize<T, 32, 8, true>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_dequantize<T, 16, 4, false>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -57,30 +57,23 @@ void fast::Quantize::eval_gpu(
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
|
||||
}
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,22 +24,4 @@ void affine_dequantize(
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void fp_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -108,12 +108,6 @@ constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, ToFP8>) {
|
||||
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, FromFP8>) {
|
||||
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
constexpr int64_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
|
||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ void pad_gpu(
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(axes); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ make_jit_source(
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
@@ -81,8 +81,7 @@ if(MLX_METAL_JIT)
|
||||
|
||||
make_jit_source(quantized_utils)
|
||||
make_jit_source(quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
|
||||
kernels/fp4.h)
|
||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
|
||||
@@ -109,7 +109,7 @@ inline void build_kernel(
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
@@ -134,7 +134,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
@@ -174,7 +174,7 @@ inline void build_kernel(
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
}
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
if (dynamic_dims) {
|
||||
@@ -195,7 +195,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os += fmt::format(
|
||||
@@ -214,7 +214,7 @@ inline void build_kernel(
|
||||
} else {
|
||||
os += x.primitive().name();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
@@ -227,7 +227,7 @@ inline void build_kernel(
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
@@ -396,7 +396,7 @@ void Compiled::eval_gpu(
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -990,7 +990,7 @@ void conv_3D_gpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& /* copies */) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
|
||||
@@ -68,7 +68,7 @@ std::string write_signature(
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
@@ -109,7 +109,7 @@ std::string write_signature(
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(output_names); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
@@ -126,8 +126,8 @@ std::string write_signature(
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
if (index < std::ssize(inputs) + std::ssize(output_names) - 1 ||
|
||||
std::ssize(attributes) > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
@@ -138,7 +138,7 @@ std::string write_signature(
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
if (index < std::ssize(attributes) - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
@@ -381,7 +381,7 @@ void CustomKernel::eval_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(checked_inputs); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
@@ -408,7 +408,7 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
auto tg_size = tx * ty * tz;
|
||||
unsigned long tg_size = tx * ty * tz;
|
||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (tg_size > max_tg_size) {
|
||||
std::ostringstream msg;
|
||||
|
||||
@@ -127,6 +127,9 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
(void)device;
|
||||
(void)lib_name;
|
||||
#endif
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
@@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||
|
||||
std::vector<NS::Object*> objs(funcs.size());
|
||||
for (int i = 0; i < funcs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(funcs); i++) {
|
||||
objs[i] = funcs[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ struct DeviceStream {
|
||||
// Data updated between command buffers
|
||||
MTL::CommandBuffer* buffer{nullptr};
|
||||
int buffer_ops{0};
|
||||
size_t buffer_sizes{0};
|
||||
int64_t buffer_sizes{0};
|
||||
|
||||
// The command encoder, fence, and temporaries are updated between command
|
||||
// encoders
|
||||
|
||||
@@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
}
|
||||
|
||||
void Fence::update(Stream stream, const array& x) {
|
||||
@@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -60,7 +60,7 @@ struct FourStepParams {
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
int64_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
@@ -93,7 +93,7 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(radices); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
@@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) {
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(steps); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
@@ -260,7 +260,7 @@ int primitive_root(int n) {
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
@@ -508,7 +508,7 @@ void four_step_fft(
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
int64_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
@@ -612,11 +612,11 @@ void fft_op(
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(plan.stockham); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(plan.rader); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
@@ -771,8 +771,8 @@ void nd_fft_op(
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
for (int i = std::ssize(axes) - 1; i >= 0; i--) {
|
||||
int reverse_index = std::ssize(axes) - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
@@ -780,8 +780,8 @@ void nd_fft_op(
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
bool step_real = (real && index == std::ssize(axes) - 1);
|
||||
const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) {
|
||||
while (end != std::string_view::npos) {
|
||||
source << " tmp[" << index << "] = ";
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
for (int i = 0; i < std::ssize(row); i++) {
|
||||
source << " " << row[i] << " x[" << i << "]";
|
||||
}
|
||||
source << ";" << std::endl;
|
||||
|
||||
@@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t slice_size = 1;
|
||||
int64_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
@@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||
size_t dim_y = indices.size();
|
||||
int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||
int64_t dim_y = indices.size();
|
||||
auto group_dims = get_block_dims(dim_x, dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
|
||||
|
||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
@@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
size_t dim1 = 1;
|
||||
int64_t dim0 = 1;
|
||||
int64_t dim1 = 1;
|
||||
if (nidx) {
|
||||
if (inputs[1].ndim() >= 1) {
|
||||
dim0 = inputs[1].shape(0);
|
||||
@@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dim1 = inputs[1].size() / dim0;
|
||||
}
|
||||
}
|
||||
size_t dim2 = slice_size;
|
||||
int64_t dim2 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
std::vector<int64_t> idx_strides;
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
int64_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
@@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
size_t nthreads = upd.size();
|
||||
int64_t nthreads = upd.size();
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
int64_t upd_ndim = upd.ndim();
|
||||
int64_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
@@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
int64_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
@@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
@@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
int64_t size_pre = 1;
|
||||
int64_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
@@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
@@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
int64_t size_pre = 1;
|
||||
int64_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized_utils();
|
||||
const char* quantized();
|
||||
const char* fp_quantized();
|
||||
const char* fp4_quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
|
||||
@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized_utils(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||
template_def);
|
||||
return kernel_source;
|
||||
});
|
||||
@@ -856,13 +856,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||
bool is_affine = mode == "affine";
|
||||
if (mode == "affine") {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
is_affine ? metal::quantized() : metal::fp_quantized(),
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
@@ -872,6 +872,23 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
} else {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::fp4_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
"uint8_t",
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
|
||||
@@ -6,7 +6,6 @@ set(BASE_HEADERS
|
||||
defines.h
|
||||
erf.h
|
||||
expm1f.h
|
||||
fp8.h
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
@@ -110,8 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
|
||||
${STEEL_HEADERS})
|
||||
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
constexpr constant static float FP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
struct fp4_e2m1 {
|
||||
fp4_e2m1(float x) {
|
||||
if (metal::isnan(x)) {
|
||||
bits = 0x7;
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
|
||||
x = metal::abs(x);
|
||||
|
||||
if (x > 5.0f) {
|
||||
bits = 0x7;
|
||||
} else if (x >= 3.5f) {
|
||||
bits = 0x6;
|
||||
} else if (x > 2.5f) {
|
||||
bits = 0x5;
|
||||
} else if (x >= 1.75f) {
|
||||
bits = 0x4;
|
||||
} else if (x > 1.25f) {
|
||||
bits = 0x3;
|
||||
} else if (x >= 0.75f) {
|
||||
bits = 0x2;
|
||||
} else if (x > 0.25f) {
|
||||
bits = 0x1;
|
||||
} else {
|
||||
bits = 0x0;
|
||||
}
|
||||
bits |= sign_bit;
|
||||
}
|
||||
|
||||
operator float() {
|
||||
return FP4_LUT[bits];
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
@@ -3,9 +3,6 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fp4.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
@@ -62,10 +59,28 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
constexpr constant static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename T>
|
||||
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
if (simd_gid == 0 && simd_lid < 16) {
|
||||
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
@@ -140,7 +155,8 @@ template <
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size>
|
||||
short group_size,
|
||||
typename S>
|
||||
struct QuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
@@ -167,12 +183,12 @@ struct QuantizedBlockLoader {
|
||||
|
||||
threadgroup T* dst;
|
||||
const device uint8_t* src;
|
||||
const device uint8_t* scales;
|
||||
const device S* scales;
|
||||
threadgroup T* lut;
|
||||
|
||||
QuantizedBlockLoader(
|
||||
const device uint8_t* src_,
|
||||
const device uint8_t* scales_,
|
||||
const device S* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
threadgroup T* lut_,
|
||||
@@ -192,7 +208,7 @@ struct QuantizedBlockLoader {
|
||||
bj * bytes_per_pack),
|
||||
scales(scales_ + bi * src_ld / group_size),
|
||||
lut(lut_) {
|
||||
load_fp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
}
|
||||
|
||||
void load_unsafe() const {
|
||||
@@ -254,10 +270,10 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, int bits, int D>
|
||||
METAL_FUNC void fp_qmv_quad_impl(
|
||||
template <typename T, int group_size, typename S, int D>
|
||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
constant int& in_vec_size,
|
||||
@@ -279,7 +295,7 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
@@ -295,7 +311,7 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
|
||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
@@ -311,10 +327,10 @@ METAL_FUNC void fp_qmv_quad_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_fast_impl(
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -337,7 +353,7 @@ METAL_FUNC void fp_qmv_fast_impl(
|
||||
typedef float U;
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -374,10 +390,10 @@ METAL_FUNC void fp_qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_impl(
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -402,7 +418,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -432,7 +448,7 @@ METAL_FUNC void fp_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
uint8_t s = sl[0];
|
||||
S s = sl[0];
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
@@ -513,10 +529,10 @@ METAL_FUNC void fp_qmv_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits>
|
||||
METAL_FUNC void fp_qvm_impl(
|
||||
template <typename T, const int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qvm_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const int in_vec_size,
|
||||
@@ -545,7 +561,7 @@ METAL_FUNC void fp_qvm_impl(
|
||||
thread U scale = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -617,14 +633,14 @@ METAL_FUNC void fp_qvm_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void fp_qmm_t_impl(
|
||||
METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -661,7 +677,8 @@ METAL_FUNC void fp_qmm_t_impl(
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size>;
|
||||
group_size,
|
||||
S>;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K * bytes_per_pack / pack_factor;
|
||||
@@ -742,13 +759,13 @@ METAL_FUNC void fp_qmm_t_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void fp_qmm_n_impl(
|
||||
METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -786,7 +803,8 @@ METAL_FUNC void fp_qmm_n_impl(
|
||||
BN_padded,
|
||||
0,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size>;
|
||||
group_size,
|
||||
S>;
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
@@ -873,11 +891,11 @@ METAL_FUNC void fp_qmm_n_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device uint8_t*& scales,
|
||||
const device S*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
@@ -908,11 +926,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device uint8_t*& scales,
|
||||
const device S*& scales,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
device T*& y,
|
||||
@@ -958,10 +976,10 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
[[kernel]] void fp_qmv_quad(
|
||||
template <typename T, int group_size, typename S, int D, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_quad(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -996,7 +1014,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_quad_impl<T, group_size, bits, D>(
|
||||
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1011,10 +1029,10 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv_fast(
|
||||
template <typename T, int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1047,14 +1065,14 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv(
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1087,14 +1105,14 @@ template <typename T, const int group_size, int bits, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qvm(
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qvm(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1127,14 +1145,14 @@ template <typename T, const int group_size, int bits, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
[[kernel]] void fp_qvm_split_k(
|
||||
template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
[[kernel]] void mxfp4_qvm_split_k(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1171,7 +1189,7 @@ template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1187,15 +1205,15 @@ template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void fp_qmm_t(
|
||||
[[kernel]] void mxfp4_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1236,21 +1254,21 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void fp_qmm_n(
|
||||
[[kernel]] void mxfp4_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1293,14 +1311,14 @@ template <
|
||||
tid);
|
||||
}
|
||||
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv_fast(
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1343,14 +1361,14 @@ template <typename T, int group_size, int bits>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv(
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1393,14 +1411,14 @@ template <typename T, int group_size, int bits>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qvm(
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qvm(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1443,21 +1461,21 @@ template <typename T, int group_size, int bits>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void fp_gather_qmm_t(
|
||||
[[kernel]] void mxfp4_gather_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1508,20 +1526,20 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
typename S,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void fp_gather_qmm_n(
|
||||
[[kernel]] void mxfp4_gather_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1573,24 +1591,24 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
typename S,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void fp_gather_qmm_rhs(
|
||||
[[kernel]] void mxfp4_gather_qmm_rhs(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device uint8_t* scales,
|
||||
const device S* scales,
|
||||
const device uint32_t* indices,
|
||||
device T* y,
|
||||
const constant int& M,
|
||||
@@ -1626,7 +1644,8 @@ template <
|
||||
transpose ? BK_padded : BN_padded,
|
||||
transpose,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size>;
|
||||
group_size,
|
||||
S>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||
@@ -1770,100 +1789,3 @@ template <
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
uint8_t operator()(float x) {
|
||||
if (bits == 8) {
|
||||
return fp8_e4m3(x).bits;
|
||||
} else {
|
||||
return fp4_e2m1(x).bits;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
float operator()(uint8_t x) {
|
||||
if (bits == 8) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(thread fp4_e2m1*)(&x));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void fp_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device uint8_t* scales [[buffer(2)]],
|
||||
uint2 tidx [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr bool use_mx_scale = group_size == 32;
|
||||
size_t index = tidx.x + grid_dim.x * size_t(tidx.y);
|
||||
|
||||
float scale;
|
||||
float w_thread = w[index];
|
||||
if (use_mx_scale) {
|
||||
scale = simd_max(abs(w_thread));
|
||||
} else {
|
||||
float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);
|
||||
float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);
|
||||
scale = tidx.x < 16 ? w_max_l : w_max_r;
|
||||
}
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
|
||||
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||
auto s = ScaleType(scale);
|
||||
uint8_t q_scale = s.bits;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales and biases
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = simd_shuffle_down(output, 1);
|
||||
output |= sval << bits;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void fp_dequantize(
|
||||
const device uint8_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr bool use_mx_scale = group_size == 32;
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
size_t oindex = offset * pack_factor;
|
||||
size_t gindex = oindex / group_size;
|
||||
|
||||
out += oindex;
|
||||
|
||||
using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;
|
||||
auto q_scale = ((device ScaleType*)(scales))[gindex];
|
||||
auto scale = float(q_scale);
|
||||
|
||||
uint val = w[offset];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < pack_factor; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));
|
||||
}
|
||||
}
|
||||
127
mlx/backend/metal/kernels/fp4_quantized.metal
Normal file
127
mlx/backend/metal/kernels/fp4_quantized.metal
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
||||
|
||||
#define instantiate_quantized(name, type) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4", \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type) \
|
||||
instantiate_quantized_batched(name, type, 1) \
|
||||
instantiate_quantized_batched(name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4_gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
@@ -1,88 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
inline float fp32_from_bits(uint32_t bits) {
|
||||
return *(reinterpret_cast<thread float*>(&bits));
|
||||
}
|
||||
inline float fp32_to_bits(float x) {
|
||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||
}
|
||||
|
||||
struct fp8_e4m3 {
|
||||
template <typename T>
|
||||
fp8_e4m3(T f) {
|
||||
// From PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||
uint32_t fp8_max = 543 << 21;
|
||||
uint32_t denorm_mask = 141 << 23;
|
||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
// Default behavior saturates to min/max
|
||||
bits = 0x7E;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
bits = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
f_bits += mant_odd;
|
||||
bits = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
bits |= static_cast<uint8_t>(sign >> 24);
|
||||
}
|
||||
|
||||
operator float() {
|
||||
// From PyTorch:
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||
uint32_t w = static_cast<uint32_t>(bits) << 24;
|
||||
uint32_t sign = w & 0x80000000;
|
||||
uint32_t nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
uint32_t renorm_shift = metal::clz(nonsign);
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
|
||||
int32_t inf_nan_mask =
|
||||
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
|
||||
struct fp8_e8m0 {
|
||||
fp8_e8m0(float x) {
|
||||
if (!metal::isfinite(x)) {
|
||||
bits = 0xFF;
|
||||
return;
|
||||
}
|
||||
if (x < 0.0f) {
|
||||
bits = 0x00;
|
||||
return;
|
||||
}
|
||||
float le = metal::log2(x);
|
||||
int n = int(metal::round(le));
|
||||
|
||||
n = n < -127 ? -127 : n;
|
||||
n = n > 127 ? 127 : n;
|
||||
bits = static_cast<uint8_t>(n + 127);
|
||||
}
|
||||
|
||||
operator float() {
|
||||
if (bits == 0xFF) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return metal::ldexp(1.0f, static_cast<int>(bits) - 127);
|
||||
}
|
||||
|
||||
uint8_t bits;
|
||||
};
|
||||
@@ -1,147 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
||||
|
||||
#define instantiate_quantized(mode, name, type) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4", \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4)
|
||||
|
||||
#define instantiate_quantized_batched(mode, name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(mode, name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(mode, name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(mode, name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(mode, name, type) \
|
||||
instantiate_quantized_batched(mode, name, type, 1) \
|
||||
instantiate_quantized_batched(mode, name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4, gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_quantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits) \
|
||||
instantiate_kernel( \
|
||||
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_dequantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantize_dequantize_modes(type) \
|
||||
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
|
||||
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
|
||||
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type) \
|
||||
instantiate_quantize_dequantize_modes(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
@@ -9,11 +9,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
index *= N;
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,11 +28,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ template <
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
out[out_idx++] = static_cast<U>(Op()(in[idx]));
|
||||
out[out_idx++] = Op()(in[idx]);
|
||||
idx += xstride;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,13 +103,4 @@ instantiate_unary_base_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool)
|
||||
|
||||
instantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t)
|
||||
instantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t)
|
||||
instantiate_unary_all(ToFP8, float32, uint8, float, uint8_t)
|
||||
instantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t)
|
||||
instantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t)
|
||||
instantiate_unary_all(FromFP8, uint8, float32, uint8_t, float)
|
||||
|
||||
// clang-format on
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/metal/kernels/cexpf.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
|
||||
namespace {
|
||||
constant float inf = metal::numeric_limits<float>::infinity();
|
||||
@@ -226,7 +225,8 @@ struct Floor {
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
float operator()(complex64_t x) {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.imag;
|
||||
};
|
||||
};
|
||||
@@ -290,7 +290,8 @@ struct Negative {
|
||||
};
|
||||
|
||||
struct Real {
|
||||
float operator()(complex64_t x) {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.real;
|
||||
};
|
||||
};
|
||||
@@ -439,16 +440,3 @@ complex64_t ArcTan::operator()(complex64_t x) {
|
||||
auto ix = i * x;
|
||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||
};
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
uint8_t operator()(T f) {
|
||||
return fp8_e4m3(f).bits;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
float operator()(uint8_t x) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int /* batch_size_out */,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
|
||||
@@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
const std::optional<array>& /* mask_out */,
|
||||
const std::optional<array>& /* mask_op */,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
|
||||
@@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu(
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate the gradient accumulator gw and a temporary to store the
|
||||
@@ -246,7 +246,7 @@ void LayerNorm::eval_gpu(
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
int simd_size = 32;
|
||||
@@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu(
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
|
||||
@@ -26,6 +26,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
(void)inputs;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@@ -152,7 +153,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Load::eval_gpu(const std::vector<array>& /* inputs */, array& /* out */) {
|
||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||
}
|
||||
|
||||
@@ -201,41 +202,45 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
||||
}
|
||||
|
||||
void SVD::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||
}
|
||||
|
||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
||||
void Inverse::eval_gpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
array& /* output */) {
|
||||
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
|
||||
}
|
||||
|
||||
void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Cholesky::eval_gpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
array& /* out */) {
|
||||
throw std::runtime_error(
|
||||
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
||||
}
|
||||
|
||||
void Eig::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
|
||||
}
|
||||
|
||||
void Eigh::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
|
||||
}
|
||||
|
||||
void LUF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -27,9 +26,14 @@ auto get_quantized_kernel_wrapped(
|
||||
int bits,
|
||||
Args... args) {
|
||||
std::string template_def;
|
||||
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
|
||||
auto fname = mode + "_" + func;
|
||||
if (mode == "affine") {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
||||
}
|
||||
return get_quantized_kernel(d, name, template_def, mode);
|
||||
}
|
||||
|
||||
@@ -1040,31 +1044,26 @@ void fast::Quantize::eval_gpu(
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
if (dequantize_) {
|
||||
auto scales = ensure_row_contiguous(inputs[1], d, s);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], d, s);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
}
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
}
|
||||
}
|
||||
|
||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||
: get_type_string(w_pre.dtype());
|
||||
auto mode = quantization_mode_to_string(mode_);
|
||||
std::string kname;
|
||||
concatenate(
|
||||
kname,
|
||||
mode + (dequantize_ ? "_dequantize" : "_quantize"),
|
||||
dequantize_ ? "affine_dequantize" : "affine_quantize",
|
||||
"_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
@@ -1075,7 +1074,7 @@ void fast::Quantize::eval_gpu(
|
||||
d,
|
||||
kname,
|
||||
dequantize_ ? "dequantize" : "quantize",
|
||||
mode,
|
||||
"affine",
|
||||
type_string,
|
||||
group_size_,
|
||||
bits_);
|
||||
@@ -1088,8 +1087,7 @@ void fast::Quantize::eval_gpu(
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread =
|
||||
dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1);
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||
size_t nthreads =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
|
||||
@@ -1110,12 +1108,4 @@ void fast::Quantize::eval_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fast::ConvertFP8::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
unary_op_gpu(inputs, out, name(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -291,7 +291,7 @@ void init_reduce(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [_, out_type] = remap_reduce_types(out, op_name);
|
||||
const std::string func_name = "init_reduce";
|
||||
std::string kname = func_name;
|
||||
@@ -397,7 +397,7 @@ void row_reduce_small(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
// Set the kernel
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
@@ -453,7 +453,7 @@ void row_reduce_simple(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
// Set the kernel
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_simple";
|
||||
@@ -493,7 +493,7 @@ void row_reduce_looped(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Set the kernel
|
||||
@@ -570,7 +570,7 @@ void strided_reduce_small(
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Figure out the grid dims
|
||||
@@ -747,7 +747,7 @@ void strided_reduce_looped(
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Continue with reduction operation
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
size_t min_bytes = std::max<int64_t>(out.nbytes(), 4);
|
||||
out.set_data(allocator::malloc(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
|
||||
@@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) {
|
||||
// Remove wired allocations until under capacity
|
||||
auto allocations = wired_set_->allAllocations();
|
||||
auto num_allocations = wired_set_->allocationCount();
|
||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
for (size_t i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||
wired_set_->removeAllocation(buf);
|
||||
current_size -= buf->allocatedSize();
|
||||
|
||||
@@ -76,7 +76,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
int64_t size = in.shape(axis_);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
|
||||
// Compute the thread grid
|
||||
|
||||
@@ -33,7 +33,7 @@ void concatenate_gpu(
|
||||
auto& d = metal::device(s.device);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
|
||||
@@ -144,7 +144,17 @@ UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
@@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#else
|
||||
// appease warnings
|
||||
(void)queue;
|
||||
(void)index;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label(
|
||||
}
|
||||
label << primitive.name();
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#else
|
||||
(void)command_buffer;
|
||||
(void)primitive;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -130,7 +130,6 @@ NO_CPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_CPU_MULTI(Quantize)
|
||||
NO_CPU_MULTI(ConvertFP8)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user