mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
16 Commits
39b04ce638
...
v0.29.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60d80a3728 | ||
|
|
eba6a9d163 | ||
|
|
be9e2aebd6 | ||
|
|
df58b4133a | ||
|
|
27778156dc | ||
|
|
761f901a41 | ||
|
|
6ece97f69b | ||
|
|
d3bc6a9bff | ||
|
|
26ceb507eb | ||
|
|
910b3e3299 | ||
|
|
50fa315d18 | ||
|
|
1ff2b713b6 | ||
|
|
50514a6146 | ||
|
|
93d76b0f30 | ||
|
|
78678de0cd | ||
|
|
ed9c6b1117 |
@@ -12,13 +12,9 @@ runs:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
MLX_BUILD_STAGE: 2
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
python -m build -w
|
||||
|
||||
if [ -f "python/scripts/repair_cuda.sh" ]; then
|
||||
bash python/scripts/repair_cuda.sh
|
||||
fi
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
|
||||
23
.github/actions/build-cuda/action.yml
vendored
23
.github/actions/build-cuda/action.yml
vendored
@@ -2,19 +2,10 @@ name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type (debug, release)'
|
||||
required: false
|
||||
default: 'debug'
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
required: true
|
||||
default: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
# this value is dependent on the CUDA tools installed in the setup-linux workflow
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -26,12 +17,7 @@ runs:
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
- name: Check if build actually worked
|
||||
shell: bash
|
||||
run: python -c "import mlx.core"
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
@@ -39,7 +25,6 @@ runs:
|
||||
run: python -m unittest discover python/tests -v
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
@@ -47,7 +32,6 @@ runs:
|
||||
run: python -m tests discover python/tests -v
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
cmake . -B build \
|
||||
@@ -57,12 +41,5 @@ runs:
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
|
||||
- name: Build Python package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: ${{ inputs.nvcc-location }}
|
||||
|
||||
33
.github/actions/build-linux-release/action.yml
vendored
Normal file
33
.github/actions/build-linux-release/action.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: 'Build Linux wheel'
|
||||
description: 'Build Linux wheel'
|
||||
|
||||
inputs:
|
||||
build-backend:
|
||||
description: 'Build the backend mlx-cpu package'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
run: |
|
||||
pip install -e ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- name: Build Python package
|
||||
shell: bash
|
||||
run: |
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
39
.github/actions/build-linux/action.yml
vendored
39
.github/actions/build-linux/action.yml
vendored
@@ -1,33 +1,14 @@
|
||||
name: 'Build and Test on Linux'
|
||||
description: 'Build and test MLX on Linux'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type'
|
||||
required: false
|
||||
default: 'debug'
|
||||
type: choice
|
||||
options:
|
||||
- debug
|
||||
- release
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
type: boolean
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set DEBUG
|
||||
shell: sh
|
||||
if: inputs.build-type == 'debug'
|
||||
run: echo "DEBUG=1" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Python package
|
||||
shell: sh
|
||||
env:
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
DEBUG: 1
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
- name: Generate package stubs
|
||||
@@ -37,7 +18,6 @@ runs:
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Run Python tests
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
python -m unittest discover python/tests -v
|
||||
@@ -50,7 +30,6 @@ runs:
|
||||
fi
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build && cd build
|
||||
@@ -58,21 +37,5 @@ runs:
|
||||
make -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: sh
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build Python package
|
||||
if: inputs.build-type == 'release'
|
||||
shell: bash
|
||||
run: |
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
if [ -f "python/scripts/repair_linux.sh" ]; then
|
||||
bash python/scripts/repair_linux.sh
|
||||
fi
|
||||
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
|
||||
13
.github/actions/build-macos-release/action.yml
vendored
13
.github/actions/build-macos-release/action.yml
vendored
@@ -6,11 +6,16 @@ inputs:
|
||||
description: 'macOS build target'
|
||||
required: false
|
||||
default: '15.0'
|
||||
build-backend:
|
||||
description: 'Build the backend mlx-metal package'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build Python package(s)
|
||||
- name: Build Python package
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
@@ -18,5 +23,11 @@ runs:
|
||||
uv pip install build
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 uv run -m build -w
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 uv run -m build -w
|
||||
|
||||
52
.github/actions/build-macos/action.yml
vendored
52
.github/actions/build-macos/action.yml
vendored
@@ -1,24 +1,6 @@
|
||||
name: 'Build and Test on macOS'
|
||||
description: 'Build and test MLX on macOS'
|
||||
|
||||
inputs:
|
||||
build-type:
|
||||
description: 'Build type (debug, release)'
|
||||
required: false
|
||||
default: 'debug'
|
||||
type: choice
|
||||
options:
|
||||
- debug
|
||||
- release
|
||||
run-tests:
|
||||
description: 'Whether to run tests'
|
||||
required: false
|
||||
default: 'true'
|
||||
build-jit:
|
||||
description: 'Whether to build with JIT'
|
||||
required: false
|
||||
default: 'true'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -26,11 +8,10 @@ runs:
|
||||
shell: sh
|
||||
env:
|
||||
DEBUG: 1
|
||||
DEV_RELEASE: 1
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
run: |
|
||||
uv pip install --upgrade pip cmake setuptools
|
||||
uv pip install nanobind==2.4.0 \
|
||||
numpy torch tensorflow unittest-xml-reporting
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
@@ -39,8 +20,12 @@ runs:
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
|
||||
- name: Install tests dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
uv pip install numpy torch tensorflow unittest-xml-reporting
|
||||
|
||||
- name: Run Python tests
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
@@ -52,7 +37,6 @@ runs:
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
|
||||
- name: Build example extension
|
||||
if: inputs.run-tests == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
cd examples/extensions
|
||||
@@ -61,7 +45,6 @@ runs:
|
||||
uv run --no-project test.py
|
||||
|
||||
- name: Build CPP only
|
||||
if: inputs.build-type == 'debug'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build
|
||||
@@ -70,7 +53,6 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run CPP tests
|
||||
if: ${{ inputs.build-type == 'debug' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
@@ -79,7 +61,6 @@ runs:
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build small binary with JIT
|
||||
if: inputs.build-jit == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build
|
||||
@@ -93,7 +74,6 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run Python tests with JIT
|
||||
if: ${{ inputs.build-jit == 'true' && inputs.run-tests == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
@@ -106,19 +86,3 @@ runs:
|
||||
uv run -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
- name: Build macOS 13 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 13.0
|
||||
- name: Build macOS 14 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 14.0
|
||||
- name: Build macOS 15 package
|
||||
if: inputs.build-type == 'release'
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
6
.github/actions/setup-macos/action.yml
vendored
6
.github/actions/setup-macos/action.yml
vendored
@@ -2,11 +2,6 @@ name: 'Setup macOS Environment'
|
||||
description: 'Install dependencies for macOS builds'
|
||||
|
||||
inputs:
|
||||
install-mpi:
|
||||
description: 'Whether to install MPI'
|
||||
required: false
|
||||
default: 'true'
|
||||
type: boolean
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
@@ -17,7 +12,6 @@ runs:
|
||||
steps:
|
||||
- name: Install Homebrew packages
|
||||
shell: sh
|
||||
if: inputs.install-mpi == 'true'
|
||||
run: /opt/homebrew/bin/brew install openmpi
|
||||
|
||||
- name: Verify MetalToolchain installed
|
||||
|
||||
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# [Setup] Install dependencies inside the container.
|
||||
dnf update -y
|
||||
dnf install -y \
|
||||
blas-devel \
|
||||
lapack-devel \
|
||||
openblas-devel \
|
||||
make \
|
||||
cmake \
|
||||
clang \
|
||||
git
|
||||
dnf clean all
|
||||
|
||||
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
|
||||
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
export DEBUG=1
|
||||
export CMAKE_C_COMPILER=/usr/bin/clang
|
||||
export CMAKE_CXX_COMPILER=/usr/bin/clang++
|
||||
|
||||
mkdir -p build
|
||||
pushd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
./tests/tests
|
||||
popd
|
||||
43
.github/workflows/nightly.yml
vendored
43
.github/workflows/nightly.yml
vendored
@@ -18,10 +18,9 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-type: release
|
||||
run-tests: false
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -50,12 +49,10 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.13"]
|
||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: "15.0"
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -64,7 +61,19 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 14.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
|
||||
build_cuda_with_tests:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: gpu-t4-4-core
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -74,6 +83,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-cuda
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -91,3 +101,24 @@ jobs:
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
retention-days: 7
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- host: ubuntu-22.04
|
||||
arch: x86_64
|
||||
- host: ubuntu-22.04-arm
|
||||
arch: aarch64
|
||||
|
||||
runs-on: ${{ matrix.host }}
|
||||
container:
|
||||
image: fedora:42
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: CPP Build Test - No Release
|
||||
run: |
|
||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||
|
||||
25
.github/workflows/pull_request.yml
vendored
25
.github/workflows/pull_request.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux
|
||||
|
||||
mac_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
needs: check_lint
|
||||
steps:
|
||||
@@ -29,6 +30,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
@@ -39,8 +41,31 @@ jobs:
|
||||
- uses: ./.github/actions/build-cuda
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- host: ubuntu-22.04
|
||||
arch: x86_64
|
||||
- host: ubuntu-22.04-arm
|
||||
arch: aarch64
|
||||
|
||||
runs-on: ${{ matrix.host }}
|
||||
container:
|
||||
image: fedora:42
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: CPP Build Test - No Release
|
||||
run: |
|
||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||
|
||||
106
.github/workflows/release.yml
vendored
106
.github/workflows/release.yml
vendored
@@ -10,7 +10,17 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
|
||||
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
|
||||
steps:
|
||||
- name: Set publishing variables
|
||||
run: echo "Publishing setup complete"
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -31,6 +41,7 @@ jobs:
|
||||
uses: actions/deploy-pages@v4
|
||||
|
||||
build_linux_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
@@ -42,10 +53,9 @@ jobs:
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-type: release
|
||||
run-tests: false
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -59,10 +69,10 @@ jobs:
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
@@ -71,9 +81,27 @@ jobs:
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
build-type: release
|
||||
macos-target: 14.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -87,6 +115,7 @@ jobs:
|
||||
path: dist/mlx_metal-*.whl
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
@@ -108,81 +137,90 @@ jobs:
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build_linux_release, build_mac_release]
|
||||
needs: [setup, build_linux_release, build_mac_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: linux-wheels-*
|
||||
merge-multiples: true
|
||||
path: artifacts
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: mac-wheels-*
|
||||
merge-multiples: true
|
||||
path: artifacts
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
pypi-publish-cuda:
|
||||
name: Upload CUDA release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_cuda_release
|
||||
needs: [setup, build_cuda_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: artifacts
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
pypi-publish-cpu:
|
||||
name: Upload CPU release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_linux_release
|
||||
needs: [setup, build_linux_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: artifacts
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
pypi-publish-metal:
|
||||
name: Upload Metal release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build_mac_release
|
||||
needs: [setup, build_mac_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-metal
|
||||
path: artifacts
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R artifacts
|
||||
# - name: Publish package distributions to PyPI
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
|
||||
@@ -127,9 +127,12 @@ if(MLX_BUILD_METAL)
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
||||
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
@@ -138,7 +141,6 @@ if(MLX_BUILD_METAL)
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
|
||||
@@ -17,11 +17,10 @@ To install from PyPI your system must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.10
|
||||
- macOS >= 13.5
|
||||
- macOS >= 14.0
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
MLX is only available on devices running macOS >= 14.0 and higher.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
@@ -7,12 +7,13 @@ Distributed Communication
|
||||
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
moment we support three different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
* A **ring** backend of our own that uses native TCP sockets. It should be
|
||||
faster for thunderbolt connections, but it also works over Ethernet.
|
||||
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
@@ -84,9 +85,8 @@ Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
||||
available backends. If they all fail then a singleton group is created.
|
||||
|
||||
.. note::
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
@@ -220,7 +220,7 @@ print 4 etc.
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
done automatically by ``mlx.launch``. Some environments use a non-standard
|
||||
library filename that can be specified using the ``MPI_LIBNAME`` environment
|
||||
variable. This is automatically taken care of by ``mlx.launch`` as well.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class Buffer {
|
||||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr) {};
|
||||
explicit Buffer(void* ptr) : ptr_(ptr) {};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
||||
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
cpy.array_desc_->offset = other.array_desc_->offset;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->offset = 0;
|
||||
array_desc_->data_size = size();
|
||||
array_desc_->flags.contiguous = true;
|
||||
array_desc_->flags.row_contiguous = true;
|
||||
@@ -156,7 +156,7 @@ void array::set_data(
|
||||
Flags flags,
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->offset = 0;
|
||||
array_desc_->data_size = data_size;
|
||||
array_desc_->strides = std::move(strides);
|
||||
array_desc_->flags = flags;
|
||||
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
array_desc_->offset =
|
||||
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
|
||||
}
|
||||
|
||||
void array::copy_shared_buffer(const array& other) {
|
||||
|
||||
23
mlx/array.h
23
mlx/array.h
@@ -294,6 +294,11 @@ class array {
|
||||
return array_desc_->siblings;
|
||||
}
|
||||
|
||||
/** The array's position in the sibling list. */
|
||||
int sibling_position() const {
|
||||
return array_desc_->position;
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
@@ -349,15 +354,23 @@ class array {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
// Return a raw pointer to the arrays data. This function may do a copy if
|
||||
// the underlying buffer is not accessible on the CPU. When accessing the
|
||||
// data for GPU kernels, be sure to use the correct method / function for the
|
||||
// given backend to access the GPU pointer.
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
return reinterpret_cast<T*>(
|
||||
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
return const_cast<array&>(*this).data<T>();
|
||||
}
|
||||
|
||||
int64_t offset() const {
|
||||
return array_desc_->offset;
|
||||
}
|
||||
|
||||
enum Status {
|
||||
@@ -461,8 +474,8 @@ class array {
|
||||
// can share the underlying data buffer.
|
||||
std::shared_ptr<Data> data;
|
||||
|
||||
// Properly offset data pointer
|
||||
void* data_ptr{nullptr};
|
||||
// Offset from beginning of data pointer
|
||||
int64_t offset{0};
|
||||
|
||||
// The size in elements of the data buffer the array accesses
|
||||
size_t data_size;
|
||||
|
||||
@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt) {
|
||||
BinaryOpType bopt,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
mallocfn(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
mallocfn(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
mallocfn(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
|
||||
@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
bool contiguous,
|
||||
const std::function<allocator::Buffer(size_t)>&
|
||||
mallocfn /* = allocator::malloc */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
Strides strides;
|
||||
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
mallocfn(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
bool contiguous,
|
||||
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
||||
allocator::malloc);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
|
||||
@@ -22,7 +22,11 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
inline bool set_copy_output_data(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
mallocfn(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ void slice(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
TernaryOpType topt) {
|
||||
TernaryOpType topt,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
auto maybe_donate = [&out](const array& x) {
|
||||
if (is_donatable(x, out)) {
|
||||
out.copy_shared_buffer(x);
|
||||
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
|
||||
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
mallocfn(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -7,19 +7,22 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
inline void set_unary_output_data(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
mallocfn(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,233 +14,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Add(), stream());
|
||||
}
|
||||
|
||||
void DivMod::eval_cpu(
|
||||
@@ -324,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
||||
}
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
||||
}
|
||||
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -372,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
});
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal(), stream());
|
||||
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
comparison_op_cpu(
|
||||
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
}
|
||||
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
||||
}
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Power(), stream());
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -463,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_int(a, b, out, detail::RightShift(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -484,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
binary_op<T, T, Op>(a, b, out, bopt);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -95,4 +95,9 @@ void Recv::eval_cpu(
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
@@ -135,15 +137,29 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle empty matrix case (K=0)
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
auto& c = inputs[2];
|
||||
if (beta_ == 1.0f) {
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
} else {
|
||||
array beta_scalar = array(beta_, c.dtype());
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
|
||||
encoder.add_temporary(std::move(beta_scalar));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
return;
|
||||
}
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
|
||||
@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -217,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::max(a.value, b.value);
|
||||
auto out = Simd<T, N>(asd::max(a.value, b.value));
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
out = select(isnan(b), b, select(isnan(a), a, out));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::min(a.value, b.value);
|
||||
auto out = Simd<T, N>(asd::min(a.value, b.value));
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
out = select(isnan(b), b, select(isnan(a), a, out));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
|
||||
@@ -32,6 +32,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
|
||||
next_free_ = next_free_->next;
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
b->buf.device = -1;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
@@ -88,14 +90,40 @@ CudaAllocator::CudaAllocator()
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
int curr;
|
||||
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(i));
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||
free_streams_.push_back(s);
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
void copy_to_managed(CudaBuffer& buf) {
|
||||
// TODO maybe make this async on a i/o stream to avoid synchronizing the
|
||||
// device on malloc/and free
|
||||
void* new_data;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
|
||||
buf.device = -1;
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
|
||||
CHECK_CUDA_ERROR(cudaFree(buf.data));
|
||||
buf.data = new_data;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
if (size == 0) {
|
||||
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
||||
}
|
||||
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size <= small_block_size) {
|
||||
@@ -106,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
int device = -1;
|
||||
if (size > small_block_size && stream != nullptr) {
|
||||
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
@@ -121,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
buf = new CudaBuffer{nullptr, size, device};
|
||||
cudaError_t err;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&buf->data, size);
|
||||
} else {
|
||||
err = cudaMallocAsync(&buf->data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
@@ -137,14 +175,30 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
// Copy to managed here if the buffer is not on the right device
|
||||
if (buf->device != device) {
|
||||
copy_to_managed(*buf);
|
||||
}
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
|
||||
return malloc_impl(size, stream);
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
return malloc_impl(size, nullptr);
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
if (buf->size == 0) {
|
||||
delete buf;
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
@@ -168,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
@@ -219,6 +277,16 @@ CudaAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream) {
|
||||
auto buffer = allocator().malloc_async(size, stream);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
@@ -231,7 +299,11 @@ void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
||||
if (cbuf.device != -1) {
|
||||
copy_to_managed(cbuf);
|
||||
}
|
||||
return cbuf.data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
#include "mlx/backend/cuda/cuda_utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
@@ -17,6 +19,7 @@ using allocator::Buffer;
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
int device; // -1 for managed
|
||||
};
|
||||
|
||||
class SmallSizePool {
|
||||
@@ -45,6 +48,7 @@ class SmallSizePool {
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
@@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
Buffer malloc_impl(size_t size, cudaStream_t stream);
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
@@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator {
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
std::vector<cudaStream_t> free_streams_;
|
||||
SmallSizePool scalar_pool_;
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
encoder.set_output_array(out);
|
||||
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<OutType>(out),
|
||||
out.data_size(),
|
||||
static_cast<CTYPE>(start_),
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
||||
|
||||
@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int32_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
num_blocks,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
gpu_ptr<T>(in),
|
||||
gpu_ptr<uint32_t>(out),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
|
||||
@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out),
|
||||
rest,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
@@ -365,7 +365,11 @@ void binary_op_gpu(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out_a),
|
||||
gpu_ptr<OutType>(out_b),
|
||||
rest,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out_a),
|
||||
gpu_ptr<OutType>(out_b),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
gpu_ptr<InType>(a),
|
||||
gpu_ptr<InType>(b),
|
||||
gpu_ptr<OutType>(out_a),
|
||||
gpu_ptr<OutType>(out_b),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, is_constant_, contiguous, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
for (auto& x : outputs) {
|
||||
args.append(x);
|
||||
}
|
||||
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
|
||||
kernel_name += fmt::format(
|
||||
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
|
||||
@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
if (out_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
array out = out_;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
|
||||
@@ -86,7 +86,7 @@ array unfold_inputs_nd(
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
gpu_ptr<DataType>(in),
|
||||
gpu_ptr<DataType>(unfolded),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
|
||||
@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
gpu_ptr<DataType>(in),
|
||||
gpu_ptr<DataType>(unfolded),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
|
||||
@@ -5,6 +5,22 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||
}
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -77,8 +77,8 @@ void copy_contiguous(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
gpu_ptr<InType>(in) + in_offset,
|
||||
gpu_ptr<OutType>(out) + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -106,8 +106,8 @@ void copy_general(
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
|
||||
@@ -69,8 +69,8 @@ void copy_general_dynamic(
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
@@ -90,8 +90,8 @@ void copy_general_dynamic(
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in),
|
||||
const_param<dims_constant()>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
gpu_ptr<int64_t>(dynamic_offset_in),
|
||||
gpu_ptr<int64_t>(dynamic_offset_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
@@ -107,8 +107,8 @@ void copy_general_dynamic(
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
gpu_ptr<int64_t>(dynamic_offset_in),
|
||||
gpu_ptr<int64_t>(dynamic_offset_out));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -92,8 +92,8 @@ void copy_general_input(
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
|
||||
82
mlx/backend/cuda/cuda_utils.h
Normal file
82
mlx/backend/cuda/cuda_utils.h
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
class CudaHandle {
|
||||
public:
|
||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||
|
||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||
assert(this != &other);
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
reset();
|
||||
}
|
||||
|
||||
CudaHandle(const CudaHandle&) = delete;
|
||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||
|
||||
CudaHandle& operator=(CudaHandle&& other) {
|
||||
assert(this != &other);
|
||||
reset();
|
||||
std::swap(handle_, other.handle_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
operator Handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Handle handle_;
|
||||
};
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
}; // namespace cu
|
||||
|
||||
// Wrappers of CUDA resources.
|
||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||
public:
|
||||
using CudaHandle::CudaHandle;
|
||||
explicit CudaGraph(cu::Device& device);
|
||||
void end_capture(cudaStream_t stream);
|
||||
};
|
||||
|
||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||
public:
|
||||
void instantiate(cudaGraph_t graph);
|
||||
};
|
||||
|
||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
|
||||
void** data_ptrs,
|
||||
F&& execute) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(
|
||||
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||
: allocator::Buffer(nullptr),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder.stream()),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
auto args = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setWorkspacePointer(workspace_ptr)
|
||||
.setDataPointers(num_args, data_ptrs)
|
||||
.setUids(num_args, uids)
|
||||
.build();
|
||||
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
@@ -23,7 +24,7 @@ class CommandEncoder;
|
||||
// Return pointer alignment of |x|'s data.
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
|
||||
// Helpers used by get_data_ptrs to get pointers.
|
||||
inline void* get_data_ptr(const array& arr) {
|
||||
return const_cast<void*>(arr.data<void>());
|
||||
return const_cast<void*>(gpu_ptr<void>(arr));
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||
|
||||
@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
|
||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||
|
||||
// Call the kernel
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : checked_inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto set_input_output =
|
||||
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto set_input_output = [&](const array& in,
|
||||
array& out) -> std::pair<array, array> {
|
||||
if (!in.flags().row_contiguous) {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
return {out, out};
|
||||
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
|
||||
out.copy_shared_buffer(in);
|
||||
return {in, out};
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
return {in, out};
|
||||
}
|
||||
};
|
||||
|
||||
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
auto& s = stream();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
distributed::detail::all_gather(group(), input, outputs[0], s);
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::sum_scatter(group(), input, outputs[0], s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only sum scatter is supported. ");
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -20,8 +22,24 @@ void Fence::wait(Stream s, const array&) {
|
||||
fence->event.wait(fence->count);
|
||||
}
|
||||
|
||||
void Fence::update(Stream s, const array&) {
|
||||
void Fence::update(Stream s, const array& a, bool cross_device) {
|
||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||
if (cross_device) {
|
||||
// Move to managed memory if there is a device switch
|
||||
auto& cbuf =
|
||||
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
|
||||
if (cbuf.device != -1) {
|
||||
void* new_data;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
||||
cbuf.device = -1;
|
||||
auto& encoder = cu::device(s.device).get_command_encoder(s);
|
||||
encoder.commit();
|
||||
CHECK_CUDA_ERROR(cudaMemcpyAsync(
|
||||
new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream()));
|
||||
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
|
||||
cbuf.data = new_data;
|
||||
}
|
||||
}
|
||||
fence->count++;
|
||||
fence->event.signal(s, fence->count);
|
||||
}
|
||||
|
||||
@@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
auto* bias_ptr = bias.data<void>();
|
||||
auto* bias_ptr = gpu_ptr<void>(bias);
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
@@ -278,9 +278,9 @@ void CublasGemm::run(
|
||||
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
gpu_ptr<void>(out),
|
||||
gpu_ptr<void>(a),
|
||||
gpu_ptr<void>(b),
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
@@ -321,10 +321,10 @@ void CublasGemm::run(
|
||||
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c.data<void>(),
|
||||
gpu_ptr<void>(out),
|
||||
gpu_ptr<void>(a),
|
||||
gpu_ptr<void>(b),
|
||||
gpu_ptr<void>(c),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
@@ -370,11 +370,11 @@ void CublasGemm::execute(
|
||||
// Ensure workspace is 256-byte aligned
|
||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||
array workspace(
|
||||
allocator::malloc(nbytes),
|
||||
cu::malloc_async(nbytes, encoder.stream()),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = workspace.data<void>();
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
|
||||
@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
execute(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
gpu_ptr<int8_t>(out) +
|
||||
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
||||
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
||||
nullptr,
|
||||
alpha);
|
||||
a_it.step();
|
||||
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
execute(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||
gpu_ptr<int8_t>(out) +
|
||||
out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
|
||||
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
|
||||
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
|
||||
alpha,
|
||||
beta);
|
||||
a_it.step();
|
||||
|
||||
@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
|
||||
{batch_count * 3},
|
||||
uint64);
|
||||
|
||||
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
gpu_ptr<int8_t*>(pointers),
|
||||
gpu_ptr<int8_t>(a),
|
||||
gpu_ptr<int8_t>(b),
|
||||
gpu_ptr<int8_t>(out),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
gpu_ptr<int8_t*>(pointers),
|
||||
gpu_ptr<int8_t>(a),
|
||||
gpu_ptr<int8_t>(b),
|
||||
gpu_ptr<int8_t>(out),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto a_pointers = gpu_ptr<int8_t*>(pointers);
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
execute(
|
||||
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
|
||||
{batch_count * 4},
|
||||
uint64);
|
||||
|
||||
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
gpu_ptr<int8_t*>(pointers),
|
||||
gpu_ptr<int8_t>(a),
|
||||
gpu_ptr<int8_t>(b),
|
||||
gpu_ptr<int8_t>(c),
|
||||
gpu_ptr<int8_t>(out),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
gpu_ptr<int8_t*>(pointers),
|
||||
gpu_ptr<int8_t>(a),
|
||||
gpu_ptr<int8_t>(b),
|
||||
gpu_ptr<int8_t>(c),
|
||||
gpu_ptr<int8_t>(out),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto a_pointers = gpu_ptr<int8_t*>(pointers);
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto c_pointers = b_pointers + batch_count;
|
||||
auto out_pointers = c_pointers + batch_count;
|
||||
|
||||
@@ -149,13 +149,13 @@ void gemv(
|
||||
auto vec_strides = const_param(b_batch_strides);
|
||||
|
||||
if (M == 1) {
|
||||
mat = b.data<DataType>();
|
||||
vec = a.data<DataType>();
|
||||
mat = gpu_ptr<DataType>(b);
|
||||
vec = gpu_ptr<DataType>(a);
|
||||
rows = N;
|
||||
std::swap(mat_strides, vec_strides);
|
||||
} else {
|
||||
mat = a.data<DataType>();
|
||||
vec = b.data<DataType>();
|
||||
mat = gpu_ptr<DataType>(a);
|
||||
vec = gpu_ptr<DataType>(b);
|
||||
rows = M;
|
||||
}
|
||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||
@@ -177,7 +177,7 @@ void gemv(
|
||||
0,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(out),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
@@ -189,7 +189,7 @@ void gemv(
|
||||
0,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(out),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
|
||||
@@ -31,7 +31,7 @@ void append_indices_arg(
|
||||
int idx_ndim) {
|
||||
SmallVector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
indices[i] = gpu_ptr<void>(inputs[i + 1]);
|
||||
}
|
||||
args.append(std::move(indices));
|
||||
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() > 0);
|
||||
const auto& src = inputs[0];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_string(idx_dtype),
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()));
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ struct KernelArgs {
|
||||
}
|
||||
|
||||
void append(const array& a) {
|
||||
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
|
||||
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(b),
|
||||
gpu_ptr<DataType>(out),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
|
||||
g_in_gw = true;
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
|
||||
60
mlx/backend/cuda/load.cpp
Normal file
60
mlx/backend/cuda/load.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
|
||||
Elem* data = reinterpret_cast<Elem*>(data_bytes);
|
||||
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
for (size_t j = 0; j < (scalar_size / 2); j++) {
|
||||
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
auto size = out.size();
|
||||
auto nbytes = size * out.itemsize();
|
||||
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
|
||||
auto out_ptr = malloc(nbytes);
|
||||
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaMemcpyAsync(
|
||||
gpu_ptr<void>(out),
|
||||
out_ptr,
|
||||
nbytes,
|
||||
cudaMemcpyDefault,
|
||||
encoder.stream()));
|
||||
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
cu::malloc_async(in.nbytes() / n, encoder.stream()),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(in),
|
||||
gpu_ptr<DataType>(out),
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||
c.data_size() == out.shape(-1)) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
M,
|
||||
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto sty = c.strides()[c.ndim() - 1];
|
||||
if (sty == 1 && stx == c.shape(-1)) {
|
||||
ldc = stx;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
} else if (sty == 1 && stx == 0) {
|
||||
ldc = 0;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
} else {
|
||||
// Copy C into out and set C to out
|
||||
ldc = c.shape(-1);
|
||||
|
||||
@@ -28,7 +28,6 @@ NO_GPU(FFT)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
@@ -40,7 +39,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
@@ -262,10 +262,10 @@ void affine_quantize(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
gpu_ptr<T>(w),
|
||||
gpu_ptr<uint8_t>(wq),
|
||||
gpu_ptr<T>(scales),
|
||||
gpu_ptr<T>(biases),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
@@ -318,10 +318,10 @@ void affine_dequantize(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.data<T>(),
|
||||
gpu_ptr<uint8_t>(wq),
|
||||
gpu_ptr<T>(scales),
|
||||
gpu_ptr<T>(biases),
|
||||
gpu_ptr<T>(w),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -156,9 +156,9 @@ void fp_quantize(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<uint8_t>(),
|
||||
gpu_ptr<T>(w),
|
||||
gpu_ptr<uint8_t>(wq),
|
||||
gpu_ptr<uint8_t>(scales),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
@@ -202,9 +202,9 @@ void fp_dequantize(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
w.data<T>(),
|
||||
gpu_ptr<uint8_t>(wq),
|
||||
gpu_ptr<uint8_t>(scales),
|
||||
gpu_ptr<T>(w),
|
||||
w.size());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
|
||||
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
|
||||
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||
|
||||
@@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
bool odd = out_per_key % 2;
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
@@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
gpu_ptr<uint32_t>(keys),
|
||||
gpu_ptr<uint8_t>(out),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
@@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
gpu_ptr<uint32_t>(keys),
|
||||
gpu_ptr<uint8_t>(out),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key,
|
||||
|
||||
@@ -66,7 +66,7 @@ void all_reduce(
|
||||
Reduce::ReduceType reduce_type) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
int threads = std::min(512UL, (size + N - 1) / N);
|
||||
@@ -100,14 +100,15 @@ void all_reduce(
|
||||
Dtype dt = in.dtype();
|
||||
|
||||
// Cub doesn't like const pointers for load (sigh).
|
||||
void* indata = const_cast<void*>(in.data<void>());
|
||||
void* indata = const_cast<void*>(gpu_ptr<void>(in));
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
encoder.set_input_array(in);
|
||||
if (blocks > 1) {
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_output_array(intermediate);
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
@@ -122,14 +123,14 @@ void all_reduce(
|
||||
threads,
|
||||
0,
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
gpu_ptr<U>(intermediate),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
});
|
||||
|
||||
// Set the input for the next step and recalculate the blocks
|
||||
indata = intermediate.data<void>();
|
||||
indata = gpu_ptr<void>(intermediate);
|
||||
dt = intermediate.dtype();
|
||||
insize = intermediate.size();
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
@@ -149,7 +150,7 @@ void all_reduce(
|
||||
threads,
|
||||
0,
|
||||
static_cast<T*>(indata),
|
||||
out.data<U>(),
|
||||
gpu_ptr<U>(out),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
|
||||
@@ -250,7 +250,7 @@ void col_reduce_looped(
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
@@ -261,7 +261,7 @@ void col_reduce_looped(
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
@@ -276,7 +276,7 @@ void col_reduce_looped(
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
out.data<U>(),
|
||||
gpu_ptr<U>(out),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
});
|
||||
});
|
||||
@@ -293,7 +293,7 @@ void col_reduce_small(
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
@@ -312,8 +312,8 @@ void col_reduce_small(
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
gpu_ptr<T>(in),
|
||||
gpu_ptr<U>(out),
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size());
|
||||
});
|
||||
|
||||
@@ -28,7 +28,7 @@ void init_reduce(
|
||||
Reduce::ReduceType reduce_type) {
|
||||
// Allocate if needed
|
||||
if (out.data_shared_ptr() == nullptr) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
@@ -42,7 +42,7 @@ void init_reduce(
|
||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||
grid.x = (grid.x + 1023) / 1024;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, out.data<U>(), out.size());
|
||||
kernel, grid, block, 0, gpu_ptr<U>(out), out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
@@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
||||
inline void allocate_same_layout(
|
||||
array& out,
|
||||
const array& in,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
cu::CommandEncoder& encoder) {
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -133,7 +135,7 @@ inline void allocate_same_layout(
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = true;
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
data_size,
|
||||
final_strides,
|
||||
fl,
|
||||
|
||||
@@ -238,7 +238,7 @@ void row_reduce_simple(
|
||||
const ReductionPlan& plan) {
|
||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||
@@ -268,10 +268,10 @@ void row_reduce_simple(
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||
kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -286,7 +286,7 @@ void row_reduce_looped(
|
||||
cu::RowReduceArgs args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
allocate_same_layout(out, in, axes, encoder);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
@@ -315,7 +315,7 @@ void row_reduce_looped(
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
||||
kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
|
||||
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
@@ -223,9 +223,9 @@ void RMSNorm::eval_gpu(
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(out),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
@@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu(
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
|
||||
@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
|
||||
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
if (with_freqs) {
|
||||
@@ -340,9 +340,9 @@ void RoPE::eval_gpu(
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
gpu_ptr<DataType>(donated ? out : in),
|
||||
gpu_ptr<DataType>(out),
|
||||
gpu_ptr<int32_t>(offset),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
@@ -357,10 +357,10 @@ void RoPE::eval_gpu(
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
gpu_ptr<DataType>(donated ? out : in),
|
||||
gpu_ptr<DataType>(out),
|
||||
gpu_ptr<int32_t>(offset),
|
||||
gpu_ptr<float>(inputs[2]),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
@@ -381,10 +381,10 @@ void RoPE::eval_gpu(
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
gpu_ptr<DataType>(donated ? out : in),
|
||||
gpu_ptr<DataType>(out),
|
||||
gpu_ptr<int32_t>(offset),
|
||||
gpu_ptr<float>(inputs[2]),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
@@ -408,9 +408,9 @@ void RoPE::eval_gpu(
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
gpu_ptr<DataType>(donated ? out : in),
|
||||
gpu_ptr<DataType>(out),
|
||||
gpu_ptr<int32_t>(offset),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
|
||||
@@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback(
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
gpu_ptr<DataType>(q),
|
||||
gpu_ptr<DataType>(k),
|
||||
gpu_ptr<DataType>(v),
|
||||
gpu_ptr<DataType>(o),
|
||||
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
|
||||
params);
|
||||
});
|
||||
});
|
||||
@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
|
||||
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
|
||||
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.add_temporary(sums);
|
||||
@@ -601,13 +602,13 @@ void sdpa_vector_2pass_fallback(
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
gpu_ptr<DataType>(q),
|
||||
gpu_ptr<DataType>(k),
|
||||
gpu_ptr<DataType>(v),
|
||||
sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
|
||||
gpu_ptr<float>(intermediate),
|
||||
gpu_ptr<float>(sums),
|
||||
gpu_ptr<float>(maxs),
|
||||
params);
|
||||
}
|
||||
|
||||
@@ -628,10 +629,10 @@ void sdpa_vector_2pass_fallback(
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
o.data<DataType>(),
|
||||
gpu_ptr<float>(intermediate),
|
||||
gpu_ptr<float>(sums),
|
||||
gpu_ptr<float>(maxs),
|
||||
gpu_ptr<DataType>(o),
|
||||
params);
|
||||
}
|
||||
});
|
||||
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
cu::malloc_async(o.nbytes(), encoder.stream()),
|
||||
o.size(),
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto in = inputs[0];
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int32_t axis_size = in.shape(axis_);
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
@@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in.data_size() / axis_size,
|
||||
block_dim,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
gpu_ptr<T>(in),
|
||||
gpu_ptr<U>(out),
|
||||
axis_size);
|
||||
} else {
|
||||
constexpr int BM = WARP_SIZE;
|
||||
@@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
num_blocks,
|
||||
block_dim,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
gpu_ptr<T>(in),
|
||||
gpu_ptr<U>(out),
|
||||
axis_size,
|
||||
stride,
|
||||
stride_blocks);
|
||||
|
||||
@@ -23,14 +23,15 @@ void concatenate_gpu(
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
auto concurrent = cu::get_command_encoder(s).concurrent_context();
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
|
||||
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
||||
});
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
// Prepare output.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(offset);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
|
||||
@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
auto set_output = [&s, &out, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||
@@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
gpu_ptr<DataType>(in),
|
||||
gpu_ptr<DataType>(out),
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = contiguous_copy_gpu(trans, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
out = array(
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
@@ -70,22 +73,28 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
array indices(
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
array discard(
|
||||
cu::malloc_async(in.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
gpu_ptr<Type>(in),
|
||||
gpu_ptr<Type>(discard),
|
||||
gpu_ptr<uint32_t>(indices),
|
||||
gpu_ptr<uint32_t>(out),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
array temp(
|
||||
cu::malloc_async(size, encoder.stream()),
|
||||
{static_cast<int>(size)},
|
||||
uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
@@ -103,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
thrust::device_pointer_cast(gpu_ptr<uint32_t>(indices)),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
temp.data<void>(),
|
||||
gpu_ptr<void>(temp),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
gpu_ptr<Type>(in),
|
||||
gpu_ptr<Type>(discard),
|
||||
gpu_ptr<uint32_t>(indices),
|
||||
gpu_ptr<uint32_t>(out),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
@@ -125,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
gpu_ptr<Type>(in),
|
||||
gpu_ptr<Type>(out),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
@@ -135,16 +147,19 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
array temp(
|
||||
cu::malloc_async(size, encoder.stream()),
|
||||
{static_cast<int>(size)},
|
||||
uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
temp.data<void>(),
|
||||
gpu_ptr<void>(temp),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
gpu_ptr<Type>(in),
|
||||
gpu_ptr<Type>(out),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
|
||||
@@ -168,10 +168,10 @@ void ternary_op_gpu_inplace(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
gpu_ptr<bool>(a),
|
||||
gpu_ptr<DType>(b),
|
||||
gpu_ptr<DType>(c),
|
||||
gpu_ptr<DType>(out),
|
||||
out.data_size());
|
||||
});
|
||||
} else {
|
||||
@@ -211,10 +211,10 @@ void ternary_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
gpu_ptr<bool>(a),
|
||||
gpu_ptr<DType>(b),
|
||||
gpu_ptr<DType>(c),
|
||||
gpu_ptr<DType>(out),
|
||||
rest,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
@@ -231,10 +231,10 @@ void ternary_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
gpu_ptr<bool>(a),
|
||||
gpu_ptr<DType>(b),
|
||||
gpu_ptr<DType>(c),
|
||||
gpu_ptr<DType>(out),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
@@ -256,7 +256,10 @@ void ternary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -158,8 +158,8 @@ void unary_op_gpu_inplace(
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<InType>(in),
|
||||
gpu_ptr<OutType>(out),
|
||||
out.data_size());
|
||||
} else {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
@@ -182,8 +182,8 @@ void unary_op_gpu_inplace(
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
gpu_ptr<InType>(in),
|
||||
gpu_ptr<OutType>(out),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(strides),
|
||||
@@ -207,7 +207,10 @@ void unary_op_gpu(
|
||||
array& out,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_unary_output_data(inputs[0], out, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,89 +4,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/cuda_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
class CudaHandle {
|
||||
public:
|
||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||
|
||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||
assert(this != &other);
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
reset();
|
||||
}
|
||||
|
||||
CudaHandle(const CudaHandle&) = delete;
|
||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||
|
||||
CudaHandle& operator=(CudaHandle&& other) {
|
||||
assert(this != &other);
|
||||
reset();
|
||||
std::swap(handle_, other.handle_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
operator Handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Handle handle_;
|
||||
};
|
||||
|
||||
// Wrappers of CUDA resources.
|
||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||
public:
|
||||
using CudaHandle::CudaHandle;
|
||||
explicit CudaGraph(cu::Device& device);
|
||||
void end_capture(cudaStream_t stream);
|
||||
};
|
||||
|
||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||
public:
|
||||
void instantiate(cudaGraph_t graph);
|
||||
};
|
||||
|
||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
@@ -100,4 +23,22 @@ inline uint max_occupancy_block_dim(T kernel) {
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* gpu_ptr(array& arr) {
|
||||
return reinterpret_cast<T*>(
|
||||
static_cast<char*>(
|
||||
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
|
||||
arr.offset());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T* gpu_ptr(const array& arr) {
|
||||
return gpu_ptr<T>(const_cast<array&>(arr));
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -7,18 +7,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
@@ -52,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||
int ndim = x.ndim();
|
||||
if (start_axis < 0) {
|
||||
|
||||
@@ -83,7 +83,7 @@ void Depends::eval_gpu(
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu(
|
||||
array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,14 +21,8 @@ function(make_jit_source SRC_FILE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/jit/bf16.h
|
||||
kernels/metal_3_0/bf16.h
|
||||
kernels/metal_3_1/bf16.h
|
||||
kernels/bf16_math.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(utils kernels/bf16.h kernels/bf16_math.h kernels/complex.h
|
||||
kernels/defines.h)
|
||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
|
||||
@@ -10,6 +10,19 @@ namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_gpu_inplace(in, out, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@@ -201,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -21,12 +21,12 @@ constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto get_metal_version() {
|
||||
auto get_metal_version_ = []() {
|
||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
if (__builtin_available(macOS 26, iOS 26, tvOS 26, visionOS 26, *)) {
|
||||
return MTL::LanguageVersion4_0;
|
||||
} else if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||
return MTL::LanguageVersion3_2;
|
||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
||||
return MTL::LanguageVersion3_1;
|
||||
} else {
|
||||
return MTL::LanguageVersion3_0;
|
||||
return MTL::LanguageVersion3_1;
|
||||
}
|
||||
};
|
||||
static auto metal_version_ = get_metal_version_();
|
||||
@@ -119,8 +119,10 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
// if SWIFTPM_BUNDLE is a framework identifier, try loading from that
|
||||
auto frameworks = NS::Bundle::allFrameworks();
|
||||
for (int i = 0, c = (int)frameworks->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));
|
||||
if (!strcmp(bundle->bundleIdentifier()->utf8String(), SWIFTPM_BUNDLE)) {
|
||||
const auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));
|
||||
const auto identifier = bundle->bundleIdentifier();
|
||||
if (identifier != nullptr &&
|
||||
!strcmp(identifier->utf8String(), SWIFTPM_BUNDLE)) {
|
||||
library = try_load_framework(device, bundle->resourceURL(), lib_name);
|
||||
if (library != nullptr) {
|
||||
return {library, nullptr};
|
||||
@@ -259,10 +261,7 @@ void CommandEncoder::set_input_array(
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc_->setBuffer(a_buf, base_offset, idx);
|
||||
enc_->setBuffer(a_buf, a.offset() + offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
@@ -446,10 +445,8 @@ void Device::end_encoding(int index) {
|
||||
auto& enc = *stream.encoder;
|
||||
// Remove temporaries from inputs and outputs
|
||||
for (auto& t : stream.temporaries) {
|
||||
if (t.data<void>() != nullptr) {
|
||||
enc.outputs().erase(t.buffer().ptr());
|
||||
enc.inputs().erase(t.buffer().ptr());
|
||||
}
|
||||
enc.outputs().erase(t.buffer().ptr());
|
||||
enc.inputs().erase(t.buffer().ptr());
|
||||
}
|
||||
|
||||
// Keep references to the fences we waited on and put them
|
||||
|
||||
@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error(
|
||||
"[ReduceScatter::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -31,7 +31,7 @@ struct FenceImpl {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(fence)->release();
|
||||
} else {
|
||||
allocator::free(static_cast<MTL::Buffer*>(fence));
|
||||
allocator::free(allocator::Buffer{static_cast<MTL::Buffer*>(fence)});
|
||||
}
|
||||
}
|
||||
bool use_fast{false};
|
||||
@@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void Fence::update(Stream stream, const array& x) {
|
||||
void Fence::update(Stream stream, const array& x, bool cross_device) {
|
||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||
f.count++;
|
||||
|
||||
@@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) {
|
||||
|
||||
// Launch input visibility kernels
|
||||
auto& compute_encoder = d.get_command_encoder(idx);
|
||||
auto kernel = d.get_kernel("input_coherent");
|
||||
uint32_t nthreads =
|
||||
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
|
||||
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
if (cross_device) {
|
||||
auto kernel = d.get_kernel("input_coherent");
|
||||
uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) /
|
||||
sizeof(uint32_t);
|
||||
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Barrier on previous kernels
|
||||
compute_encoder.barrier();
|
||||
|
||||
// Launch value update kernel
|
||||
kernel = d.get_kernel("fence_update");
|
||||
auto kernel = d.get_kernel("fence_update");
|
||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
|
||||
@@ -768,9 +768,12 @@ void nd_fft_op(
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
std::vector<array> temp_arrs;
|
||||
temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});
|
||||
if (axes.size() > 2) {
|
||||
temp_arrs.emplace_back(
|
||||
temp_shape, complex64, nullptr, std::vector<array>{});
|
||||
}
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
@@ -781,8 +784,8 @@ void nd_fft_op(
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
set(BASE_HEADERS
|
||||
metal_3_1/bf16.h
|
||||
metal_3_0/bf16.h
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
@@ -18,16 +17,9 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(VERSION_INCLUDES
|
||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#define jit_if #if
|
||||
#define jit_else #else
|
||||
#define jit_endif #endif
|
||||
|
||||
jit_if (__METAL_VERSION__ >= 310)
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
||||
|
||||
jit_else
|
||||
|
||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
||||
|
||||
jit_endif // clang-format on
|
||||
@@ -1,314 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
return x != x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
||||
return x.bits_;
|
||||
}
|
||||
|
||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
@@ -4,11 +4,7 @@
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
// The correct bf16.h is included based on the metal version
|
||||
// by giving the correct path to -I during compilation
|
||||
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
|
||||
#include "bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||
#include "mlx/backend/metal/kernels/complex.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -925,19 +926,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy c into out and return
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Handle empty matrix case (K=0)
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
copy_gpu(
|
||||
inputs[2],
|
||||
out,
|
||||
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
auto& c = inputs[2];
|
||||
if (beta_ == 1.0f) {
|
||||
copy_gpu(
|
||||
c,
|
||||
out,
|
||||
c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
array beta_scalar = array(beta_, c.dtype());
|
||||
binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
|
||||
d.add_temporary(std::move(beta_scalar), s.index);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
@@ -13,7 +13,7 @@ bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void start_capture(std::string path, id object) {
|
||||
void start_capture(std::string path, NS::Object* object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user