Compare commits

..

24 Commits

Author SHA1 Message Date
Pedro Cuenca
eba6a9d163 Compatibility with pip-installed openmpi (#2741)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-07 16:58:31 -08:00
CCYeh
be9e2aebd6 Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like

* Improvements

* fix access after moved
2025-11-06 19:12:20 -08:00
Awni Hannun
df58b4133a [CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00
Anastasiia Filippova
27778156dc Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-05 08:21:11 -08:00
Mike Drob
761f901a41 fix property name (#2736) 2025-11-05 06:31:56 -06:00
Angelos Katharopoulos
6ece97f69b Make cpu binary_op easily accessible (#2733) 2025-11-05 01:08:41 -08:00
Awni Hannun
d3bc6a9bff don't test when doing release (#2734)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-04 15:54:23 -08:00
Awni Hannun
26ceb507eb only build for macos 14 and up (#2731)
* only build for macos 14 and up

* bump metal cpp
2025-11-04 09:44:15 -08:00
Mike Drob
910b3e3299 skip self-hosted runners on forks (#2730)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-03 16:22:13 -06:00
Harsh Sutaria
50fa315d18 Fix addmm with empty matrices and beta != 1.0 (#2715) 2025-11-03 14:16:15 -08:00
AN Long
1ff2b713b6 Check isnan in maximum / minimum with CPU backend (#2652)
* Check isnan in maximum / minimum with CPU backend

* Add tests

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-03 08:51:14 -08:00
Mike Drob
50514a6146 Set up publishing to PyPI and Test-PyPI (#2721) 2025-11-03 07:20:11 -08:00
Awni Hannun
93d76b0f30 Fix compile multi capture (#2678)
* fix compile when compiling multiple lambdas with the same capture

* add test
2025-11-03 06:33:43 -08:00
David Koski
78678de0cd add null check -- the bundleIdentifier is optional (#2709)
* add null check -- the bundleIdentifier is optional

* use variable
2025-11-03 06:33:21 -08:00
Melissa Kilby
ed9c6b1117 update: add linux fedora container CI - CPP build test only (#2722)
* update: add linux_fedora_build_cpp CI - CPP build test only - x86-64

Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: add linux_fedora_build_cpp_aarch64 CI - CPP build test only - arm64

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: convert linux_fedora_build_cpp to matrix.arch loop

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

---------

Signed-off-by: Melissa Kilby <mkilby@apple.com>
Co-authored-by: Mike Drob <mdrob@apple.com>
2025-11-03 06:33:00 -08:00
Awni Hannun
39b04ce638 use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-31 11:49:59 -07:00
Mike Drob
d9e6349657 fix docs path (#2719) 2025-10-30 19:12:49 -05:00
Angelos Katharopoulos
b901a9f311 Fix the order of hosts in the ring (#2718)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-30 15:02:39 -07:00
Awni Hannun
68c5fa1c95 fix memory count bug (#2717) 2025-10-30 14:27:15 -07:00
Christopher Webb
793a31eeb6 Fix missing domain_uuid_key in thunderbolt ring setup (#2682) 2025-10-30 13:17:20 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
232 changed files with 5145 additions and 3075 deletions

View File

@@ -0,0 +1,24 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
MLX_BUILD_STAGE: 2
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
python -m build -w
if [ -f "python/scripts/repair_cuda.sh" ]; then
bash python/scripts/repair_cuda.sh
fi

64
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,64 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
# this value is dependent on the CUDA tools installed in the setup-linux workflow
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Run Python tests - CPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- name: Build Python package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: ${{ inputs.nvcc-location }}

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

78
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,78 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
build-type:
description: 'Build type'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
type: boolean
runs:
using: "composite"
steps:
- name: Set DEBUG
shell: sh
if: inputs.build-type == 'debug'
run: echo "DEBUG=1" >> $GITHUB_ENV
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
if: inputs.build-type == 'debug'
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
shell: sh
run: ./build/tests/tests
- name: Build Python package
if: inputs.build-type == 'release'
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
if [ -f "python/scripts/repair_linux.sh" ]; then
bash python/scripts/repair_linux.sh
fi
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

View File

@@ -0,0 +1,22 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
runs:
using: "composite"
steps:
- name: Build Python package(s)
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

120
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,120 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
inputs:
build-type:
description: 'Build type (debug, release)'
required: false
default: 'debug'
type: choice
options:
- debug
- release
run-tests:
description: 'Whether to run tests'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
DEV_RELEASE: 1
run: |
uv pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Install tests dependencies
if: inputs.run-tests == 'true'
shell: sh
run: |
uv pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
if: inputs.run-tests == 'true'
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
if: inputs.run-tests == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
if: inputs.run-tests == 'true'
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
if: inputs.run-tests == 'true'
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
if: inputs.run-tests == 'true'
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
- name: Build macOS 14 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
- name: Build macOS 15 package
if: inputs.build-type == 'release'
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0

83
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

31
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
install-mpi:
description: 'Whether to install MPI'
required: false
default: 'true'
type: boolean
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
if: inputs.install-mpi == 'true'
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -0,0 +1,27 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

116
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,116 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
env:
MACOSX_DEPLOYMENT_TARGET: "15.0"
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,20 +1,71 @@
on:
pull_request:
branches:
- main
name: Build and Test
on: pull_request
permissions:
contents: read
jobs:
check_lint:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: 3.8
- name: Install dependencies
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

210
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,210 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
outputs:
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
with:
build-type: release
run-tests: false
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}

View File

@@ -1,4 +1,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -20,13 +20,9 @@ project(
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
add_compile_options(-Wall -Wextra)
endif()
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
@@ -92,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
@@ -126,9 +127,12 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
@@ -137,7 +141,6 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>

View File

@@ -17,11 +17,10 @@ To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.10
- macOS >= 13.5
- macOS >= 14.0
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
MLX is only available on devices running macOS >= 14.0 and higher.
CUDA
^^^^

View File

@@ -7,12 +7,13 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
moment we support three different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
@@ -84,9 +85,8 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
@@ -220,7 +220,7 @@ print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py

View File

@@ -14,17 +14,14 @@ void array_basics() {
// Get the value out of it:
auto s = x.item<float>();
assert(s == 1.0);
(void)s;
// Scalars have a size of 1:
int64_t size = x.size();
size_t size = x.size();
assert(size == 1);
(void)size;
// Scalars have 0 dimensions:
int ndim = x.ndim();
assert(ndim == 0);
(void)ndim;
// The shape should be an empty vector:
auto shape = x.shape();
@@ -33,7 +30,6 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);
(void)dtype;
// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (int i = 0; i < std::ssize(shapes); ++i) {
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
// For each node in |outputs|, its siblings are the other nodes.
for (int i = 0; i < std::ssize(outputs); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
}
@@ -141,13 +141,12 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim =
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
}
void array::set_data(
@@ -157,7 +156,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -173,9 +172,8 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
}
void array::copy_shared_buffer(const array& other) {
@@ -193,7 +191,7 @@ array::~array() {
}
// Break circular reference for non-detached arrays with siblings
if (auto n = std::ssize(siblings()); n > 0) {
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
@@ -275,7 +273,7 @@ array::ArrayDesc::~ArrayDesc() {
ad.inputs.clear();
for (auto& [_, a] : input_map) {
bool is_deletable =
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
(a.array_desc_.use_count() <= a.siblings().size() + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
@@ -284,7 +282,7 @@ array::ArrayDesc::~ArrayDesc() {
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
s.array_desc_.use_count() <= a.siblings().size() + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));

View File

@@ -81,22 +81,22 @@ class array {
}
/** The size of the array's datatype in bytes. */
int itemsize() const {
size_t itemsize() const {
return size_of(dtype());
}
/** The number of elements in the array. */
int64_t size() const {
size_t size() const {
return array_desc_->size;
}
/** The number of bytes in the array. */
int64_t nbytes() const {
size_t nbytes() const {
return size() * itemsize();
}
/** The number of dimensions of the array. */
int ndim() const {
size_t ndim() const {
return array_desc_->shape.size();
}
@@ -294,6 +294,11 @@ class array {
return array_desc_->siblings;
}
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
@@ -329,7 +334,7 @@ class array {
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
int64_t data_size() const {
size_t data_size() const {
return array_desc_->data_size;
}
@@ -340,7 +345,7 @@ class array {
return array_desc_->data->buffer;
}
int64_t buffer_size() const {
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
@@ -349,15 +354,23 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
}
enum Status {
@@ -461,8 +474,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
// Offset from beginning of data pointer
int64_t offset{0};
// The size in elements of the data buffer the array accesses
size_t data_size;
@@ -530,7 +543,7 @@ array::array(
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (std::ssize(data) != size()) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
}

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
mallocfn(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
int64_t r = 1, c = 1;
for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
@@ -60,8 +60,7 @@ void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
i < std::ssize(outputs);
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
@@ -71,7 +70,7 @@ void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < std::ssize(outputs); i++) {
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
@@ -207,11 +206,11 @@ void Split::eval(
auto compute_new_flags = [](const auto& shape,
const auto& strides,
int64_t in_data_size,
size_t in_data_size,
auto flags) {
int64_t data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
@@ -241,7 +240,7 @@ void Split::eval(
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < std::ssize(indices); i++) {
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
@@ -255,7 +254,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < std::ssize(axes_) && i == axes_[j]) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
@@ -273,7 +272,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < std::ssize(axes_); ++ax) {
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}

View File

@@ -114,13 +114,15 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) {
int o = 0;
Strides strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
@@ -138,16 +140,16 @@ void compiled_allocate_outputs(
data_size = in.data_size();
}
}
for (; o < std::ssize(outputs); ++o) {
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
mallocfn(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
@@ -162,8 +164,8 @@ void compiled_allocate_outputs(
o++;
}
}
for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
for (; o < outputs.size(); ++o) {
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
}
}
}
@@ -193,7 +195,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
// Broadcast the inputs to the output shape.
Strides xstrides;
int j = 0;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
@@ -201,7 +203,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); ++i, ++j) {
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
@@ -224,13 +226,13 @@ bool compiled_use_large_index(
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
int64_t max_size = 0;
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
int64_t max_size = 0;
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
return false;
}
}

View File

@@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),

View File

@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < std::ssize(axes); i++) {
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];

View File

@@ -24,8 +24,8 @@ std::tuple<int64_t, Strides> prepare_slice(
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
int64_t data_offset,
int64_t data_size,
size_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -61,7 +61,7 @@ void slice(
if (data_end < 0) {
data_end += in.data_size();
}
int64_t data_size = (data_end - data_offset);
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt) {
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
mallocfn(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -7,19 +7,22 @@
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
}

View File

@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
if (shape[0] != 1) {
to_collapse.push_back(0);
}
int64_t size = shape[0];
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < std::ssize(strides); j++) {
for (int j = 0; j < strides.size(); j++) {
const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}

View File

@@ -162,7 +162,7 @@ struct ContiguousIterator {
};
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
int64_t no_broadcast_data_size = 1;
size_t no_broadcast_data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
bool is_row_contiguous = true;
@@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
}
inline bool is_donatable(const array& in, const array& out) {
constexpr int64_t donation_extra = 16384;
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;

View File

@@ -10,7 +10,7 @@ namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, int64_t size, Stream stream) {
void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
auto& encoder = cpu::get_command_encoder(stream);

View File

@@ -19,12 +19,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (int64_t i = 0; i < out.size(); ++i) {
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;

View File

@@ -14,238 +14,11 @@
namespace mlx::core {
namespace {
template <typename Op>
void binary(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add(), stream());
binary_op_cpu(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@@ -329,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide(), stream());
binary_op_cpu(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder(), stream());
binary_op_cpu(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -377,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
}
});
} else {
comparison_op(a, b, out, detail::Equal(), stream());
comparison_op_cpu(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
comparison_op_cpu(
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_float(a, b, out, detail::LogAddExp(), stream());
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd(), stream());
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr(), stream());
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum(), stream());
binary_op_cpu(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum(), stream());
binary_op_cpu(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply(), stream());
binary_op_cpu(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power(), stream());
binary_op_cpu(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract(), stream());
binary_op_cpu(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -468,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
switch (op_) {
case BitwiseBinary::And:
binary_int(a, b, out, detail::BitwiseAnd(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
binary_int(a, b, out, detail::BitwiseOr(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
binary_int(a, b, out, detail::BitwiseXor(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
binary_int(a, b, out, detail::LeftShift(), stream());
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
binary_int(a, b, out, detail::RightShift(), stream());
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
break;
}
}
@@ -489,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
binary_float(a, b, out, detail::ArcTan2(), stream());
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
template <typename Op>
void binary_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace mlx::core

View File

@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -137,21 +137,21 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (int64_t i = 0; i < b.data_size(); ++i) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (int64_t i = 0; i < a.data_size(); ++i) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
} else { // VectorVector
for (int64_t i = 0; i < a.size(); ++i) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;

View File

@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
int64_t num_matrices = size / (N * N);
for (int64_t i = 0; i < num_matrices; i++) {
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(

View File

@@ -49,7 +49,7 @@ static CompilerCache& cache() {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& /* device */) {
bool compile_available_for_device(const Device& device) {
return true;
}
@@ -168,7 +168,7 @@ inline void build_kernel(
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (int i = 0; i < std::ssize(inputs); ++i) {
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
continue;
@@ -238,7 +238,7 @@ inline void build_kernel(
} else {
os << x.primitive().name();
os << "()(";
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;

View File

@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& /* wt_dilation */,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
@@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& /* wt_dilation */,
const std::vector<int>& wt_dilation,
const bool flip,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
@@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
// Pad input
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (int i = 0; i < iDim.size(); i++) {
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
}
padded_shape.back() = C;
@@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
// Make strided view
Shape strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (int i = 0; i < oDim.size(); i++) {
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (int i = 0; i < wDim.size(); i++) {
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
Strides strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (int i = 0; i < std::ssize(wt_strides); i++) {
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}

View File

@@ -90,10 +90,14 @@ void Recv::eval_cpu(
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
(void)inputs;
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed

View File

@@ -70,7 +70,7 @@ void eig_impl(
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (int64_t i = 0; i < size / (N * N); ++i) {
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,

View File

@@ -165,7 +165,7 @@ void eigh_impl(
EighWork<T> work(jobz, uplo, N);
// Work loop
for (int64_t i = 0; i < size / (N * N); ++i) {
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N;
eig_ptr += N;

View File

@@ -20,8 +20,8 @@ struct CommandEncoder {
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& /* a */) {}
void set_output_array(array& /* a */) {}
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.

View File

@@ -12,12 +12,12 @@ void matmul(
T* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -34,7 +34,7 @@ void matmul_bnns(
bool b_transposed,
size_t lda,
size_t ldb,
size_t /* ldc */,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
@@ -52,7 +52,7 @@ void matmul_bnns(
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (size_t i = 0; i < batch_size * M * N; ++i) {
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
@@ -127,7 +127,7 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (size_t i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
reinterpret_cast<const uint8_t*>(
@@ -148,12 +148,12 @@ void matmul<float16_t>(
float16_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -13,20 +13,20 @@ void matmul<float>(
float* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
@@ -54,20 +54,20 @@ void matmul<double>(
double* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
complex64_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);

View File

@@ -11,9 +11,9 @@ namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
for (int b = 0; b < size / n; b++) {
int64_t loc = b * n;
size_t loc = b * n;
T* data_ptr = out + loc;
int h = 1;
int n_over_2 = n / 2;
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
// m component
template <typename T>
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < std::ssize(row); i++) {
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
}
for (int b = 0; b < size / m / n; b++) {
int64_t loc = b * n * m;
size_t loc = b * n * m;
T* data_ptr = out + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);

View File

@@ -78,7 +78,7 @@ void gather(
can_copy = true;
// Ignore leading 1s
int64_t i = 0;
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
@@ -91,7 +91,7 @@ void gather(
can_copy = true;
// Ignore trailing 1s
int64_t i = slice_sizes.size() - 1;
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
@@ -101,11 +101,11 @@ void gather(
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
int64_t slice_size = 1;
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
@@ -115,10 +115,10 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
int64_t out_idx = 0;
for (int64_t idx = 0; idx < ind_size; idx++) {
int64_t src_idx = 0;
for (int ii = 0; ii < std::ssize(inds); ++ii) {
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
@@ -134,7 +134,7 @@ void gather(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int64_t jj = 0; jj < slice_size; jj++) {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
@@ -403,11 +403,11 @@ void scatter(
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
int64_t n_updates = nind ? inds[0].size() : 1;
size_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
int64_t update_size = 1;
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
@@ -418,9 +418,9 @@ void scatter(
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int64_t i = 0; i < n_updates; ++i) {
int64_t out_offset = 0;
for (int j = 0; j < std::ssize(inds); ++j) {
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
@@ -429,7 +429,7 @@ void scatter(
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int64_t j = 0; j < update_size; ++j) {
for (int j = 0; j < update_size; ++j) {
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();

View File

@@ -122,7 +122,7 @@ void inverse_impl(
stream);
const int N = a.shape(-1);
const int64_t num_matrices = a.size() / (N * N);
const size_t num_matrices = a.size() / (N * N);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(inv);
@@ -130,13 +130,13 @@ void inverse_impl(
auto inv_ptr = inv.data<T>();
if (tri) {
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
for (int64_t i = 0; i < num_matrices; i++) {
for (int i = 0; i < num_matrices; i++) {
tri_inv<T>(inv_ptr + N * N * i, N, upper);
}
});
} else {
encoder.dispatch([inv_ptr, N, num_matrices]() {
for (int64_t i = 0; i < num_matrices; i++) {
for (int i = 0; i < num_matrices; i++) {
general_inv<T>(inv_ptr + N * N * i, N);
}
});

View File

@@ -25,7 +25,7 @@ inline void mask_matrix(
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const int64_t mask_offset) {
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -61,13 +61,13 @@ inline void segmented_mm(
T* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
int64_t num_segments,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
int batch_idx,
int X,
int Y,
int64_t X_data_str,
int64_t Y_data_str,
size_t X_data_str,
size_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
int64_t num_matrices = out.size() / (M * int64_t(N));
size_t num_matrices = out.size() / (M * size_t(N));
auto ldc = out.shape(-1);
encoder.dispatch([a_ptr,
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Get batch dims
auto batch_size_out = out.size() / (M * N);
int64_t matrix_stride_out = M * N;
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};

View File

@@ -2,6 +2,8 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
@@ -135,15 +137,29 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto& encoder = cpu::get_command_encoder(stream());
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}

View File

@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < std::ssize(axes); ++i) {
for (int i = 0; i < axes.size(); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
@@ -124,7 +124,6 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
(void)inputs;
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
@@ -194,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < std::ssize(inputs); i++) {
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
int64_t data_offset = strides[axis_] * sizes[i];
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
@@ -206,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr int64_t extra_bytes = 16384;
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
@@ -255,8 +254,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
copy_cpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
int64_t data_offset = 0;
for (int i = 0; i < std::ssize(axes_); i++) {
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
@@ -275,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
int64_t num_keys = keys.size() / 2;
size_t num_keys = keys.size() / 2;
int64_t elems_per_key = out.size() / num_keys;
int64_t bytes_per_key = out.itemsize() * elems_per_key;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
@@ -292,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
uintptr_t half_size = out_skip / 2;
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
@@ -334,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
auto& in = inputs[0];
@@ -362,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -397,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
int64_t num_matrices = a.size() / (M * N);
size_t num_matrices = a.size() / (M * N);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {});
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
}
allocator::free(work);
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>(
&M,
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&info);
}
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -1102,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
@@ -207,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
auto out = Simd<T, N>(asd::max(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
auto out = Simd<T, N>(asd::min(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>
@@ -253,12 +269,12 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
} else {
Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined
if (any(exp < static_cast<T>(0))) {
if (any(exp < 0)) {
return 0;
}
while (any(exp > static_cast<T>(0))) {
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > static_cast<T>(0), base * base, base);
base = select(exp > 0, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -79,8 +79,7 @@ Simd<T, N> sincos(Simd<T, N> in) {
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask =
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
auto poly_mask = (emm2 & 2) != 0;
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
@@ -88,8 +87,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2

View File

@@ -120,8 +120,8 @@ template <typename T>
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
int64_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -214,8 +214,8 @@ template <typename T>
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
int64_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();

View File

@@ -27,7 +27,7 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
int64_t num_matrices = a.size() / (M * N);
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
@@ -121,7 +121,7 @@ void svd_impl(
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int64_t i = 0; i < num_matrices; i++) {
for (int i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
@@ -153,10 +153,10 @@ void svd_impl(
template <typename T>
void compute_svd(
const array& /* a */,
bool /* compute_uv */,
std::vector<array>& /* outputs */,
Stream /* stream */) {}
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,

View File

@@ -136,7 +136,7 @@ void ternary_op(
if (topt == TernaryOpType::ScalarScalarScalar) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) {
for (int64_t i = 0; i < out.size(); ++i) {
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;

View File

@@ -10,8 +10,8 @@
namespace mlx::core {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
for (int64_t i = 0; i < shape; i += 1) {
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = Op{}(*a);
a += stride;
}
@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
src++;
}
} else {
int64_t shape = ndim > 0 ? a.shape().back() : 1;
int64_t stride = ndim > 0 ? a.strides().back() : 1;
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (int64_t elem = 0; elem < a.size(); elem += shape) {
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}

View File

@@ -108,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
@@ -51,12 +52,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -170,11 +178,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
@@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf;
}
@@ -88,16 +90,41 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr;
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
Buffer CudaAllocator::malloc(size_t size) {
void copy_to_managed(CudaBuffer& buf) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
buf.device = -1;
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(buf.data));
buf.data = new_data;
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -107,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) {
size = page_size * ((size + page_size - 1) / page_size);
}
int device = -1;
if (size > small_block_size && stream != nullptr) {
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
@@ -122,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
@@ -131,21 +168,37 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.lock();
}
active_memory_ += size;
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
if (buf->size == 0) {
delete buf;
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
@@ -169,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
}
delete buf;
}
}
@@ -220,6 +277,16 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {
@@ -232,7 +299,11 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
copy_to_managed(cbuf);
}
return cbuf.data;
}
} // namespace allocator

View File

@@ -4,7 +4,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <utility>
@@ -17,6 +19,7 @@ using allocator::Buffer;
struct CudaBuffer {
void* data;
size_t size;
int device; // -1 for managed
};
class SmallSizePool {
@@ -45,6 +48,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dims,
0,
out.data<OutType>(),
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
out.data_size());
});
}
@@ -365,7 +365,11 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param(shape),
const_param(a_strides),
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
out_a.data_size());
});
}

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
}
}
auto& encoder = cu::get_command_encoder(s);
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
for (auto& x : outputs) {
args.append(x);
}
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -5,6 +5,22 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
}
});
});

View File

@@ -92,8 +92,8 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;

View File

@@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false;
}
encoder.add_temporary(workspace);
return true;
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>

View File

@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
}
}
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"

View File

@@ -2,6 +2,8 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -334,4 +336,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto set_input_output = [&](const array& in,
array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported.");
}
}
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
distributed::detail::all_gather(group(), input, outputs[0], s);
}
void ReduceScatter::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
switch (reduce_type_) {
case Sum:
distributed::detail::sum_scatter(group(), input, outputs[0], s);
break;
default:
throw std::runtime_error("Only sum scatter is supported. ");
}
}
} // namespace mlx::core::distributed

View File

@@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
namespace mlx::core {
@@ -20,8 +22,24 @@ void Fence::wait(Stream s, const array&) {
fence->event.wait(fence->count);
}
void Fence::update(Stream s, const array&) {
void Fence::update(Stream s, const array& a, bool cross_device) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
if (cross_device) {
// Move to managed memory if there is a device switch
auto& cbuf =
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
if (cbuf.device != -1) {
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
auto& encoder = cu::device(s.device).get_command_encoder(s);
encoder.commit();
CHECK_CUDA_ERROR(cudaMemcpyAsync(
new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream()));
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
cbuf.data = new_data;
}
}
fence->count++;
fence->event.signal(s, fence->count);
}

View File

@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
@@ -278,9 +278,9 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
nullptr,
alpha);
}
@@ -321,10 +321,10 @@ void CublasGemm::run(
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(c),
alpha,
beta);
}
@@ -370,11 +370,11 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
workspace_ptr = gpu_ptr<void>(workspace);
}
auto capture = encoder.capture_context();

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr,
alpha);
a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) {
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
gpu_ptr<int8_t>(out) +
out.itemsize() * i * batch_shape.back() * M_ * N_,
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3},
uint64);
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4},
uint64);
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
gpu_ptr<int8_t*>(pointers),
gpu_ptr<int8_t>(a),
gpu_ptr<int8_t>(b),
gpu_ptr<int8_t>(c),
gpu_ptr<int8_t>(out),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;

View File

@@ -149,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
mat = gpu_ptr<DataType>(b);
vec = gpu_ptr<DataType>(a);
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
mat = gpu_ptr<DataType>(a);
vec = gpu_ptr<DataType>(b);
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -177,7 +177,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols);
} else {
@@ -189,7 +189,7 @@ void gemv(
0,
mat,
vec,
out.data<DataType>(),
gpu_ptr<DataType>(out),
rows,
cols,
const_param(batch_shape),

View File

@@ -31,7 +31,7 @@ void append_indices_arg(
int idx_ndim) {
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
indices[i] = gpu_ptr<void>(inputs[i + 1]);
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() > 0);
const auto& src = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(idx_dtype),
nidx);
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_ndim,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& src = inputs[0];
const auto& idx = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) {
return;
}
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(out.dtype()),
dtype_to_string(idx.dtype()));
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.flags().row_contiguous,
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
}
template <typename T>

View File

@@ -9,6 +9,7 @@
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h>

View File

@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(),
x.strides(),
x.flags());
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(b),
gpu_ptr<DataType>(out),
eps_,
axis_size,
w_stride,
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp);
}
}
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);

60
mlx/backend/cuda/load.cpp Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <utility>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/primitives.h"
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
namespace mlx::core {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<void>(out),
out_ptr,
nbytes,
cudaMemcpyDefault,
encoder.stream()));
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
}
} // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
cu::malloc_async(in.nbytes() / n, encoder.stream()),
in.data_size() / n,
std::move(strides),
flags);
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(out),
axis_size);
});
});

Some files were not shown because too many files have changed in this diff Show More