mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
119 Commits
v0.26.3
...
3bb6b1d44a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bb6b1d44a | ||
|
|
4ee0d0bb55 | ||
|
|
cd53eb1ae3 | ||
|
|
f7c11b965e | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
984cefb14d | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
dadf8d9c93 | ||
|
|
389276e2b8 | ||
|
|
2e255c8eb4 | ||
|
|
062aa80b84 | ||
|
|
f540b1d612 | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea |
@@ -7,18 +7,9 @@ parameters:
|
|||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
test_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
cuda_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -73,9 +64,9 @@ jobs:
|
|||||||
git push -f origin gh-pages
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
machine:
|
||||||
- image: cimg/python:3.9
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -87,34 +78,35 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
pip install nanobind==2.4.0
|
export NEEDRESTART_MODE=a
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv venv
|
||||||
python3 setup.py build_ext --inplace
|
uv pip install cmake
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv pip install -e ".[dev]" -v
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
source .venv/bin/activate
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build
|
mkdir -p build && cd build
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
make -j `nproc`
|
make -j `nproc`
|
||||||
@@ -140,50 +132,49 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||||
brew install openmpi
|
brew install openmpi uv
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv venv --python 3.9
|
||||||
|
uv pip install \
|
||||||
|
nanobind==2.4.0 \
|
||||||
|
cmake \
|
||||||
|
numpy \
|
||||||
|
torch \
|
||||||
|
tensorflow \
|
||||||
|
unittest-xml-reporting
|
||||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
pip install -e . -v
|
uv pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd examples/extensions
|
cd examples/extensions
|
||||||
pip install -r requirements.txt
|
uv pip install -r requirements.txt
|
||||||
python setup.py build_ext -j8
|
uv run --no-project setup.py build_ext --inplace
|
||||||
|
uv run --no-project python test.py
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
@@ -192,7 +183,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build small binary
|
name: Build small binary
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd build/
|
cd build/
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
@@ -204,34 +195,60 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
pip install -e . -v
|
uv pip install -e .
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
uv run --no-project python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
|
|
||||||
cuda_build_and_test:
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:default
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: gpu.nvidia.small.gen2
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
uv venv
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
python -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
pip install -e ".[dev]"
|
uv pip install -e ".[dev]" -v
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||||
|
- run:
|
||||||
|
name: CCache report
|
||||||
|
command: |
|
||||||
|
ccache --show-stats
|
||||||
|
ccache --zero-stats
|
||||||
|
ccache --max-size 400MB
|
||||||
|
ccache --cleanup
|
||||||
|
- save_cache:
|
||||||
|
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||||
|
paths:
|
||||||
|
- /home/circleci/.cache/ccache
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
@@ -284,7 +301,18 @@ jobs:
|
|||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> python -m build -w
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
- when:
|
||||||
|
condition:
|
||||||
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
steps:
|
steps:
|
||||||
@@ -301,88 +329,100 @@ jobs:
|
|||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
extra_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: "DEV_RELEASE=1"
|
default: ""
|
||||||
docker:
|
machine:
|
||||||
- image: ubuntu:20.04
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
PYTHON=python<< parameters.python_version >>
|
PYTHON=python<< parameters.python_version >>
|
||||||
apt-get update
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
apt-get upgrade -y
|
export NEEDRESTART_MODE=a
|
||||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
sudo apt-get update
|
||||||
apt-get install -y apt-utils
|
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||||
apt-get install -y software-properties-common
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa
|
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
apt-get install -y build-essential git
|
|
||||||
$PYTHON -m venv env
|
$PYTHON -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> pip install . -v
|
<< parameters.build_env >> pip install ".[dev]" -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> python -m build --wheel
|
python setup.py clean --all
|
||||||
auditwheel show dist/*
|
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
bash python/scripts/repair_linux.sh
|
||||||
- run:
|
- when:
|
||||||
name: Upload package
|
condition:
|
||||||
command: |
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
source env/bin/activate
|
steps:
|
||||||
twine upload wheelhouse/*
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload packages
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: ""
|
||||||
extra_env:
|
|
||||||
type: string
|
|
||||||
default: "DEV_RELEASE=1"
|
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:default
|
image: ubuntu-2204:current
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export NEEDRESTART_MODE=a
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
python -m venv env
|
sudo apt-get install zip
|
||||||
source env/bin/activate
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
pip install ".[dev]" -v
|
python -m build -w
|
||||||
python setup.py generate_stubs
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
python -m build --wheel
|
|
||||||
bash python/scripts/repair_cuda.sh
|
bash python/scripts/repair_cuda.sh
|
||||||
- run:
|
- when:
|
||||||
name: Upload package
|
condition: << parameters.build_env >>
|
||||||
command: |
|
steps:
|
||||||
source env/bin/activate
|
- run:
|
||||||
twine upload wheelhouse/*.whl
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
@@ -394,7 +434,6 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
@@ -402,14 +441,16 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- cuda_build_and_test
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
@@ -501,7 +542,16 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
@@ -522,6 +572,9 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
- cuda_build_and_test:
|
- cuda_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -580,11 +633,17 @@ workflows:
|
|||||||
- macosx_deployment_target: "15.0"
|
- macosx_deployment_target: "15.0"
|
||||||
xcode_version: "15.0.0"
|
xcode_version: "15.0.0"
|
||||||
python_version: "3.13"
|
python_version: "3.13"
|
||||||
weekly_build:
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
- build_cuda_release
|
||||||
|
|
||||||
|
build_dev_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
- << pipeline.parameters.weekly_build >>
|
- << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -654,25 +713,12 @@ workflows:
|
|||||||
xcode_version: "15.0.0"
|
xcode_version: "15.0.0"
|
||||||
python_version: "3.13"
|
python_version: "3.13"
|
||||||
build_env: "DEV_RELEASE=1"
|
build_env: "DEV_RELEASE=1"
|
||||||
linux_test_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.linux_release >>
|
|
||||||
jobs:
|
|
||||||
- build_linux_release:
|
- build_linux_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
cuda_test_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.cuda_release >>
|
|
||||||
jobs:
|
|
||||||
- build_cuda_release:
|
- build_cuda_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
|||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@@ -64,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
endif()
|
||||||
|
|
||||||
|
if(MLX_USE_CCACHE)
|
||||||
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -234,12 +243,16 @@ target_include_directories(
|
|||||||
# Do not add mlx_EXPORTS define for shared library.
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
FetchContent_Declare(
|
if(USE_SYSTEM_FMT)
|
||||||
fmt
|
find_package(fmt REQUIRED)
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
else()
|
||||||
GIT_TAG 10.2.1
|
FetchContent_Declare(
|
||||||
EXCLUDE_FROM_ALL)
|
fmt
|
||||||
FetchContent_MakeAvailable(fmt)
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
|
|||||||
@@ -192,6 +192,22 @@ void time_reductions() {
|
|||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ relying on a copy from ``ensure_row_contiguous``:
|
|||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source
|
source=source
|
||||||
|
ensure_row_contiguous=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
|
|||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes=[a.shape],
|
output_shapes=[a.shape],
|
||||||
output_dtypes=[a.dtype],
|
output_dtypes=[a.dtype],
|
||||||
ensure_row_contiguous=False,
|
|
||||||
)
|
)
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -394,14 +394,14 @@ below.
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::stream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname = "axpby_general_" + type_to_name(out);
|
||||||
|
|
||||||
// Load the metal library
|
// Load the metal library
|
||||||
auto lib = d.get_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), lib);
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.9
|
||||||
@@ -23,22 +23,39 @@ To install from PyPI you must meet the following requirements:
|
|||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 13.5
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
|
||||||
|
|
||||||
CUDA
|
CUDA
|
||||||
^^^^
|
^^^^
|
||||||
|
|
||||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
MLX has a CUDA backend which you can install with:
|
||||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx-cuda
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -254,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
|||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
z = x + y
|
z = x + y
|
||||||
state.append(z)
|
state.append(z)
|
||||||
return mx.exp(z), state
|
return mx.exp(z)
|
||||||
|
|
||||||
fun(mx.array(1.0), mx.array(2.0))
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
# Prints [array(3, dtype=float32)]
|
# Prints [array(3, dtype=float32)]
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2025 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -16,6 +17,19 @@
|
|||||||
|
|
||||||
namespace my_ext {
|
namespace my_ext {
|
||||||
|
|
||||||
|
// A helper function to find the location of the current binary on disk.
|
||||||
|
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||||
|
std::string current_binary_dir() {
|
||||||
|
static std::string binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::ostringstream kname;
|
std::string kname = "axpby_";
|
||||||
kname << "axpby_";
|
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname += type_to_name(out);
|
||||||
kname << type_to_name(out);
|
|
||||||
|
|
||||||
// Load the metal library
|
// Load the metal library
|
||||||
auto lib = d.get_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), lib);
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.2.0
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c_cpu.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c_cpu.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||||
|
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -18,8 +19,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = std::vector<ShapeElem>;
|
using Shape = SmallVector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
std::filesystem::path current_binary_dir() {
|
||||||
std::ostringstream op_t;
|
static std::filesystem::path binary_dir = []() {
|
||||||
primitive->print(op_t);
|
Dl_info info;
|
||||||
return op_t.str();
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -9,7 +10,8 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
// Return the directory that contains current shared library.
|
||||||
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
@@ -195,7 +197,7 @@ void shared_buffer_reshape(
|
|||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
vec.erase(std::next(vec.begin(), index));
|
vec.erase(std::next(vec.begin(), index));
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
|||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -157,10 +157,12 @@ inline void build_kernel(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
os << "void " << kernel_name
|
||||||
|
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
int strides_index = 1;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(i)) {
|
if (is_constant(i)) {
|
||||||
@@ -175,8 +177,8 @@ inline void build_kernel(
|
|||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
os << " const int64_t* " << xname << "_strides = strides["
|
||||||
<< "];" << std::endl;
|
<< strides_index++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,10 +188,8 @@ inline void build_kernel(
|
|||||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output size
|
||||||
if (!contiguous) {
|
if (contiguous) {
|
||||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
|
||||||
} else {
|
|
||||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
@@ -290,7 +290,6 @@ void Compiled::eval_cpu(
|
|||||||
|
|
||||||
// Collect function input arguments.
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
int strides_index = 1;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant_(i)) {
|
if (is_constant_(i)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -298,9 +297,6 @@ void Compiled::eval_cpu(
|
|||||||
const auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
if (!contiguous && !is_scalar(x)) {
|
|
||||||
args.push_back(strides[strides_index++].data());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
@@ -335,16 +331,20 @@ void Compiled::eval_cpu(
|
|||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
if (!contiguous) {
|
if (contiguous) {
|
||||||
args.push_back((void*)shape.data());
|
|
||||||
} else {
|
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||||
encoder.dispatch([fun,
|
encoder.dispatch([fun,
|
||||||
args = std::move(args),
|
args = std::move(args),
|
||||||
strides = std::move(strides),
|
strides = std::move(strides),
|
||||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
shape = std::move(shape)]() mutable {
|
||||||
|
SmallVector<int64_t*> strides_ptrs;
|
||||||
|
for (auto& s : strides) {
|
||||||
|
strides_ptrs.push_back(s.data());
|
||||||
|
}
|
||||||
|
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
temps.push_back(array(0, conv_dtype));
|
temps.push_back(array(0, conv_dtype));
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||||
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
in_padded_slice.size(),
|
in_padded_slice.size(),
|
||||||
data_offset);
|
data_offset);
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
Shape strided_reshape = {N * oH, wH * C};
|
Shape strided_reshape = {N * oH, wH * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
wt.size(),
|
wt.size(),
|
||||||
0);
|
0);
|
||||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||||
copy(wt_transpose, gemm_wt, CopyType::General, stream);
|
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
temps.push_back(array(0, conv_dtype));
|
temps.push_back(array(0, conv_dtype));
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||||
@@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||||
@@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
// Materialize strided view
|
// Materialize strided view
|
||||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
@@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
std::vector<array> temps = {array(0, conv_dtype)};
|
std::vector<array> temps = {array(0, conv_dtype)};
|
||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
@@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
data_offset);
|
data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||||
temps.push_back(in_padded_slice);
|
temps.push_back(in_padded_slice);
|
||||||
|
|
||||||
// Make strided view
|
// Make strided view
|
||||||
@@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy(in_strided_view, in_strided, CopyType::General, stream);
|
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||||
temps.push_back(in_strided);
|
temps.push_back(in_strided);
|
||||||
|
|
||||||
// Check wt dtype and prepare
|
// Check wt dtype and prepare
|
||||||
@@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
auto ctype =
|
auto ctype =
|
||||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
copy(wt, gemm_wt, ctype, stream);
|
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||||
temps.push_back(gemm_wt);
|
temps.push_back(gemm_wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flip) {
|
if (flip) {
|
||||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||||
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
|
||||||
temps.push_back(gemm_wt_);
|
temps.push_back(gemm_wt_);
|
||||||
|
|
||||||
// Calculate the total size of the spatial dimensions
|
// Calculate the total size of the spatial dimensions
|
||||||
@@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Copy results if needed
|
// Copy results if needed
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
copy_inplace(gemm_out, out, CopyType::Vector, stream);
|
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||||
}
|
}
|
||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream) {
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(src);
|
encoder.set_input_array(src);
|
||||||
encoder.set_output_array(dst);
|
encoder.set_output_array(dst);
|
||||||
@@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
bool donated = set_copy_output_data(src, dst, ctype);
|
bool donated = set_copy_output_data(src, dst, ctype);
|
||||||
if (donated && src.dtype() == dst.dtype()) {
|
if (donated && src.dtype() == dst.dtype()) {
|
||||||
// If the output has the same type as the input then there is nothing to
|
// If the output has the same type as the input then there is nothing to
|
||||||
@@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy_inplace(src, dst, ctype, stream);
|
copy_cpu_inplace(src, dst, ctype, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -373,4 +377,10 @@ void copy_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -10,10 +10,14 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream);
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -26,4 +30,7 @@ void copy_inplace(
|
|||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return {contiguous_copy_cpu(arr, stream), true};
|
||||||
copy(arr, arr_copy, CopyType::General, stream);
|
|
||||||
return {arr_copy, true};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_cpu(in, s);
|
||||||
copy(in, arr_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ void Eig::eval_cpu(
|
|||||||
: array(a.shape(), complex64, nullptr, {});
|
: array(a.shape(), complex64, nullptr, {});
|
||||||
|
|
||||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
a_copy,
|
a_copy,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ void Eigh::eval_cpu(
|
|||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
vectors,
|
vectors,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(
|
copy_cpu(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(idx);
|
encoder.set_input_array(idx);
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void inverse_impl(
|
|||||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||||
|
|
||||||
// The inverse is computed in place, so just copy the input to the output.
|
// The inverse is computed in place, so just copy the input to the output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
inv,
|
inv,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
|
|||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
INSTANTIATE_LAPACK_REAL(geev)
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ void luf_impl(
|
|||||||
strides[ndim - 1] = M;
|
strides[ndim - 1] = M;
|
||||||
strides[ndim - 2] = 1;
|
strides[ndim - 2] = 1;
|
||||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
a,
|
a,
|
||||||
lu,
|
lu,
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
|||||||
@@ -124,21 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(false, stx, arr, false);
|
return std::make_tuple(false, stx, arr, false);
|
||||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(true, sty, arr_copy, true);
|
return std::make_tuple(true, sty, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(true, sty, arr, false);
|
return std::make_tuple(true, sty, arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
|
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -386,7 +385,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@@ -504,7 +503,7 @@ void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return std::make_tuple(true, sty, x);
|
return std::make_tuple(true, sty, x);
|
||||||
} else {
|
} else {
|
||||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy(x, xc, CopyType::General, s);
|
copy_cpu(x, xc, CopyType::General, s);
|
||||||
encoder.add_temporary(xc);
|
encoder.add_temporary(xc);
|
||||||
int64_t stx = x.shape(-1);
|
int64_t stx = x.shape(-1);
|
||||||
return std::make_tuple(false, stx, xc);
|
return std::make_tuple(false, stx, xc);
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ void matmul_general(
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, stream);
|
copy_cpu(arr, temps.back(), CopyType::General, stream);
|
||||||
stx = arr.shape(-1);
|
stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
CopyType ctype = c.data_size() == 1
|
CopyType ctype = c.data_size() == 1
|
||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy(c, out, ctype, stream());
|
copy_cpu(c, out, ctype, stream());
|
||||||
if (inputs[0].shape(-1) == 0) {
|
if (inputs[0].shape(-1) == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
size_t data_offset = strides[axis_] * sizes[i];
|
size_t data_offset = strides[axis_] * sizes[i];
|
||||||
out_slice.copy_shared_buffer(
|
out_slice.copy_shared_buffer(
|
||||||
out, strides, flags, out_slice.size(), data_offset);
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(in, out, CopyType::General, stream());
|
copy_cpu(in, out, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
// Fill output with val
|
// Fill output with val
|
||||||
copy(val, out, CopyType::Scalar, stream());
|
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||||
|
|
||||||
// Find offset for start of input values
|
// Find offset for start of input values
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto [in_offset, donated] =
|
auto [in_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ in,
|
/* const array& src = */ in,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const Shape& data_shape = */ out.shape(),
|
/* const Shape& data_shape = */ out.shape(),
|
||||||
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
auto [out_offset, donated] =
|
auto [out_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
// Calculate out strides, initial offset and if copy needs to be made
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
auto [data_offset, out_strides] =
|
auto [data_offset, out_strides] =
|
||||||
prepare_slice(out, start_indices_, strides_);
|
prepare_slice(out, start_indices_, strides_);
|
||||||
|
|
||||||
// Do copy
|
// Do copy
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||||
in_tmp.copy_shared_buffer(in);
|
in_tmp.copy_shared_buffer(in);
|
||||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||||
} else {
|
} else {
|
||||||
copy_inplace(in, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in, tmp, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
strides[in.ndim() - 2] = 1;
|
strides[in.ndim() - 2] = 1;
|
||||||
strides[in.ndim() - 1] = M;
|
strides[in.ndim() - 1] = M;
|
||||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
q.set_data(allocator::malloc(q.nbytes()));
|
q.set_data(allocator::malloc(q.nbytes()));
|
||||||
r.set_data(allocator::malloc(r.nbytes()));
|
r.set_data(allocator::malloc(r.nbytes()));
|
||||||
|
|||||||
@@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return std::make_pair(arr, false);
|
return std::make_pair(arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
|
||||||
return std::make_pair(arr_copy, true);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -325,7 +325,15 @@ struct MaxReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::max(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::max(x);
|
return simd::max(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@@ -342,7 +350,15 @@ struct MinReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::min(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::min(x);
|
return simd::min(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@@ -475,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
case uint8:
|
case uint8:
|
||||||
|
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
case uint16:
|
|
||||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
case uint32:
|
|
||||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
case uint64:
|
|
||||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
@@ -527,10 +551,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||||
|
|||||||
@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Ensure contiguity
|
// Ensure contiguity
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_cpu(in, stream());
|
||||||
copy(in, arr_copy, CopyType::General, stream());
|
encoder.add_temporary(in);
|
||||||
in = arr_copy;
|
|
||||||
encoder.add_temporary(arr_copy);
|
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -333,45 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
int axis = axis_;
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += in.ndim();
|
||||||
|
}
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||||
copy(in, out, ctype, stream());
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.dispatch(
|
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||||
switch (out.dtype()) {
|
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||||
case bool_:
|
});
|
||||||
return sort<bool>(out, axis_);
|
});
|
||||||
case uint8:
|
|
||||||
return sort<uint8_t>(out, axis_);
|
|
||||||
case uint16:
|
|
||||||
return sort<uint16_t>(out, axis_);
|
|
||||||
case uint32:
|
|
||||||
return sort<uint32_t>(out, axis_);
|
|
||||||
case uint64:
|
|
||||||
return sort<uint64_t>(out, axis_);
|
|
||||||
case int8:
|
|
||||||
return sort<int8_t>(out, axis_);
|
|
||||||
case int16:
|
|
||||||
return sort<int16_t>(out, axis_);
|
|
||||||
case int32:
|
|
||||||
return sort<int32_t>(out, axis_);
|
|
||||||
case int64:
|
|
||||||
return sort<int64_t>(out, axis_);
|
|
||||||
case float32:
|
|
||||||
return sort<float>(out, axis_);
|
|
||||||
case float64:
|
|
||||||
return sort<double>(out, axis_);
|
|
||||||
case float16:
|
|
||||||
return sort<float16_t>(out, axis_);
|
|
||||||
case bfloat16:
|
|
||||||
return sort<bfloat16_t>(out, axis_);
|
|
||||||
case complex64:
|
|
||||||
return sort<complex64_t>(out, axis_);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -426,8 +405,10 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||||
copy(in, out, ctype, stream());
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ void svd_impl(
|
|||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
array in(a.shape(), a.dtype(), nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
in,
|
in,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
@@ -81,9 +81,7 @@ void svd_impl(
|
|||||||
// Vᵀ of shape N x N. (M x M in lapack).
|
// Vᵀ of shape N x N. (M x M in lapack).
|
||||||
const int ldvt = M;
|
const int ldvt = M;
|
||||||
|
|
||||||
auto job_u = (u_ptr) ? "V" : "N";
|
auto jobz = (u_ptr) ? "A" : "N";
|
||||||
auto job_vt = (u_ptr) ? "V" : "N";
|
|
||||||
static constexpr auto range = "A";
|
|
||||||
|
|
||||||
// Will contain the number of singular values after the call has returned.
|
// Will contain the number of singular values after the call has returned.
|
||||||
int ns = 0;
|
int ns = 0;
|
||||||
@@ -91,30 +89,20 @@ void svd_impl(
|
|||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
// used here but required by lapack).
|
// used here but required by lapack).
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
|
|
||||||
static const int ignored_int = 0;
|
|
||||||
static const T ignored_float = 0;
|
|
||||||
|
|
||||||
int info;
|
int info;
|
||||||
|
|
||||||
// Compute workspace size.
|
// Compute workspace size.
|
||||||
gesvdx<T>(
|
gesdd<T>(
|
||||||
/* jobu = */ job_u,
|
/* jobz = */ jobz,
|
||||||
/* jobvt = */ job_vt,
|
|
||||||
/* range = */ range,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
// M and N are swapped since lapack expects column-major.
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &M,
|
/* n = */ &M,
|
||||||
/* a = */ nullptr,
|
/* a = */ nullptr,
|
||||||
/* lda = */ &lda,
|
/* lda = */ &lda,
|
||||||
/* vl = */ &ignored_float,
|
|
||||||
/* vu = */ &ignored_float,
|
|
||||||
/* il = */ &ignored_int,
|
|
||||||
/* iu = */ &ignored_int,
|
|
||||||
/* ns = */ &ns,
|
|
||||||
/* s = */ nullptr,
|
/* s = */ nullptr,
|
||||||
/* u = */ nullptr,
|
/* u = */ nullptr,
|
||||||
/* ldu = */ &ldu,
|
/* ldu = */ &ldu,
|
||||||
@@ -136,20 +124,13 @@ void svd_impl(
|
|||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
gesvdx<T>(
|
gesdd<T>(
|
||||||
/* jobu = */ job_u,
|
/* jobz = */ jobz,
|
||||||
/* jobvt = */ job_vt,
|
|
||||||
/* range = */ range,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
// M and N are swapped since lapack expects column-major.
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &M,
|
/* n = */ &M,
|
||||||
/* a = */ in_ptr + M * N * i,
|
/* a = */ in_ptr + M * N * i,
|
||||||
/* lda = */ &lda,
|
/* lda = */ &lda,
|
||||||
/* vl = */ &ignored_float,
|
|
||||||
/* vu = */ &ignored_float,
|
|
||||||
/* il = */ &ignored_int,
|
|
||||||
/* iu = */ &ignored_int,
|
|
||||||
/* ns = */ &ns,
|
|
||||||
/* s = */ s_ptr + K * i,
|
/* s = */ s_ptr + K * i,
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||||
@@ -167,13 +148,6 @@ void svd_impl(
|
|||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ns != K) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
|
||||||
<< " were computed.";
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
@@ -15,18 +15,25 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
@@ -35,14 +42,28 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||||
|
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||||
|
else()
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
# Embed kernel sources in binary for JIT compilation.
|
# Embed kernel sources in binary for JIT compilation.
|
||||||
@@ -67,6 +88,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
|||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
|
||||||
|
# Enable calling host constexpr functions from device. This is needed because
|
||||||
|
# the constexpr version of isnan is host only.
|
||||||
|
target_compile_options(
|
||||||
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
|
||||||
|
|
||||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||||
# true but the warning wouldn't be suppressed.
|
# true but the warning wouldn't be suppressed.
|
||||||
@@ -80,11 +106,18 @@ endif()
|
|||||||
target_compile_options(
|
target_compile_options(
|
||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||||
|
|
||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# and requires drivers released after CUDA 12.4.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||||
"70;80"
|
target_compile_options(
|
||||||
CACHE STRING "CUDA architectures")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||||
|
# managed memory.
|
||||||
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
|
set(MLX_CUDA_ARCHITECTURES "native")
|
||||||
|
endif()
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
"${MLX_CUDA_ARCHITECTURES}")
|
"${MLX_CUDA_ARCHITECTURES}")
|
||||||
@@ -116,6 +149,27 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|||||||
# Use NVRTC and driver APIs.
|
# Use NVRTC and driver APIs.
|
||||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
|
# Use the frontend APIs of cuDNN.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cudnn
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
|
GIT_TAG v1.14.0
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||||
|
FetchContent_MakeAvailable(cudnn)
|
||||||
|
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||||
|
# Link with the actual cuDNN libraries.
|
||||||
|
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||||
|
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
|
||||||
|
# Install CCCL headers for JIT.
|
||||||
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -17,14 +16,66 @@ namespace cu {
|
|||||||
|
|
||||||
constexpr int page_size = 16384;
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
|
// Any allocations smaller than this will try to use the small pool
|
||||||
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
|
// size and small_block_size.
|
||||||
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
|
SmallSizePool::SmallSizePool() {
|
||||||
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
buffer_ = new Block[num_blocks];
|
||||||
|
|
||||||
|
next_free_ = buffer_;
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
|
auto curr = next_free_;
|
||||||
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
curr->next = buffer_ + i;
|
||||||
|
curr = curr->next;
|
||||||
|
}
|
||||||
|
curr->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallSizePool::~SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||||
|
delete[] buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffer* SmallSizePool::malloc() {
|
||||||
|
if (next_free_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Block* b = next_free_;
|
||||||
|
uint64_t i = next_free_ - buffer_;
|
||||||
|
next_free_ = next_free_->next;
|
||||||
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
|
b->buf.size = small_block_size;
|
||||||
|
return &b->buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SmallSizePool::free(CudaBuffer* buf) {
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
b->next = next_free_;
|
||||||
|
next_free_ = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||||
|
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
int64_t block_num = b - buffer_;
|
||||||
|
return block_num >= 0 && block_num < num_blocks;
|
||||||
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}) {
|
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@@ -36,7 +87,9 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
if (size <= small_block_size) {
|
||||||
|
size = 8;
|
||||||
|
} else if (size < page_size) {
|
||||||
size = next_power_of_2(size);
|
size = next_power_of_2(size);
|
||||||
} else {
|
} else {
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
@@ -44,19 +97,25 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||||
// try to reclaim memory from the cache.
|
int64_t mem_to_free =
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
if (mem_required >= memory_limit_) {
|
if (mem_to_free > 0) {
|
||||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try the scalar pool first
|
||||||
|
if (size <= small_block_size) {
|
||||||
|
buf = scalar_pool_.malloc();
|
||||||
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
buf = new CudaBuffer{nullptr, size};
|
if (!buf) {
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
buf = new CudaBuffer{nullptr, size};
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
throw std::runtime_error(fmt::format(
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
@@ -67,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
if (get_cache_memory() > max_pool_size_) {
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
lock.unlock();
|
cuda_free(buf);
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,27 +152,14 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
|||||||
return buf->size;
|
return buf->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::register_this_thread() {
|
// This must be called with mutex_ aquired
|
||||||
std::lock_guard lock(worker_mutex_);
|
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||||
allowed_threads_.insert(std::this_thread::get_id());
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
}
|
scalar_pool_.free(buf);
|
||||||
|
} else {
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
cudaFree(buf->data);
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
delete buf;
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cudaFree(buf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
|
|||||||
@@ -7,13 +7,10 @@
|
|||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <thread>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
class Worker;
|
|
||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
// Stores cuda-managed unified memory.
|
||||||
@@ -22,21 +19,35 @@ struct CudaBuffer {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SmallSizePool {
|
||||||
|
private:
|
||||||
|
union Block {
|
||||||
|
Block* next;
|
||||||
|
CudaBuffer buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
Block* buffer_{nullptr};
|
||||||
|
void* data_{nullptr};
|
||||||
|
Block* next_free_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SmallSizePool();
|
||||||
|
~SmallSizePool();
|
||||||
|
|
||||||
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
|
CudaBuffer* malloc();
|
||||||
|
void free(CudaBuffer* buf);
|
||||||
|
bool in_pool(CudaBuffer* buf);
|
||||||
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
// Register current thread as safe to free buffers.
|
|
||||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
|
||||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
|
||||||
// buffers there would result in dead lock.
|
|
||||||
void register_this_thread();
|
|
||||||
|
|
||||||
// Call cudaFree in the safe thread.
|
|
||||||
void cuda_free(void* buf);
|
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@@ -47,19 +58,18 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
std::unique_ptr<Worker> worker_;
|
|
||||||
std::set<std::thread::id> allowed_threads_;
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|||||||
55
mlx/backend/cuda/arange.cu
Normal file
55
mlx/backend/cuda/arange.cu
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Arange {
|
||||||
|
const T start;
|
||||||
|
const T step;
|
||||||
|
|
||||||
|
__device__ T operator()(uint32_t i) const {
|
||||||
|
return start + i * step;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
|
CTYPE step =
|
||||||
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||||
|
thrust::transform(
|
||||||
|
cu::thrust_policy(encoder.stream()),
|
||||||
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
|
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
cu::Arange<OutType>{
|
||||||
|
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -44,8 +44,11 @@ struct ArgMin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__device__ IndexValPair<T>
|
__device__ IndexValPair<T> reduce_many(
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] < best.val) {
|
if (vals[i] < best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@@ -74,8 +77,11 @@ struct ArgMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__device__ IndexValPair<T>
|
__device__ IndexValPair<T> reduce_many(
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] > best.val) {
|
if (vals[i] > best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
|
|||||||
|
|
||||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||||
|
in += in_idx;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
T init = op.init();
|
T init = op.init();
|
||||||
IndexValPair<T> best{0, init};
|
IndexValPair<T> best{0, init};
|
||||||
|
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T vals[N_READS];
|
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
21
mlx/backend/cuda/binary/CMakeLists.txt
Normal file
21
mlx/backend/cuda/binary/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
|
||||||
7
mlx/backend/cuda/binary/add.cu
Normal file
7
mlx/backend/cuda/binary/add.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Add)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/arctan2.cu
Normal file
7
mlx/backend/cuda/binary/arctan2.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(ArcTan2)
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -20,21 +19,16 @@ namespace cg = cooperative_groups;
|
|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
int remaining = size - index * N_READS;
|
|
||||||
if (remaining <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remaining < N_READS) {
|
if ((index + 1) * N_READS > size) {
|
||||||
for (int i = 0; i < remaining; ++i) {
|
for (int i = index * N_READS; i < size; ++i) {
|
||||||
IdxT offset = index * N_READS + i;
|
out[i] = Op{}(a[0], b[0]);
|
||||||
out[offset] = Op{}(a[0], b[0]);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a[0], b[0]);
|
out_vec[i] = Op{}(a[0], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@@ -44,15 +38,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
int remaining = size - index * N_READS;
|
|
||||||
if (remaining <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remaining < N_READS) {
|
if ((index + 1) * N_READS > size) {
|
||||||
for (int i = 0; i < remaining; ++i) {
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
IdxT offset = index * N_READS + i;
|
out[i] = Op{}(a[0], b[i]);
|
||||||
out[offset] = Op{}(a[0], b[offset]);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto b_vec = load_vector<N_READS>(b, index);
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
@@ -60,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
|
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@@ -70,15 +59,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
int remaining = size - index * N_READS;
|
|
||||||
if (remaining <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remaining < N_READS) {
|
if ((index + 1) * N_READS > size) {
|
||||||
for (int i = 0; i < remaining; ++i) {
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
IdxT offset = index * N_READS + i;
|
out[i] = Op{}(a[i], b[0]);
|
||||||
out[offset] = Op{}(a[offset], b[0]);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto a_vec = load_vector<N_READS>(a, index);
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
@@ -86,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
|
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
@@ -96,15 +80,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
int remaining = size - index * N_READS;
|
|
||||||
if (remaining <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remaining < N_READS) {
|
if ((index + 1) * N_READS > size) {
|
||||||
for (int i = 0; i < remaining; ++i) {
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
IdxT offset = index * N_READS + i;
|
out[i] = Op{}(a[i], b[i]);
|
||||||
out[offset] = Op{}(a[offset], b[offset]);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto a_vec = load_vector<N_READS>(a, index);
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
@@ -113,46 +92,96 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
|||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
|
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
template <
|
||||||
|
typename Op,
|
||||||
|
typename In,
|
||||||
|
typename Out,
|
||||||
|
typename IdxT,
|
||||||
|
int NDIM,
|
||||||
|
int N_READS>
|
||||||
__global__ void binary_g_nd(
|
__global__ void binary_g_nd(
|
||||||
const In* a,
|
const In* a,
|
||||||
const In* b,
|
const In* b,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
IdxT index_rest =
|
||||||
index, shape.data(), a_strides.data(), b_strides.data());
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[NDIM - 1];
|
||||||
|
auto a_stride_x = a_strides[NDIM - 1];
|
||||||
|
auto b_stride_x = b_strides[NDIM - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
auto a_vec =
|
||||||
|
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||||
|
auto b_vec =
|
||||||
|
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_g(
|
__global__ void binary_g(
|
||||||
const In* a,
|
const In* a,
|
||||||
const In* b,
|
const In* b,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ Shape shape,
|
const __grid_constant__ Shape shape,
|
||||||
const __grid_constant__ Strides a_strides,
|
const __grid_constant__ Strides a_strides,
|
||||||
const __grid_constant__ Strides b_strides,
|
const __grid_constant__ Strides b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
IdxT index_rest =
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[ndim - 1];
|
||||||
|
auto a_stride_x = a_strides[ndim - 1];
|
||||||
|
auto b_stride_x = b_strides[ndim - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
|
index_rest * shape_x,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
ndim);
|
||||||
|
auto a_vec =
|
||||||
|
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||||
|
auto b_vec =
|
||||||
|
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out>
|
template <typename Op, typename In, typename Out>
|
||||||
@@ -197,7 +226,7 @@ template <typename Op>
|
|||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
@@ -230,36 +259,61 @@ void binary_op_gpu_inplace(
|
|||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
int work_per_thread = 1;
|
||||||
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
|
auto rest = out.size() / dim0;
|
||||||
|
if (dim0 >= 4) {
|
||||||
|
work_per_thread = 4;
|
||||||
|
}
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
|
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||||
|
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::
|
auto kernel = cu::binary_g_nd<
|
||||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
Op,
|
||||||
auto [num_blocks, block_dims] =
|
InType,
|
||||||
get_launch_args(kernel, out, large());
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant(),
|
||||||
|
1>;
|
||||||
|
if (work_per_thread == 4) {
|
||||||
|
kernel = cu::binary_g_nd<
|
||||||
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant(),
|
||||||
|
4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.size(),
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
const_param<dims_constant()>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
|
||||||
auto [num_blocks, block_dims] =
|
if (work_per_thread == 4) {
|
||||||
get_launch_args(kernel, out, large());
|
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.size(),
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
const_param(b_strides),
|
const_param(b_strides),
|
||||||
@@ -267,10 +321,9 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
@@ -280,16 +333,12 @@ void binary_op_gpu_inplace(
|
|||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
out.data_size(),
|
|
||||||
out.shape(),
|
|
||||||
out.strides(),
|
|
||||||
large(),
|
|
||||||
N_READS);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@@ -311,7 +360,7 @@ template <typename Op>
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@@ -320,63 +369,11 @@ void binary_op_gpu(
|
|||||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU(func) \
|
#define BINARY_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
auto& s = out.primitive().stream(); \
|
auto& s = out.primitive().stream(); \
|
||||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||||
}
|
}
|
||||||
|
|
||||||
BINARY_GPU(Add)
|
|
||||||
BINARY_GPU(ArcTan2)
|
|
||||||
BINARY_GPU(Divide)
|
|
||||||
BINARY_GPU(Remainder)
|
|
||||||
BINARY_GPU(Greater)
|
|
||||||
BINARY_GPU(GreaterEqual)
|
|
||||||
BINARY_GPU(Less)
|
|
||||||
BINARY_GPU(LessEqual)
|
|
||||||
BINARY_GPU(LogicalAnd)
|
|
||||||
BINARY_GPU(LogicalOr)
|
|
||||||
BINARY_GPU(LogAddExp)
|
|
||||||
BINARY_GPU(Maximum)
|
|
||||||
BINARY_GPU(Minimum)
|
|
||||||
BINARY_GPU(Multiply)
|
|
||||||
BINARY_GPU(NotEqual)
|
|
||||||
BINARY_GPU(Power)
|
|
||||||
BINARY_GPU(Subtract)
|
|
||||||
|
|
||||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
if (equal_nan_) {
|
|
||||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
|
||||||
} else {
|
|
||||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
switch (op_) {
|
|
||||||
case BitwiseBinary::And:
|
|
||||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Or:
|
|
||||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Xor:
|
|
||||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::LeftShift:
|
|
||||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::RightShift:
|
|
||||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
27
mlx/backend/cuda/binary/bitwise_binary.cu
Normal file
27
mlx/backend/cuda/binary/bitwise_binary.cu
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
switch (op_) {
|
||||||
|
case BitwiseBinary::And:
|
||||||
|
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Or:
|
||||||
|
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Xor:
|
||||||
|
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::LeftShift:
|
||||||
|
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::RightShift:
|
||||||
|
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/divide.cu
Normal file
7
mlx/backend/cuda/binary/divide.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Divide)
|
||||||
|
} // namespace mlx::core
|
||||||
15
mlx/backend/cuda/binary/equal.cu
Normal file
15
mlx/backend/cuda/binary/equal.cu
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
if (equal_nan_) {
|
||||||
|
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||||
|
} else {
|
||||||
|
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/greater.cu
Normal file
7
mlx/backend/cuda/binary/greater.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Greater)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/greater_equal.cu
Normal file
7
mlx/backend/cuda/binary/greater_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(GreaterEqual)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/less.cu
Normal file
7
mlx/backend/cuda/binary/less.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Less)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/less_equal.cu
Normal file
7
mlx/backend/cuda/binary/less_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(LessEqual)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/log_add_exp.cu
Normal file
7
mlx/backend/cuda/binary/log_add_exp.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(LogAddExp)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/logical_and.cu
Normal file
7
mlx/backend/cuda/binary/logical_and.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(LogicalAnd)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/logical_or.cu
Normal file
7
mlx/backend/cuda/binary/logical_or.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(LogicalOr)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/maximum.cu
Normal file
7
mlx/backend/cuda/binary/maximum.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Maximum)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/minimum.cu
Normal file
7
mlx/backend/cuda/binary/minimum.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Minimum)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/multiply.cu
Normal file
7
mlx/backend/cuda/binary/multiply.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Multiply)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/not_equal.cu
Normal file
7
mlx/backend/cuda/binary/not_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(NotEqual)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/power.cu
Normal file
7
mlx/backend/cuda/binary/power.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Power)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/remainder.cu
Normal file
7
mlx/backend/cuda/binary/remainder.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Remainder)
|
||||||
|
} // namespace mlx::core
|
||||||
7
mlx/backend/cuda/binary/subtract.cu
Normal file
7
mlx/backend/cuda/binary/subtract.cu
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
BINARY_GPU(Subtract)
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -17,93 +16,214 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[0], b[0]);
|
if ((index + 1) * N_READS > size) {
|
||||||
out_a[0] = out[0];
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
out_b[0] = out[1];
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[0], b[index]);
|
if ((index + 1) * N_READS > size) {
|
||||||
out_a[index] = out[0];
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
out_b[index] = out[1];
|
auto out = Op{}(a[0], b[i]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b_vec[i]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[index], b[0]);
|
if ((index + 1) * N_READS > size) {
|
||||||
out_a[index] = out[0];
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
out_b[index] = out[1];
|
auto out = Op{}(a[i], b[0]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b[0]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void
|
__global__ void
|
||||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
auto out = Op{}(a[index], b[index]);
|
if ((index + 1) * N_READS > size) {
|
||||||
out_a[index] = out[0];
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
out_b[index] = out[1];
|
auto out = Op{}(a[i], b[i]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
template <
|
||||||
__global__ void binary_g_nd(
|
typename Op,
|
||||||
|
typename In,
|
||||||
|
typename Out,
|
||||||
|
typename IdxT,
|
||||||
|
int NDIM,
|
||||||
|
int N_READS>
|
||||||
|
__global__ void binary_two_g_nd(
|
||||||
const In* a,
|
const In* a,
|
||||||
const In* b,
|
const In* b,
|
||||||
Out* out_a,
|
Out* out_a,
|
||||||
Out* out_b,
|
Out* out_b,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
IdxT index_rest =
|
||||||
index, shape.data(), a_strides.data(), b_strides.data());
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
if (index_rest >= size_rest) {
|
||||||
out_a[index] = out[0];
|
return;
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[NDIM - 1];
|
||||||
|
auto a_stride_x = a_strides[NDIM - 1];
|
||||||
|
auto b_stride_x = b_strides[NDIM - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
auto a_vec =
|
||||||
|
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||||
|
auto b_vec =
|
||||||
|
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec_a;
|
||||||
|
AlignedVector<Out, N_READS> out_vec_b;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
out_vec_a[i] = out[0];
|
||||||
|
out_vec_b[i] = out[1];
|
||||||
|
}
|
||||||
|
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
|
||||||
|
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT>
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void binary_g(
|
__global__ void binary_two_g(
|
||||||
const In* a,
|
const In* a,
|
||||||
const In* b,
|
const In* b,
|
||||||
Out* out_a,
|
Out* out_a,
|
||||||
Out* out_b,
|
Out* out_b,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ Shape shape,
|
const __grid_constant__ Shape shape,
|
||||||
const __grid_constant__ Strides a_strides,
|
const __grid_constant__ Strides a_strides,
|
||||||
const __grid_constant__ Strides b_strides,
|
const __grid_constant__ Strides b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
IdxT index_rest =
|
||||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
if (index_rest >= size_rest) {
|
||||||
out_a[index] = out[0];
|
return;
|
||||||
out_b[index] = out[1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[ndim - 1];
|
||||||
|
auto a_stride_x = a_strides[ndim - 1];
|
||||||
|
auto b_stride_x = b_strides[ndim - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
|
index_rest * shape_x,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
ndim);
|
||||||
|
auto a_vec =
|
||||||
|
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||||
|
auto b_vec =
|
||||||
|
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec_a;
|
||||||
|
AlignedVector<Out, N_READS> out_vec_b;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
out_vec_a[i] = out[0];
|
||||||
|
out_vec_b[i] = out[1];
|
||||||
|
}
|
||||||
|
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
|
||||||
|
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out>
|
template <typename Op, typename In, typename Out>
|
||||||
constexpr bool supports_binary_op() {
|
constexpr bool supports_binary_two_op() {
|
||||||
if (std::is_same_v<Op, DivMod>) {
|
if (std::is_same_v<Op, DivMod>) {
|
||||||
return std::is_same_v<In, Out> &&
|
return std::is_same_v<In, Out> &&
|
||||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||||
@@ -114,10 +234,10 @@ constexpr bool supports_binary_op() {
|
|||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu_inplace(
|
void binary_two_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
@@ -141,7 +261,7 @@ void binary_op_gpu_inplace(
|
|||||||
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
@@ -159,38 +279,64 @@ void binary_op_gpu_inplace(
|
|||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
int work_per_thread = 1;
|
||||||
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
|
auto rest = out_a.size() / dim0;
|
||||||
|
if (dim0 >= 4) {
|
||||||
|
work_per_thread = 4;
|
||||||
|
}
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
|
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||||
|
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||||
|
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::
|
auto kernel = cu::binary_two_g_nd<
|
||||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
Op,
|
||||||
auto [num_blocks, block_dims] =
|
InType,
|
||||||
get_launch_args(kernel, out_a, large());
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant(),
|
||||||
|
1>;
|
||||||
|
if (work_per_thread == 4) {
|
||||||
|
kernel = cu::binary_two_g_nd<
|
||||||
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant(),
|
||||||
|
4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
out_b.data<OutType>(),
|
out_b.data<OutType>(),
|
||||||
out_a.size(),
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
const_param<dims_constant()>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
|
||||||
auto [num_blocks, block_dims] =
|
if (work_per_thread == 4) {
|
||||||
get_launch_args(kernel, out_a, large());
|
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
out_b.data<OutType>(),
|
out_b.data<OutType>(),
|
||||||
out_a.size(),
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
const_param(b_strides),
|
const_param(b_strides),
|
||||||
@@ -198,26 +344,28 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
|
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
|
||||||
out_a.data_size(),
|
out_a.data_size(),
|
||||||
out_a.shape(),
|
out_a.shape(),
|
||||||
out_a.strides(),
|
out_a.strides(),
|
||||||
large());
|
large(),
|
||||||
|
N_READS);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@@ -237,17 +385,17 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void binary_op_gpu(
|
void binary_two_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval_gpu(
|
void DivMod::eval_gpu(
|
||||||
@@ -255,7 +403,7 @@ void DivMod::eval_gpu(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||||
auto& s = outputs[0].primitive().stream();
|
auto& s = outputs[0].primitive().stream();
|
||||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Build function signature.
|
// Build function signature.
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
os += "template <typename IdxT = uint32_t>\n";
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
} else {
|
} else {
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
}
|
}
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
@@ -67,12 +68,77 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
os += ") {\n";
|
os += ") {\n";
|
||||||
|
|
||||||
// Index.
|
// Index. For non contiguous kernels we create a separate index
|
||||||
|
// variable per variable otherwise everyone uses `index`.
|
||||||
os +=
|
os +=
|
||||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||||
" if (index >= size) {\n"
|
" if (index >= size) {\n"
|
||||||
" return;\n"
|
" return;\n"
|
||||||
" }\n";
|
" }\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " IdxT " + xname + "_idx = 0;\n";
|
||||||
|
}
|
||||||
|
os += " {\n";
|
||||||
|
os += " IdxT loc = index;\n";
|
||||||
|
os +=
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||||
|
"_strides[i]);\n";
|
||||||
|
}
|
||||||
|
os +=
|
||||||
|
" loc /= shape[i];\n"
|
||||||
|
" }\n"
|
||||||
|
" }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vectorized read loop
|
||||||
|
if (contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||||
|
xname,
|
||||||
|
type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create some space for the outputs
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
if (!contiguous) {
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
} else {
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
@@ -87,14 +153,11 @@ struct FusedKernelBuilder {
|
|||||||
} else if (is_scalar(x)) {
|
} else if (is_scalar(x)) {
|
||||||
value = fmt::format("{}[0]", xname);
|
value = fmt::format("{}[0]", xname);
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
value = fmt::format("{}[index]", xname);
|
value = fmt::format("vec_{}[i]", xname);
|
||||||
} else {
|
} else {
|
||||||
std::string index = fmt::format(
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
|
||||||
xname);
|
|
||||||
value = fmt::format("{}[{}]", xname, index);
|
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write tape.
|
// Write tape.
|
||||||
@@ -106,21 +169,40 @@ struct FusedKernelBuilder {
|
|||||||
value = fmt::format(
|
value = fmt::format(
|
||||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream ss;
|
value = x.primitive().name();
|
||||||
x.primitive().print(ss);
|
|
||||||
value = ss.str();
|
|
||||||
value += "{}(";
|
value += "{}(";
|
||||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||||
}
|
}
|
||||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write output.
|
// Write output.
|
||||||
for (const auto& x : outputs) {
|
for (const auto& x : outputs) {
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// End of work loop
|
||||||
|
if (!contiguous) {
|
||||||
|
os += "\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += " }\n";
|
||||||
|
|
||||||
|
// Store the output to global memory
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
os += fmt::format(
|
||||||
|
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||||
|
namer.get_name(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
os += "}\n";
|
os += "}\n";
|
||||||
@@ -146,6 +228,15 @@ void Compiled::eval_gpu(
|
|||||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||||
|
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||||
|
// over the max itemsize of all arrays.
|
||||||
|
int max_size = 1;
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||||
|
}
|
||||||
|
int work_per_thread = 16 / max_size;
|
||||||
|
|
||||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||||
// Build source code.
|
// Build source code.
|
||||||
cu::FusedKernelBuilder builder{
|
cu::FusedKernelBuilder builder{
|
||||||
@@ -158,16 +249,24 @@ void Compiled::eval_gpu(
|
|||||||
builder.build("_strided", false);
|
builder.build("_strided", false);
|
||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names = {
|
std::vector<std::string> kernel_names;
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
kernel_names.push_back(fmt::format(
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
};
|
lib_name(),
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
work_per_thread));
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
kernel_names.push_back(
|
lib_name(),
|
||||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
work_per_thread));
|
||||||
|
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||||
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -209,13 +308,20 @@ void Compiled::eval_gpu(
|
|||||||
args.append<uint32_t>(outputs[0].data_size());
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch kernel.
|
// Launch kernel.
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
kernel_name +=
|
||||||
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
kernel_name += fmt::format(
|
||||||
|
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@@ -226,8 +332,9 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
auto [num_blocks, block_dims] =
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
get_launch_args(outputs[0], large, work_per_thread);
|
||||||
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
418
mlx/backend/cuda/conv.cpp
Normal file
418
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Alias for better readability.
|
||||||
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_INPUT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_WEIGHT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||||
|
|
||||||
|
// Custom placeholder representing fallback kernel.
|
||||||
|
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
||||||
|
|
||||||
|
struct ConvCacheKey {
|
||||||
|
int device_id;
|
||||||
|
cudnnDataType_t cudnn_dtype;
|
||||||
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
|
std::array<int, MAX_NDIM> weight_shape;
|
||||||
|
std::array<int, MAX_NDIM> stride;
|
||||||
|
std::array<int, MAX_NDIM> padding_lo;
|
||||||
|
std::array<int, MAX_NDIM> padding_hi;
|
||||||
|
std::array<int, MAX_NDIM> dilation;
|
||||||
|
int groups;
|
||||||
|
bool flip;
|
||||||
|
uint8_t input_alignment;
|
||||||
|
uint8_t weight_alignment;
|
||||||
|
uint8_t output_alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& conv_cache() {
|
||||||
|
static LRUBytesKeyCache<
|
||||||
|
ConvCacheKey,
|
||||||
|
std::pair<
|
||||||
|
cudnnBackendDescriptorType_t,
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||||
|
cache(/* capacity */ 128);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_conv_op_settings(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const std::vector<int>& kernel_strides,
|
||||||
|
const std::vector<int>& padding_lo_,
|
||||||
|
const std::vector<int>& padding_hi_,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation) {
|
||||||
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||||
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||||
|
|
||||||
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||||
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||||
|
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||||
|
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||||
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(input_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
padding_hi = padding_lo;
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_strides));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_strides),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const SmallVector<int64_t>& stride,
|
||||||
|
const SmallVector<int64_t>& padding_lo,
|
||||||
|
const SmallVector<int64_t>& padding_hi,
|
||||||
|
const SmallVector<int64_t>& dilation) {
|
||||||
|
try {
|
||||||
|
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||||
|
? CUDNN_DATA_FLOAT
|
||||||
|
: dtype_to_cudnn_type(dtype);
|
||||||
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||||
|
.setDataType(compute_dtype)
|
||||||
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||||
|
.setNDims(stride.size())
|
||||||
|
.setStrides(stride.size(), stride.data())
|
||||||
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||||
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||||
|
.setDilation(dilation.size(), dilation.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||||
|
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||||
|
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||||
|
.setcDesc(conv_desc)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
return cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||||
|
array group_transpose(
|
||||||
|
const array& x,
|
||||||
|
int groups,
|
||||||
|
int group_dim,
|
||||||
|
int axis1,
|
||||||
|
int axis2,
|
||||||
|
Stream s) {
|
||||||
|
if (groups == 1) {
|
||||||
|
return swapaxes_in_eval(x, axis1, axis2);
|
||||||
|
}
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (group_dim < 0) {
|
||||||
|
group_dim += ndim;
|
||||||
|
}
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
if (group_dim <= axis1) {
|
||||||
|
axis1 += 1;
|
||||||
|
}
|
||||||
|
if (group_dim <= axis2) {
|
||||||
|
axis2 += 1;
|
||||||
|
}
|
||||||
|
auto shape = x.shape();
|
||||||
|
shape.insert(shape.begin() + group_dim, groups);
|
||||||
|
shape[group_dim + 1] = shape[group_dim + 1] / groups;
|
||||||
|
array x_trans = reshape_in_eval(x, std::move(shape), s);
|
||||||
|
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
|
||||||
|
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
|
||||||
|
return x_trans;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||||
|
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||||
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
|
std::tuple<array, array, array> prepare_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array in,
|
||||||
|
array wt,
|
||||||
|
array out,
|
||||||
|
int groups,
|
||||||
|
Stream s) {
|
||||||
|
// Transpose the args depending on the backend type.
|
||||||
|
// TODO: Handle groups.
|
||||||
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
|
wt = group_transpose(wt, groups, 0, 0, -1, s);
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
in = group_transpose(in, groups, -1, 0, -1, s);
|
||||||
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
|
// Create a contiguous array that shares the data with |out|, but with dim
|
||||||
|
// C_in and C_out swapped.
|
||||||
|
Shape shape(out.shape());
|
||||||
|
std::swap(shape.front(), shape.back());
|
||||||
|
Strides strides(shape.size(), 1);
|
||||||
|
for (int i = shape.size() - 2; i >= 0; --i) {
|
||||||
|
strides[i] = shape[i + 1] * strides[i + 1];
|
||||||
|
}
|
||||||
|
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
||||||
|
intermediate.copy_shared_buffer(
|
||||||
|
out, std::move(strides), {true, true, false}, out.data_size());
|
||||||
|
out = intermediate;
|
||||||
|
}
|
||||||
|
|
||||||
|
// cuDNN requires contiguous input.
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::move(in), std::move(wt), std::move(out)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||||
|
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& out) {
|
||||||
|
switch (backend_type) {
|
||||||
|
case CONV_BACKWARD_INPUT:
|
||||||
|
return {out, wt, in};
|
||||||
|
case CONV_BACKWARD_WEIGHT:
|
||||||
|
return {in, out, wt};
|
||||||
|
default:
|
||||||
|
return {in, wt, out};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register inputs and outputs before actually running conv op. Can only be
|
||||||
|
// called once per eval_gpu.
|
||||||
|
void register_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& intermediate_out,
|
||||||
|
array& final_out) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_input_array(wt);
|
||||||
|
encoder.set_output_array(final_out);
|
||||||
|
|
||||||
|
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
||||||
|
// in vjp and the final |grad_weight| will then be contiguous.
|
||||||
|
Strides strides = intermediate_out.strides();
|
||||||
|
std::swap(strides.front(), strides.back());
|
||||||
|
final_out.copy_shared_buffer(
|
||||||
|
intermediate_out,
|
||||||
|
std::move(strides),
|
||||||
|
{false, false, false},
|
||||||
|
intermediate_out.data_size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
if (out_.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
array in = inputs[0];
|
||||||
|
array wt = inputs[1];
|
||||||
|
array out = out_;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
Dtype dtype = out.dtype();
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
ConvCacheKey cache_key{
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
dtype_to_cudnn_type(dtype),
|
||||||
|
vector_key(in.shape()),
|
||||||
|
vector_key(wt.shape()),
|
||||||
|
vector_key(kernel_strides_),
|
||||||
|
vector_key(padding_lo_),
|
||||||
|
vector_key(padding_hi_),
|
||||||
|
vector_key(kernel_dilation_),
|
||||||
|
groups_,
|
||||||
|
flip_,
|
||||||
|
get_alignment(in),
|
||||||
|
get_alignment(wt),
|
||||||
|
get_alignment(out)};
|
||||||
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
|
auto& [backend_type, plan] = it->second;
|
||||||
|
if (plan) {
|
||||||
|
// Run cached plan.
|
||||||
|
std::tie(in, wt, out) =
|
||||||
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Run fallback kernel.
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups_,
|
||||||
|
flip_,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
|
// convolution, so we make a best guess and then try.
|
||||||
|
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||||
|
if (flip_) {
|
||||||
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
|
} else {
|
||||||
|
// Otherwise it could be backward weight convolution or forward convolution,
|
||||||
|
// mathematically there is no difference so we have to use heuristics.
|
||||||
|
// Empirically backward convolutions have large kernel dimensions, and
|
||||||
|
// usually have |in| and |wt| transposed.
|
||||||
|
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
||||||
|
wt.shape(2) > out.shape(2)) {
|
||||||
|
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
||||||
|
} else {
|
||||||
|
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to build op graph.
|
||||||
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
|
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||||
|
for (auto try_backend : try_backends) {
|
||||||
|
auto [in_copy, wt_copy, out_copy] =
|
||||||
|
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||||
|
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||||
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||||
|
try_backend,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_);
|
||||||
|
op_graph = build_conv_op_graph(
|
||||||
|
encoder,
|
||||||
|
try_backend,
|
||||||
|
dtype,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
stride,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
dilation);
|
||||||
|
if (op_graph) {
|
||||||
|
backend_type = try_backend;
|
||||||
|
in = std::move(in_copy);
|
||||||
|
wt = std::move(wt_copy);
|
||||||
|
out = std::move(out_copy);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_graph) {
|
||||||
|
// Setup inputs and outputs.
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
|
||||||
|
// Find a plan for the graph and execute it.
|
||||||
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
|
if (!plan) {
|
||||||
|
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
||||||
|
}
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
|
conv_cache().emplace(
|
||||||
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use fallback kernel for settings not supported by cuDNN.
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
groups_,
|
||||||
|
flip_,
|
||||||
|
s);
|
||||||
|
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
126
mlx/backend/cuda/conv/conv.h
Normal file
126
mlx/backend/cuda/conv/conv.h
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
struct ConvParams {
|
||||||
|
int N; // Batch size
|
||||||
|
int C; // In channels
|
||||||
|
int O; // Out channels
|
||||||
|
int strides[NDIM];
|
||||||
|
int padding[NDIM];
|
||||||
|
int kernel_dilation[NDIM];
|
||||||
|
int input_dilation[NDIM];
|
||||||
|
int groups;
|
||||||
|
bool flip;
|
||||||
|
int in_spatial_dims[NDIM];
|
||||||
|
int wt_spatial_dims[NDIM];
|
||||||
|
int out_spatial_dims[NDIM];
|
||||||
|
int64_t in_strides[NDIM + 2];
|
||||||
|
|
||||||
|
ConvParams(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
const array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip)
|
||||||
|
: N(in.shape(0)),
|
||||||
|
C(in.shape(-1)),
|
||||||
|
O(wt.shape(0)),
|
||||||
|
groups(groups),
|
||||||
|
flip(flip) {
|
||||||
|
std::copy_n(strides.begin(), NDIM, this->strides);
|
||||||
|
std::copy_n(padding.begin(), NDIM, this->padding);
|
||||||
|
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
|
||||||
|
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
|
||||||
|
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
|
||||||
|
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
|
||||||
|
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
|
||||||
|
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
inline void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array in,
|
||||||
|
array wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groups == 1) {
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
gemm_grouped_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int NDIM>
|
||||||
|
__global__ void naive_unfold_nd(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int filter_size,
|
||||||
|
int out_pixels,
|
||||||
|
const __grid_constant__ ConvParams<NDIM> params) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto tid = block.group_index();
|
||||||
|
auto lid = block.thread_index();
|
||||||
|
|
||||||
|
int index_batch = tid.z / out_pixels; // [0, N)
|
||||||
|
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||||
|
int index_wt_spatial =
|
||||||
|
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||||
|
|
||||||
|
if (index_wt_spatial >= filter_size / params.C) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += tid.y; // [0, C)
|
||||||
|
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
|
||||||
|
|
||||||
|
bool valid = index_batch < params.N;
|
||||||
|
|
||||||
|
// Get the coordinates in input.
|
||||||
|
int index_in[NDIM] = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||||
|
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||||
|
|
||||||
|
if (params.flip) {
|
||||||
|
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = index_out * params.strides[i] - params.padding[i] +
|
||||||
|
index_wt * params.kernel_dilation[i];
|
||||||
|
int index_max =
|
||||||
|
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||||
|
|
||||||
|
valid &= (index >= 0) && (index < index_max) &&
|
||||||
|
(index % params.input_dilation[i] == 0);
|
||||||
|
|
||||||
|
index_in[i] = index / params.input_dilation[i];
|
||||||
|
|
||||||
|
index_out_spatial /= params.out_spatial_dims[i];
|
||||||
|
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
int in_offset = index_batch * params.in_strides[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||||
|
}
|
||||||
|
*out = in[in_offset];
|
||||||
|
} else {
|
||||||
|
*out = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
array unfold_inputs_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
int mat_M,
|
||||||
|
int mat_K,
|
||||||
|
int mat_N,
|
||||||
|
ConvParams<NDIM>& params) {
|
||||||
|
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||||
|
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||||
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
|
int filter_size = params.C;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
filter_size *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_pixels = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
out_pixels *= params.out_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int wt_spatial_size = mat_K / params.C;
|
||||||
|
dim3 block_dims;
|
||||||
|
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||||
|
dim3 num_blocks;
|
||||||
|
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||||
|
num_blocks.y = params.C;
|
||||||
|
num_blocks.z = mat_M;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(unfolded);
|
||||||
|
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::naive_unfold_nd<DataType, NDIM>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<DataType>(),
|
||||||
|
unfolded.data<DataType>(),
|
||||||
|
filter_size,
|
||||||
|
out_pixels,
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
|
||||||
|
return unfolded;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
void gemm_conv_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
ConvParams<NDIM>& params,
|
||||||
|
Stream s) {
|
||||||
|
// Get gemm shapes.
|
||||||
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
|
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
||||||
|
int mat_N = params.O; // O
|
||||||
|
|
||||||
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
|
array in_unfolded =
|
||||||
|
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
|
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
|
||||||
|
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
||||||
|
wt_reshaped.copy_shared_buffer(
|
||||||
|
wt,
|
||||||
|
{1, mat_K},
|
||||||
|
{false, false, /* col_contiguous */ true},
|
||||||
|
wt.data_size());
|
||||||
|
|
||||||
|
// Single batch.
|
||||||
|
Shape batch_shape{1};
|
||||||
|
Strides a_batch_strides{0};
|
||||||
|
Strides b_batch_strides{0};
|
||||||
|
|
||||||
|
// Run matmul.
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
in.dtype(),
|
||||||
|
false, // a_transposed
|
||||||
|
mat_M, // a_rows
|
||||||
|
mat_K, // a_cols
|
||||||
|
mat_K, // lda
|
||||||
|
true, // b_transposed
|
||||||
|
mat_K, // b_rows
|
||||||
|
mat_N, // b_cols
|
||||||
|
mat_K, // ldb
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
gemm.run(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
in_unfolded,
|
||||||
|
wt_reshaped,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
int conv_ndim = in.ndim() - 2;
|
||||||
|
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
|
}
|
||||||
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
|
ConvParams<ndim_constant()> params(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
1, // groups
|
||||||
|
flip);
|
||||||
|
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int NDIM>
|
||||||
|
__global__ void naive_grouped_unfold_transpose_nd(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int filter_size,
|
||||||
|
int out_pixels,
|
||||||
|
const __grid_constant__ ConvParams<NDIM> params) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto tid = block.group_index();
|
||||||
|
auto lid = block.thread_index();
|
||||||
|
|
||||||
|
int index_batch = tid.z / out_pixels; // [0, N)
|
||||||
|
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||||
|
int index_wt_spatial =
|
||||||
|
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||||
|
|
||||||
|
if (index_wt_spatial >= filter_size / params.C) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += tid.y; // [0, C)
|
||||||
|
out += tid.z * filter_size + tid.y * (filter_size / params.C);
|
||||||
|
|
||||||
|
bool valid = index_batch < params.N;
|
||||||
|
|
||||||
|
// Get the coordinates in input.
|
||||||
|
int index_in[NDIM] = {};
|
||||||
|
int wt_stride = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||||
|
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||||
|
out += index_wt * wt_stride;
|
||||||
|
|
||||||
|
if (params.flip) {
|
||||||
|
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = index_out * params.strides[i] - params.padding[i] +
|
||||||
|
index_wt * params.kernel_dilation[i];
|
||||||
|
int index_max =
|
||||||
|
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||||
|
|
||||||
|
valid &= (index >= 0) && (index < index_max) &&
|
||||||
|
(index % params.input_dilation[i] == 0);
|
||||||
|
|
||||||
|
index_in[i] = index / params.input_dilation[i];
|
||||||
|
|
||||||
|
index_out_spatial /= params.out_spatial_dims[i];
|
||||||
|
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||||
|
wt_stride *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
int in_offset = index_batch * params.in_strides[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||||
|
}
|
||||||
|
*out = in[in_offset];
|
||||||
|
} else {
|
||||||
|
*out = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
array grouped_unfold_transpose_inputs_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
int mat_M,
|
||||||
|
int mat_K,
|
||||||
|
int mat_N,
|
||||||
|
ConvParams<NDIM>& params) {
|
||||||
|
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||||
|
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||||
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
|
int filter_size = params.C;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
filter_size *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_pixels = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
out_pixels *= params.out_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int wt_spatial_size = (mat_K * params.groups) / params.C;
|
||||||
|
dim3 block_dims;
|
||||||
|
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||||
|
dim3 num_blocks;
|
||||||
|
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||||
|
num_blocks.y = params.C;
|
||||||
|
num_blocks.z = mat_M;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(unfolded);
|
||||||
|
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<DataType>(),
|
||||||
|
unfolded.data<DataType>(),
|
||||||
|
filter_size,
|
||||||
|
out_pixels,
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
|
||||||
|
return unfolded;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
void gemm_grouped_conv_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
ConvParams<NDIM>& params,
|
||||||
|
Stream s) {
|
||||||
|
// Get gemm shapes.
|
||||||
|
int C_per_group = params.C / params.groups;
|
||||||
|
int O_per_group = params.O / params.groups;
|
||||||
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
|
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
|
||||||
|
int mat_N = O_per_group; // O_per_group
|
||||||
|
|
||||||
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
|
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
|
||||||
|
encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
|
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
|
||||||
|
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
|
||||||
|
array wt_view(
|
||||||
|
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
|
||||||
|
wt_view.copy_shared_buffer(
|
||||||
|
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||||
|
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
|
||||||
|
|
||||||
|
// Batch with size of groups.
|
||||||
|
Shape batch_shape{params.groups};
|
||||||
|
Strides a_batch_strides{mat_K};
|
||||||
|
Strides b_batch_strides{mat_N * mat_K};
|
||||||
|
|
||||||
|
// Run matmul.
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
in.dtype(),
|
||||||
|
false, // a_transposed
|
||||||
|
mat_M, // a_rows
|
||||||
|
mat_K, // a_cols
|
||||||
|
mat_K * params.groups, // lda
|
||||||
|
true, // b_transposed
|
||||||
|
mat_K, // b_rows
|
||||||
|
mat_N, // b_cols
|
||||||
|
mat_K, // ldb
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
gemm.set_out(
|
||||||
|
out.dtype(),
|
||||||
|
false, // out_transposed
|
||||||
|
mat_M, // out_rows
|
||||||
|
mat_N, // out_cols
|
||||||
|
mat_N * params.groups, // out_ld
|
||||||
|
params.groups, // batch_count
|
||||||
|
mat_N); // batch_stride
|
||||||
|
gemm.run(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
in_unfolded,
|
||||||
|
wt_reshaped,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
int conv_ndim = in.ndim() - 2;
|
||||||
|
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
|
}
|
||||||
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
|
ConvParams<ndim_constant()> params(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip);
|
||||||
|
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -10,19 +10,43 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
out[index] = CastOp<In, Out>{}(in[0]);
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = cast_to<Out>(in[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = cast_to<Out>(in[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
|
||||||
out[index] = CastOp<In, Out>{}(in[index]);
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = cast_to<Out>(in[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto in_vec = load_vector<N_READS>(in, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,16 +65,18 @@ void copy_contiguous(
|
|||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
|
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
out.data_size());
|
out.data_size());
|
||||||
|
|||||||
@@ -10,37 +10,80 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
|
||||||
__global__ void copy_gg_nd(
|
__global__ void copy_gg_nd(
|
||||||
const In* in,
|
const In* in,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
IdxT index_rest =
|
||||||
index, shape.data(), strides_in.data(), strides_out.data());
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[NDIM - 1];
|
||||||
|
auto in_stride_x = strides_in[NDIM - 1];
|
||||||
|
auto out_stride_x = strides_out[NDIM - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||||
|
index_rest * shape_x,
|
||||||
|
shape.data(),
|
||||||
|
strides_in.data(),
|
||||||
|
strides_out.data());
|
||||||
|
|
||||||
|
auto in_vec =
|
||||||
|
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void copy_gg(
|
__global__ void copy_gg(
|
||||||
const In* in,
|
const In* in,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ Shape shape,
|
const __grid_constant__ Shape shape,
|
||||||
const __grid_constant__ Strides strides_in,
|
const __grid_constant__ Strides strides_in,
|
||||||
const __grid_constant__ Strides strides_out,
|
const __grid_constant__ Strides strides_out,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
IdxT index_rest =
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[ndim - 1];
|
||||||
|
auto in_stride_x = strides_in[ndim - 1];
|
||||||
|
auto out_stride_x = strides_out[ndim - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
|
index_rest * shape_x,
|
||||||
|
shape.data(),
|
||||||
|
strides_in.data(),
|
||||||
|
strides_out.data(),
|
||||||
|
ndim);
|
||||||
|
|
||||||
|
auto in_vec =
|
||||||
|
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
@@ -69,34 +112,52 @@ void copy_general(
|
|||||||
size_t data_size = 1;
|
size_t data_size = 1;
|
||||||
for (auto& s : shape)
|
for (auto& s : shape)
|
||||||
data_size *= s;
|
data_size *= s;
|
||||||
|
|
||||||
|
int work_per_thread = 1;
|
||||||
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
|
auto rest = data_size / dim0;
|
||||||
|
if (dim0 >= 4) {
|
||||||
|
work_per_thread = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
|
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||||
|
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||||
|
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
if (work_per_thread == 4) {
|
||||||
kernel, data_size, shape, out.strides(), large());
|
kernel =
|
||||||
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
rest,
|
||||||
const_param<ndim_constant()>(shape),
|
const_param<ndim_constant()>(shape),
|
||||||
const_param<ndim_constant()>(strides_in),
|
const_param<ndim_constant()>(strides_in),
|
||||||
const_param<ndim_constant()>(strides_out));
|
const_param<ndim_constant()>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
if (work_per_thread == 4) {
|
||||||
kernel, data_size, shape, out.strides(), large());
|
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
|
|||||||
const int64_t* offset_out) {
|
const int64_t* offset_out) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
if (index < size) {
|
if (index < size) {
|
||||||
auto [idx_in, idx_out] = elem_to_loc_4d(
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
}
|
}
|
||||||
@@ -74,14 +74,16 @@ void copy_general_dynamic(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg_dynamic_nd<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@@ -92,13 +94,12 @@ void copy_general_dynamic(
|
|||||||
dynamic_offset_out.data<int64_t>());
|
dynamic_offset_out.data<int64_t>());
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -10,33 +10,67 @@ namespace cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
|
||||||
__global__ void copy_g_nd(
|
__global__ void copy_g_nd(
|
||||||
const In* in,
|
const In* in,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
IdxT index_rest =
|
||||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[NDIM - 1];
|
||||||
|
auto stride_x = strides[NDIM - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto idx =
|
||||||
|
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
|
||||||
|
auto in_vec =
|
||||||
|
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename In, typename Out, typename IdxT>
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
__global__ void copy_g(
|
__global__ void copy_g(
|
||||||
const In* in,
|
const In* in,
|
||||||
Out* out,
|
Out* out,
|
||||||
IdxT size,
|
IdxT size_rest,
|
||||||
const __grid_constant__ Shape shape,
|
const __grid_constant__ Shape shape,
|
||||||
const __grid_constant__ Strides strides_in,
|
const __grid_constant__ Strides strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
auto block = cg::this_thread_block();
|
||||||
if (index < size) {
|
auto grid = cg::this_grid();
|
||||||
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
|
IdxT index_rest =
|
||||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||||
|
if (index_rest >= size_rest) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto shape_x = shape[ndim - 1];
|
||||||
|
auto stride_x = strides[ndim - 1];
|
||||||
|
IdxT index_x =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
auto idx =
|
||||||
|
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||||
|
auto in_vec =
|
||||||
|
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||||
|
}
|
||||||
|
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
@@ -61,33 +95,49 @@ void copy_general_input(
|
|||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
int work_per_thread = 1;
|
||||||
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
|
auto rest = out.size() / dim0;
|
||||||
|
if (dim0 >= 4) {
|
||||||
|
work_per_thread = 4;
|
||||||
|
}
|
||||||
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
|
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||||
|
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||||
|
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||||
auto [num_blocks, block_dims] =
|
if (work_per_thread == 4) {
|
||||||
get_launch_args(kernel, out, large());
|
kernel =
|
||||||
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(strides_in));
|
const_param<dims_constant()>(strides_in));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||||
auto [num_blocks, block_dims] =
|
if (work_per_thread == 4) {
|
||||||
get_launch_args(kernel, out, large());
|
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||||
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
{num_blocks_x, num_blocks_y},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
ndim);
|
ndim);
|
||||||
|
|||||||
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Create a cudnn tensor descriptor.
|
||||||
|
template <typename Vec>
|
||||||
|
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||||
|
int64_t id,
|
||||||
|
const array& x,
|
||||||
|
const Vec& shape,
|
||||||
|
const Vec& strides) {
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
|
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||||
|
assert(shape.size() >= 3);
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return available engines for a |op_graph|.
|
||||||
|
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = true) {
|
||||||
|
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||||
|
sources.push_back([](auto& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
});
|
||||||
|
if (use_fallback) {
|
||||||
|
sources.push_back([&backend_type](auto& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto configs =
|
||||||
|
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||||
|
.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||||
|
// from them.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan>
|
||||||
|
find_cudnn_plan_from_engine_configs(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||||
|
const cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto op_graph_tag = op_graph.getTag();
|
||||||
|
for (const auto& config : engine_configs) {
|
||||||
|
try {
|
||||||
|
return cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(handle)
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare workspace and args to execute plan.
|
||||||
|
template <typename F>
|
||||||
|
bool prepare_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs,
|
||||||
|
F&& execute) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(
|
||||||
|
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||||
|
: allocator::Buffer(nullptr),
|
||||||
|
{workspace_size},
|
||||||
|
uint8);
|
||||||
|
|
||||||
|
auto args = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(num_args, data_ptrs)
|
||||||
|
.setUids(num_args, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
return build_cudnn_tensor(id, x, shape, x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||||
|
if (x.ndim() == 0) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 1) {
|
||||||
|
int64_t s = x.shape(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 2) {
|
||||||
|
int64_t s = x.strides(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||||
|
return build_cudnn_tensor_nchw(id, x);
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(16)
|
||||||
|
.setDataType(dtype_to_cudnn_type(dtype))
|
||||||
|
.setByValue(true)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||||
|
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
if (!graph) {
|
||||||
|
graph = CudaGraph(encoder.device());
|
||||||
|
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
164
mlx/backend/cuda/cudnn_utils.h
Normal file
164
mlx/backend/cuda/cudnn_utils.h
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
class CommandEncoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return pointer alignment of |x|'s data.
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the type of elements in |vec| to |T|.
|
||||||
|
template <typename T, typename Vec>
|
||||||
|
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||||
|
return SmallVector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||||
|
//
|
||||||
|
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||||
|
// 1. The rest of array is filled with 0.
|
||||||
|
// 2. This util can be used in .cpp files.
|
||||||
|
template <typename T, template <typename U> class Vec>
|
||||||
|
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers used by get_data_ptrs to get pointers.
|
||||||
|
inline void* get_data_ptr(const array& arr) {
|
||||||
|
return const_cast<void*>(arr.data<void>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||||
|
inline void* get_data_ptr(T& scalar) {
|
||||||
|
return &scalar;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array filled with data pointers of args.
|
||||||
|
template <typename... Args>
|
||||||
|
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||||
|
return {get_data_ptr(args)...};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map dtype to cudnn data type.
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||||
|
// from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||||
|
|
||||||
|
// Find a working plan for |op_graph|.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph);
|
||||||
|
|
||||||
|
// Encode the plan to command buffer by capturing.
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||||
|
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_capturing(
|
||||||
|
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_graph_api(
|
||||||
|
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@@ -9,12 +10,23 @@
|
|||||||
#include <future>
|
#include <future>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||||
// This should be less than 255
|
// This should be less than 255
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
constexpr int default_max_nodes_per_graph = 20;
|
||||||
|
|
||||||
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int cuda_graph_cache_size() {
|
int cuda_graph_cache_size() {
|
||||||
static int cache_size = []() {
|
static int cache_size = []() {
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||||
@@ -22,7 +34,7 @@ int cuda_graph_cache_size() {
|
|||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cu {
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
@@ -40,11 +52,18 @@ Device::Device(int device) : device_(device) {
|
|||||||
}
|
}
|
||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
cublasLtCreate(<_);
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
|
// The cudnn handle is used by Convolution.
|
||||||
|
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||||
|
|
||||||
|
// Initialize the jit module cache here ensures it is not
|
||||||
|
// unloaded before any evaluation is done
|
||||||
|
get_jit_module_cache();
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
cublasLtDestroy(lt_);
|
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
@@ -57,31 +76,26 @@ void Device::make_current() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||||
|
auto it = encoders_.find(s.index);
|
||||||
|
if (it == encoders_.end()) {
|
||||||
|
it = encoders_.try_emplace(s.index, *this).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
enc.device().make_current();
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
graph.end_capture(enc.stream());
|
||||||
size_t num_nodes;
|
if (discard) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
return;
|
||||||
if (num_nodes == 1) {
|
|
||||||
cudaGraphNode_t captured_node;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
|
||||||
CUDA_KERNEL_NODE_PARAMS params;
|
|
||||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
|
||||||
cudaGraphNode_t node;
|
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
|
||||||
} else {
|
|
||||||
cudaGraphNode_t node;
|
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
enc.add_graph_node(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||||
@@ -168,28 +182,11 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
CommandEncoder::CommandEncoder(Device& d)
|
||||||
auto it = encoders_.find(s.index);
|
: device_(d),
|
||||||
if (it == encoders_.end()) {
|
stream_(d),
|
||||||
it = encoders_.try_emplace(s.index, *this).first;
|
graph_(d),
|
||||||
}
|
graph_cache_(cuda_graph_cache_size()) {}
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
|
||||||
for (auto& [_, graph_exec] : graphs) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
|
||||||
}
|
|
||||||
graphs.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandEncoder::~CommandEncoder() {
|
|
||||||
clear_graphs(graph_cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
@@ -216,22 +213,22 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void* func,
|
void* func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
cudaGraphNode_t node;
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
CHECK_CUDA_ERROR(
|
add_kernel_node(kernel_params);
|
||||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(
|
void CommandEncoder::add_kernel_node(
|
||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
@@ -242,13 +239,30 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
CUgraphNode node;
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
CHECK_CUDA_ERROR(
|
add_kernel_node(kernel_params);
|
||||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
|
CUgraphNode node;
|
||||||
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
|
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
}
|
}
|
||||||
@@ -265,7 +279,7 @@ void CommandEncoder::commit() {
|
|||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
graph_key_ += std::to_string(empty_node_count_);
|
graph_key_ += std::to_string(empty_node_count_);
|
||||||
|
|
||||||
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
|
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||||
|
|
||||||
if (graph_exec != nullptr) {
|
if (graph_exec != nullptr) {
|
||||||
cudaGraphExecUpdateResult update_result;
|
cudaGraphExecUpdateResult update_result;
|
||||||
@@ -279,34 +293,27 @@ void CommandEncoder::commit() {
|
|||||||
#endif // CUDART_VERSION >= 12000
|
#endif // CUDART_VERSION >= 12000
|
||||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||||
cudaGetLastError(); // reset error
|
cudaGetLastError(); // reset error
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
graph_exec.reset();
|
||||||
graph_exec = nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (graph_exec == nullptr) {
|
if (graph_exec == nullptr) {
|
||||||
CHECK_CUDA_ERROR(
|
graph_exec.instantiate(graph_);
|
||||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
|
||||||
}
|
}
|
||||||
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
// TODO smarter cache policy
|
|
||||||
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
|
||||||
clear_graphs(graph_cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
node_count_ = 0;
|
node_count_ = 0;
|
||||||
graph_node_count_ = 0;
|
graph_node_count_ = 0;
|
||||||
|
empty_node_count_ = 0;
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
graph_key_.clear();
|
graph_key_.clear();
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
graph_ = CudaGraph(device_);
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.end_batch();
|
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +322,6 @@ void CommandEncoder::synchronize() {
|
|||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
|
||||||
commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
@@ -333,6 +339,4 @@ CommandEncoder& get_command_encoder(Stream s) {
|
|||||||
return device(s.device).get_command_encoder(s);
|
return device(s.device).get_command_encoder(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace mlx::core::cu
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
#include <cudnn.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -19,8 +21,9 @@ class CommandEncoder {
|
|||||||
struct CaptureContext {
|
struct CaptureContext {
|
||||||
CaptureContext(CommandEncoder& enc);
|
CaptureContext(CommandEncoder& enc);
|
||||||
~CaptureContext();
|
~CaptureContext();
|
||||||
cudaGraph_t graph;
|
CudaGraph graph;
|
||||||
CommandEncoder& enc;
|
CommandEncoder& enc;
|
||||||
|
bool discard{false};
|
||||||
};
|
};
|
||||||
struct ConcurrentContext {
|
struct ConcurrentContext {
|
||||||
ConcurrentContext(CommandEncoder& enc);
|
ConcurrentContext(CommandEncoder& enc);
|
||||||
@@ -29,7 +32,6 @@ class CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
explicit CommandEncoder(Device& d);
|
explicit CommandEncoder(Device& d);
|
||||||
~CommandEncoder();
|
|
||||||
|
|
||||||
CommandEncoder(const CommandEncoder&) = delete;
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
@@ -45,25 +47,39 @@ class CommandEncoder {
|
|||||||
void set_output_array(const array& arr);
|
void set_output_array(const array& arr);
|
||||||
|
|
||||||
template <typename F, typename... Params>
|
template <typename F, typename... Params>
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
F* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
Params&&... params) {
|
||||||
constexpr size_t num = sizeof...(Params);
|
constexpr size_t num = sizeof...(Params);
|
||||||
void* ptrs[num];
|
void* ptrs[num];
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||||
std::forward<Params>(params)),
|
std::forward<Params>(params)),
|
||||||
...);
|
...);
|
||||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_kernel_node(
|
void add_kernel_node(
|
||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
void* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
void** params);
|
||||||
|
|
||||||
|
// Low-level graph helpers.
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
temporaries_.push_back(arr.data_shared_ptr());
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
@@ -73,6 +89,10 @@ class CommandEncoder {
|
|||||||
void maybe_commit();
|
void maybe_commit();
|
||||||
void commit();
|
void commit();
|
||||||
|
|
||||||
|
Device& device() {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
|
||||||
CudaStream& stream() {
|
CudaStream& stream() {
|
||||||
return stream_;
|
return stream_;
|
||||||
}
|
}
|
||||||
@@ -93,8 +113,9 @@ class CommandEncoder {
|
|||||||
void insert_graph_dependencies(GraphNode node);
|
void insert_graph_dependencies(GraphNode node);
|
||||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||||
|
|
||||||
|
Device& device_;
|
||||||
CudaStream stream_;
|
CudaStream stream_;
|
||||||
cudaGraph_t graph_;
|
CudaGraph graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
char node_count_{0};
|
char node_count_{0};
|
||||||
char graph_node_count_{0};
|
char graph_node_count_{0};
|
||||||
@@ -105,7 +126,7 @@ class CommandEncoder {
|
|||||||
std::string graph_key_;
|
std::string graph_key_;
|
||||||
std::vector<GraphNode> concurrent_nodes_;
|
std::vector<GraphNode> concurrent_nodes_;
|
||||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||||
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
|
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||||
std::vector<std::uintptr_t> active_deps_;
|
std::vector<std::uintptr_t> active_deps_;
|
||||||
std::vector<std::uintptr_t> active_outputs_;
|
std::vector<std::uintptr_t> active_outputs_;
|
||||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||||
@@ -136,12 +157,16 @@ class Device {
|
|||||||
cublasLtHandle_t lt_handle() const {
|
cublasLtHandle_t lt_handle() const {
|
||||||
return lt_;
|
return lt_;
|
||||||
}
|
}
|
||||||
|
cudnnHandle_t cudnn_handle() const {
|
||||||
|
return cudnn_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
int compute_capability_major_;
|
int compute_capability_major_;
|
||||||
int compute_capability_minor_;
|
int compute_capability_minor_;
|
||||||
cublasLtHandle_t lt_;
|
cublasLtHandle_t lt_;
|
||||||
|
cudnnHandle_t cudnn_;
|
||||||
std::unordered_map<int, CommandEncoder> encoders_;
|
std::unordered_map<int, CommandEncoder> encoders_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Arange {
|
|
||||||
const T start;
|
|
||||||
const T step;
|
|
||||||
|
|
||||||
__device__ T operator()(uint32_t i) const {
|
|
||||||
return start + i * step;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
|
||||||
#include <cuda/atomic>
|
#include <cuda/atomic>
|
||||||
@@ -48,22 +48,13 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
|||||||
atomicAdd(out, val);
|
atomicAdd(out, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||||
#if __CUDA_ARCH__ < 900
|
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
|
||||||
atomicAdd(out, val);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||||
#if __CUDA_ARCH__ < 800
|
#if __CUDA_ARCH__ < 800
|
||||||
#if CCCL_VERSION >= 2008000
|
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
|
||||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
|
||||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
atomicAdd(out, val);
|
atomicAdd(out, val);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@@ -47,7 +44,7 @@ struct Remainder {
|
|||||||
} else {
|
} else {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
return x % y;
|
return x % y;
|
||||||
} else {
|
} else {
|
||||||
T r = fmod(x, y);
|
T r = fmod(x, y);
|
||||||
@@ -69,14 +66,12 @@ struct Equal {
|
|||||||
struct NaNEqual {
|
struct NaNEqual {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ bool operator()(T x, T y) {
|
__device__ bool operator()(T x, T y) {
|
||||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return x == y ||
|
return x == y ||
|
||||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
|
||||||
isnan(cuCimagf(y))) ||
|
isnan(y.imag())) ||
|
||||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
|
||||||
isnan(cuCimagf(y))) ||
|
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
|
||||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
|
||||||
cuCimagf(x) == cuCimagf(y));
|
|
||||||
} else {
|
} else {
|
||||||
return x == y || (isnan(x) && isnan(y));
|
return x == y || (isnan(x) && isnan(y));
|
||||||
}
|
}
|
||||||
@@ -114,36 +109,38 @@ struct LessEqual {
|
|||||||
struct LogAddExp {
|
struct LogAddExp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if (isnan(x) || isnan(y)) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
|
||||||
|
isnan(y.imag())) {
|
||||||
|
return {
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
|
}
|
||||||
|
auto max = x.real() > y.real() ? x : y;
|
||||||
|
auto min = x.real() < y.real() ? x : y;
|
||||||
|
auto min_real = min.real();
|
||||||
|
auto max_real = max.real();
|
||||||
|
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||||
|
if (min_real < 0) {
|
||||||
|
return min;
|
||||||
|
} else {
|
||||||
|
return Log{}(Exp{}(min) + Exp{}(max));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Log1p{}(Exp{}(min - max)) + max;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (isnan(x) || isnan(y)) {
|
||||||
|
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
T maxval = max(x, y);
|
||||||
|
T minval = min(x, y);
|
||||||
|
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||||
|
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||||
|
? maxval
|
||||||
|
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||||
}
|
}
|
||||||
T maxval = max(x, y);
|
|
||||||
T minval = min(x, y);
|
|
||||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
|
||||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
|
||||||
? maxval
|
|
||||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
|
||||||
isnan(cuCimagf(y))) {
|
|
||||||
return {
|
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
|
||||||
}
|
|
||||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
|
||||||
auto maxval = x > y ? x : y;
|
|
||||||
auto minval = x < y ? x : y;
|
|
||||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
|
||||||
return maxval;
|
|
||||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
|
||||||
cuComplex dexp{
|
|
||||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
|
||||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
|
||||||
};
|
|
||||||
return maxval + log1p(dexp);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Maximum {
|
struct Maximum {
|
||||||
@@ -151,8 +148,8 @@ struct Maximum {
|
|||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return max(x, y);
|
return max(x, y);
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
if (isnan(x.real()) || isnan(x.imag())) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
return x > y ? x : y;
|
return x > y ? x : y;
|
||||||
@@ -170,8 +167,8 @@ struct Minimum {
|
|||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return min(x, y);
|
return min(x, y);
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
if (isnan(x.real()) || isnan(x.imag())) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
return x < y ? x : y;
|
return x < y ? x : y;
|
||||||
@@ -194,8 +191,8 @@ struct Multiply {
|
|||||||
struct NotEqual {
|
struct NotEqual {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ bool operator()(T x, T y) {
|
__device__ bool operator()(T x, T y) {
|
||||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
return x.real() != y.real() || x.imag() != y.imag();
|
||||||
} else {
|
} else {
|
||||||
return x != y;
|
return x != y;
|
||||||
}
|
}
|
||||||
@@ -215,19 +212,8 @@ struct Power {
|
|||||||
base *= base;
|
base *= base;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (base.y == 0 && base.x == 0) {
|
return pow(base, exp);
|
||||||
if (isnan(exp.x) || isnan(exp.y)) {
|
|
||||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return make_cuFloatComplex(nan, nan);
|
|
||||||
}
|
|
||||||
return make_cuFloatComplex(0.0, 0.0);
|
|
||||||
}
|
|
||||||
auto x_theta = atan2f(base.y, base.x);
|
|
||||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
|
||||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
|
||||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
|
||||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
|
||||||
} else {
|
} else {
|
||||||
return powf(base, exp);
|
return powf(base, exp);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
#include <thrust/iterator/transform_iterator.h>
|
||||||
@@ -20,50 +21,43 @@ struct CastOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Castings between complex and boolean.
|
// Castings between complex and boolean.
|
||||||
// TODO: Should make a custom complex type.
|
template <typename T>
|
||||||
template <>
|
struct CastOp<complex_t<T>, bool> {
|
||||||
struct CastOp<cuComplex, bool> {
|
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
__device__ bool operator()(cuComplex x) {
|
__device__ bool operator()(complex_t<T> x) {
|
||||||
return x.x != 0 && x.y != 0;
|
return x.real() != 0 && x.imag() != 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename T>
|
||||||
struct CastOp<bool, cuComplex> {
|
struct CastOp<bool, complex_t<T>> {
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
__device__ cuComplex operator()(bool x) {
|
__device__ complex_t<T> operator()(bool x) {
|
||||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converting a complex number to real number discards the imaginary part.
|
// Converting a complex number to real number discards the imaginary part.
|
||||||
template <typename DstT>
|
template <typename T, typename DstT>
|
||||||
struct CastOp<
|
struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {
|
||||||
cuComplex,
|
static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;
|
||||||
DstT,
|
|
||||||
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
|
||||||
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
|
||||||
|
|
||||||
__device__ DstT operator()(cuComplex x) {
|
__device__ DstT operator()(complex_t<T> x) {
|
||||||
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
static_assert(!is_complex_v<DstT>);
|
||||||
return static_cast<DstT>(cuCrealf(x));
|
return static_cast<DstT>(x.real());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allow converting a real number to complex number.
|
// Allow converting a real number to complex number.
|
||||||
template <typename SrcT>
|
template <typename SrcT, typename T>
|
||||||
struct CastOp<
|
struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {
|
||||||
SrcT,
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;
|
||||||
cuComplex,
|
|
||||||
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
|
||||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
|
||||||
|
|
||||||
__device__ cuComplex operator()(SrcT x) {
|
__device__ complex_t<T> operator()(SrcT x) {
|
||||||
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
static_assert(!is_complex_v<SrcT>);
|
||||||
return cuComplex{static_cast<float>(x), 0};
|
return complex_t<T>{static_cast<T>(x), 0};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -88,8 +82,7 @@ struct CastOp<
|
|||||||
SrcT,
|
SrcT,
|
||||||
DstT,
|
DstT,
|
||||||
cuda::std::enable_if_t<
|
cuda::std::enable_if_t<
|
||||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||||
!cuda::std::is_same_v<SrcT, cuComplex> &&
|
|
||||||
(cuda::std::is_same_v<DstT, __half> ||
|
(cuda::std::is_same_v<DstT, __half> ||
|
||||||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
@@ -104,8 +97,7 @@ struct CastOp<
|
|||||||
SrcT,
|
SrcT,
|
||||||
DstT,
|
DstT,
|
||||||
cuda::std::enable_if_t<
|
cuda::std::enable_if_t<
|
||||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||||
!cuda::std::is_same_v<DstT, cuComplex> &&
|
|
||||||
!cuda::std::is_same_v<DstT, __half> &&
|
!cuda::std::is_same_v<DstT, __half> &&
|
||||||
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
||||||
(cuda::std::is_same_v<SrcT, __half> ||
|
(cuda::std::is_same_v<SrcT, __half> ||
|
||||||
|
|||||||
60
mlx/backend/cuda/device/complex.cuh
Normal file
60
mlx/backend/cuda/device/complex.cuh
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Make multiplication and division faster.
|
||||||
|
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
|
||||||
|
|
||||||
|
#include <cuda/std/complex>
|
||||||
|
#include <cuda/std/type_traits>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// TODO: Consider using a faster implementation as cuda::std::complex has to
|
||||||
|
// conform to C++ standard.
|
||||||
|
template <typename T>
|
||||||
|
using complex_t = cuda::std::complex<T>;
|
||||||
|
|
||||||
|
using complex64_t = complex_t<float>;
|
||||||
|
using complex128_t = complex_t<double>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex : cuda::std::false_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_complex_v = is_complex<T>::value;
|
||||||
|
|
||||||
|
// cuda::std::complex is missing some operators.
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ complex_t<T> operator%(
|
||||||
|
complex_t<T> a,
|
||||||
|
complex_t<T> b) {
|
||||||
|
T r = a.real() - floor(a.real() / b.real()) * b.real();
|
||||||
|
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
|
||||||
|
return complex_t<T>{r, i};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return operator>(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return !(a > b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return !(a < b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
|
||||||
//
|
|
||||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
|
||||||
// "License"); you may not use this file except in compliance with the
|
|
||||||
// License. You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
//
|
|
||||||
// Forked from
|
|
||||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
|
|
||||||
// This header provides some helper functions for cuComplex types.
|
|
||||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
|
||||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
|
||||||
// all provided by CUDA
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCadd(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCsub(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCmul(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCdiv(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
|
||||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
|
||||||
return make_cuDoubleComplex(r, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator==(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator!=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
|
||||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
|
||||||
return mag_a > mag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return a > b || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return b > a;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return b > a || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(double a, const cuDoubleComplex& b) {
|
|
||||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
|
||||||
return make_cuDoubleComplex(
|
|
||||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCaddf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCsubf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCmulf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCdivf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
|
||||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
|
||||||
return make_cuFloatComplex(r, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator==(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator!=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
|
||||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
|
||||||
return mag_a > mag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return a > b || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return b > a;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return b > a || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(float a, const cuFloatComplex& b) {
|
|
||||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
|
||||||
return make_cuFloatComplex(
|
|
||||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
|
||||||
}
|
|
||||||
@@ -14,8 +14,6 @@ struct Abs {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
|
||||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
|
||||||
} else {
|
} else {
|
||||||
return abs(x);
|
return abs(x);
|
||||||
}
|
}
|
||||||
@@ -27,8 +25,6 @@ struct ArcCos {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return acos(x);
|
return acos(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcCosh {
|
struct ArcCosh {
|
||||||
@@ -43,8 +39,6 @@ struct ArcSin {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return asin(x);
|
return asin(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcSinh {
|
struct ArcSinh {
|
||||||
@@ -59,8 +53,6 @@ struct ArcTan {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return atan(x);
|
return atan(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcTanh {
|
struct ArcTanh {
|
||||||
@@ -82,6 +74,8 @@ struct Ceil {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
|
} else if constexpr (is_complex_v<T>) {
|
||||||
|
return T{ceil(x.real()), ceil(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return ceil(x);
|
return ceil(x);
|
||||||
}
|
}
|
||||||
@@ -89,34 +83,23 @@ struct Ceil {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Conjugate {
|
struct Conjugate {
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
template <typename T>
|
||||||
return {cuCrealf(x), -cuCimagf(x)};
|
__device__ complex_t<T> operator()(complex_t<T> x) {
|
||||||
|
return conj(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Cos {
|
struct Cos {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return cos(x);
|
||||||
return {
|
|
||||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
|
||||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return cos(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Cosh {
|
struct Cosh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return cosh(x);
|
||||||
return {
|
|
||||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
|
||||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return cosh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -149,12 +132,7 @@ struct ErfInv {
|
|||||||
struct Exp {
|
struct Exp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return exp(x);
|
||||||
auto m = exp(cuCrealf(x));
|
|
||||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return exp(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -176,6 +154,8 @@ struct Floor {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
|
} else if constexpr (is_complex_v<T>) {
|
||||||
|
return T{floor(x.real()), floor(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return floor(x);
|
return floor(x);
|
||||||
}
|
}
|
||||||
@@ -183,30 +163,25 @@ struct Floor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Imag {
|
struct Imag {
|
||||||
__device__ float operator()(cuComplex x) {
|
template <typename T>
|
||||||
return cuCimagf(x);
|
__device__ auto operator()(complex_t<T> x) {
|
||||||
|
return x.imag();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return log(x);
|
||||||
auto r = log(cuCrealf(Abs{}(x)));
|
|
||||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
|
||||||
return {r, i};
|
|
||||||
} else {
|
|
||||||
return log(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};
|
||||||
} else {
|
} else {
|
||||||
return log2(x);
|
return log2(x);
|
||||||
}
|
}
|
||||||
@@ -216,20 +191,31 @@ struct Log2 {
|
|||||||
struct Log10 {
|
struct Log10 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return log10(x);
|
||||||
auto y = Log{}(x);
|
|
||||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
|
||||||
return y;
|
|
||||||
} else {
|
|
||||||
return log10(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log1p {
|
struct Log1p {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T z) {
|
||||||
return log1p(x);
|
if constexpr (is_complex_v<T>) {
|
||||||
|
float x = z.real();
|
||||||
|
float y = z.imag();
|
||||||
|
float zabs = Abs{}(z).real();
|
||||||
|
float theta = atan2f(y, x + 1);
|
||||||
|
if (zabs < 0.5f) {
|
||||||
|
float r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return {x, theta};
|
||||||
|
}
|
||||||
|
return {0.5f * log1pf(r), theta};
|
||||||
|
} else {
|
||||||
|
float z0 = hypotf(x + 1, y);
|
||||||
|
return {logf(z0), theta};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return log1p(z);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -242,8 +228,8 @@ struct LogicalNot {
|
|||||||
struct Negative {
|
struct Negative {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return 0 - x;
|
return T{0, 0} - x;
|
||||||
} else {
|
} else {
|
||||||
return -x;
|
return -x;
|
||||||
}
|
}
|
||||||
@@ -251,16 +237,17 @@ struct Negative {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Real {
|
struct Real {
|
||||||
__device__ float operator()(cuComplex x) {
|
template <typename T>
|
||||||
return cuCrealf(x);
|
__device__ auto operator()(complex_t<T> x) {
|
||||||
|
return x.real();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
return {rint(x.real()), rint(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return rint(x);
|
return rint(x);
|
||||||
}
|
}
|
||||||
@@ -280,8 +267,8 @@ struct Sign {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
if (x.real() == 0 && x.imag() == 0) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
return x / Abs()(x);
|
return x / Abs()(x);
|
||||||
@@ -297,26 +284,14 @@ struct Sign {
|
|||||||
struct Sin {
|
struct Sin {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return sin(x);
|
||||||
return {
|
|
||||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
|
||||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return sin(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sinh {
|
struct Sinh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return sinh(x);
|
||||||
return {
|
|
||||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
|
||||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return sinh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -332,77 +307,31 @@ struct Sqrt {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return sqrt(x);
|
return sqrt(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
|
||||||
auto xr = cuCrealf(x);
|
|
||||||
auto xi = cuCimagf(x);
|
|
||||||
if (xr == 0.0f && xi == 0.0f) {
|
|
||||||
return {0.0f, 0.0f};
|
|
||||||
}
|
|
||||||
auto r = cuCrealf(Abs{}(x));
|
|
||||||
auto a = sqrt((r + xr) / 2.0f);
|
|
||||||
auto b_abs = sqrt((r - xr) / 2.0f);
|
|
||||||
auto b = copysign(b_abs, xi);
|
|
||||||
return {a, b};
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Rsqrt {
|
struct Rsqrt {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return rsqrt(x);
|
if constexpr (is_complex_v<T>) {
|
||||||
}
|
return 1.0f / Sqrt{}(x);
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
} else {
|
||||||
return 1.0f / Sqrt{}(x);
|
return rsqrt(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tan {
|
struct Tan {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return tan(x);
|
||||||
float tan_a = tan(cuCrealf(x));
|
|
||||||
float tanh_b = tanh(cuCimagf(x));
|
|
||||||
float t1 = tan_a * tanh_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
|
||||||
} else {
|
|
||||||
return tan(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tanh {
|
struct Tanh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return tanh(x);
|
||||||
float tanh_a = tanh(cuCrealf(x));
|
|
||||||
float tan_b = tan(cuCimagf(x));
|
|
||||||
float t1 = tanh_a * tan_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
|
||||||
} else {
|
|
||||||
return tanh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0, 1.0};
|
|
||||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
|
||||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
|
||||||
auto ix = i * x;
|
|
||||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -8,9 +8,9 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
@@ -32,23 +32,137 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
|||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct alignas(sizeof(T) * N) AlignedVector {
|
struct alignas(sizeof(T) * N) AlignedVector {
|
||||||
T val[N];
|
T val[N];
|
||||||
|
|
||||||
|
__device__ T& operator[](int i) {
|
||||||
|
return val[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ T operator[](int i) const {
|
||||||
|
return val[i];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ AlignedVector<T, N> load_vector(
|
inline __host__ __device__ bool is_aligned(T* x) {
|
||||||
|
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ AlignedVector<T, N> unsafe_load_vector(
|
||||||
const T* ptr,
|
const T* ptr,
|
||||||
uint32_t offset) {
|
uint32_t offset) {
|
||||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
return from[offset];
|
return from[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
|
const T* ptr,
|
||||||
|
uint32_t offset) {
|
||||||
|
if (is_aligned<N>(ptr)) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] = ptr[offset * N + i];
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ AlignedVector<T, N>
|
||||||
|
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
|
||||||
|
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
|
const T* ptr,
|
||||||
|
uint32_t offset,
|
||||||
|
SizeT size,
|
||||||
|
int64_t stride,
|
||||||
|
T fallback) {
|
||||||
|
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
} else {
|
||||||
|
AlignedVector<T, N> v;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
v[i] =
|
||||||
|
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
to[offset] = vec;
|
to[offset] = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ void
|
||||||
|
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
|
if (is_aligned<N>(ptr)) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
ptr[offset * N + i] = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ void store_vector(
|
||||||
|
T* ptr,
|
||||||
|
uint32_t offset,
|
||||||
|
const AlignedVector<T, N>& vec,
|
||||||
|
SizeT size) {
|
||||||
|
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
} else {
|
||||||
|
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||||
|
ptr[offset * N + i] = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename T, typename SizeT>
|
||||||
|
inline __device__ void store_vector(
|
||||||
|
T* ptr,
|
||||||
|
uint32_t offset,
|
||||||
|
const AlignedVector<T, N>& vec,
|
||||||
|
SizeT size,
|
||||||
|
int64_t stride) {
|
||||||
|
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
} else {
|
||||||
|
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||||
|
ptr[stride * (offset * N + i)] = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -127,13 +241,13 @@ struct Limits<bool> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename T>
|
||||||
struct Limits<cuComplex> {
|
struct Limits<complex_t<T>> {
|
||||||
static constexpr __host__ __device__ cuComplex max() {
|
static constexpr __host__ __device__ complex_t<T> max() {
|
||||||
return {Limits<float>::max(), Limits<float>::max()};
|
return {Limits<T>::max(), Limits<T>::max()};
|
||||||
}
|
}
|
||||||
static constexpr __host__ __device__ cuComplex min() {
|
static constexpr __host__ __device__ complex_t<T> min() {
|
||||||
return {Limits<float>::min(), Limits<float>::min()};
|
return {Limits<T>::min(), Limits<T>::min()};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -204,20 +318,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimized version when ndim is larger than 4.
|
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ IdxT
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
|
||||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
|
||||||
IdxT loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT = int64_t>
|
|
||||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|
||||||
IdxT elem,
|
IdxT elem,
|
||||||
const int* shape,
|
const int* shape,
|
||||||
const int64_t* a_strides,
|
const int64_t* a_strides,
|
||||||
@@ -235,7 +337,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
|
||||||
IdxT elem,
|
IdxT elem,
|
||||||
const int* shape,
|
const int* shape,
|
||||||
const int64_t* a_strides,
|
const int64_t* a_strides,
|
||||||
@@ -359,21 +461,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline __device__ cuComplex log1p(cuComplex in) {
|
|
||||||
float x = cuCrealf(in);
|
|
||||||
float y = cuCimagf(in);
|
|
||||||
float zabs = sqrt(x * x + y * y);
|
|
||||||
float theta = atan2f(y, x + 1);
|
|
||||||
if (zabs < 0.5f) {
|
|
||||||
float r = x * (2 + x) + y * y;
|
|
||||||
if (r == 0) { // handle underflow
|
|
||||||
return {x, theta};
|
|
||||||
}
|
|
||||||
return {0.5f * log1pf(r), theta};
|
|
||||||
} else {
|
|
||||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
|
||||||
return {log(z0), theta};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
51
mlx/backend/cuda/distributed.cu
Normal file
51
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
|
||||||
|
if (input.is_donatable()) {
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
} else {
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -19,8 +19,6 @@ void new_stream(Stream s) {
|
|||||||
cudaFree(nullptr);
|
cudaFree(nullptr);
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
// The main thread is safe to free buffers.
|
|
||||||
cu::allocator().register_this_thread();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
@@ -38,18 +36,15 @@ void eval(array& arr) {
|
|||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||||
// Keep used buffers alive until kernel finishes running.
|
// Keep used buffers alive until kernel finishes running.
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
buffers.insert(in.data_shared_ptr());
|
// Except for the donated one.
|
||||||
|
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
buffers.insert(s.data_shared_ptr());
|
encoder.add_temporary(s);
|
||||||
}
|
}
|
||||||
// Remove the output if it was donated to by an input.
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
|
||||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
|
||||||
encoder.maybe_commit();
|
encoder.maybe_commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -90,8 +90,6 @@ bool CudaEvent::completed() const {
|
|||||||
// SharedEvent implementations
|
// SharedEvent implementations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||||
uint64_t current;
|
uint64_t current;
|
||||||
while ((current = ac->load()) < value) {
|
while ((current = ac->load()) < value) {
|
||||||
@@ -112,26 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
event_signal(ac, value);
|
event_signal(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||||
|
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
buf_ = std::shared_ptr<Buffer>(
|
||||||
Atomic* ac;
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
allocator().free(*ptr);
|
||||||
new (ac) Atomic(0);
|
delete ptr;
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
});
|
||||||
ptr->~Atomic();
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||||
allocator().cuda_free(ptr);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(uint64_t value) {
|
void SharedEvent::wait(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||||
event_wait(ac_.get(), value);
|
event_wait(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||||
@@ -142,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
wait(encoder.stream(), value);
|
wait(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(uint64_t value) {
|
void SharedEvent::signal(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||||
event_signal(ac_.get(), value);
|
event_signal(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
@@ -166,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
signal(encoder.stream(), value);
|
signal(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||||
return ac_->load() >= value;
|
return to_atomic(buf_)->load() >= value;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t SharedEvent::value() const {
|
uint64_t SharedEvent::value() const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||||
return ac_->load();
|
return to_atomic(buf_)->load();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -55,12 +56,8 @@ class SharedEvent {
|
|||||||
bool is_signaled(uint64_t value) const;
|
bool is_signaled(uint64_t value) const;
|
||||||
uint64_t value() const;
|
uint64_t value() const;
|
||||||
|
|
||||||
const std::shared_ptr<Atomic>& atomic() const {
|
|
||||||
return ac_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Atomic> ac_;
|
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user