mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
217 Commits
v0.26.5
...
sign-warns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24828b1b2f | ||
|
|
9f649b5658 | ||
|
|
18aa921388 | ||
|
|
8d13a0bc6b | ||
|
|
ac75c87fd7 | ||
|
|
7107802e09 | ||
|
|
c5913131cf | ||
|
|
19ab7911f6 | ||
|
|
4a1b1796b7 | ||
|
|
b48d298205 | ||
|
|
8277e71ea9 | ||
|
|
b0d985416a | ||
|
|
8d10f3ec75 | ||
|
|
6343622c67 | ||
|
|
979abf462b | ||
|
|
981d2fdaf0 | ||
|
|
5a306d3495 | ||
|
|
5baa361779 | ||
|
|
1bac0db7e3 | ||
|
|
a1212b4e44 | ||
|
|
45a8b226af | ||
|
|
76ef1e98f3 | ||
|
|
63d91557e0 | ||
|
|
310e501e6a | ||
|
|
cacc3ab7fd | ||
|
|
53525cba23 | ||
|
|
3d67b717a0 | ||
|
|
953b2f5be2 | ||
|
|
26f7155537 | ||
|
|
66fcb9fe94 | ||
|
|
d1e06117e8 | ||
|
|
539d8322d1 | ||
|
|
c4767d110f | ||
|
|
895217f25b | ||
|
|
0cfeeb60ca | ||
|
|
8f8af61a37 | ||
|
|
233384161e | ||
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 | ||
|
|
9f0ba3ddf1 | ||
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 |
@@ -7,6 +7,9 @@ parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -15,16 +18,17 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
xcode: "26.0.0"
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
brew install python@3.10
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
python3.10 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
@@ -78,23 +82,25 @@ jobs:
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get upgrade -y
|
||||
pip install --upgrade cmake
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
pip install -e ".[dev]"
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
@@ -102,6 +108,7 @@ jobs:
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
@@ -113,7 +120,7 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -121,39 +128,37 @@ jobs:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
uv venv --python 3.10
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
uv pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
@@ -162,16 +167,17 @@ jobs:
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
@@ -180,7 +186,7 @@ jobs:
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
@@ -192,43 +198,85 @@ jobs:
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: linux-cuda-12:2023.11.1
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
default: "3.10"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -237,7 +285,7 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
@@ -245,11 +293,15 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
mkdir -p ~/miniconda3
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
@@ -259,29 +311,29 @@ jobs:
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
@@ -290,7 +342,7 @@ jobs:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
@@ -299,7 +351,7 @@ jobs:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
default: "3.10"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -315,14 +367,10 @@ jobs:
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get upgrade -y
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo apt-get install -y apt-utils
|
||||
sudo apt-get install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
@@ -339,7 +387,7 @@ jobs:
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
@@ -366,22 +414,27 @@ jobs:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
@@ -392,7 +445,6 @@ jobs:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
@@ -405,19 +457,24 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
@@ -427,71 +484,10 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -507,7 +503,7 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
@@ -533,11 +529,14 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -547,57 +546,34 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
|
||||
@@ -19,12 +19,17 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Organizations
|
||||
|
||||
MLX has received contributions from the following companies:
|
||||
- NVIDIA Corporation & Affiliates
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
MLX leverages several third-party software, listed here together with
|
||||
|
||||
@@ -20,12 +20,17 @@ project(
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||
add_compile_options(-Wall -Wextra)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# ----------------------------- Configuration -----------------------------
|
||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||
@@ -41,7 +46,9 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@@ -68,6 +75,15 @@ else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
@@ -76,22 +92,21 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
if(MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(METAL_LIB)
|
||||
message(STATUS "Metal found ${METAL_LIB}")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||
endif()
|
||||
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
@@ -100,7 +115,8 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -129,6 +145,12 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||
# building on old distributions they need to be explicitly listed.
|
||||
target_link_libraries(mlx PRIVATE dl pthread)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
@@ -156,7 +178,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
@@ -232,12 +254,16 @@ target_include_directories(
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
||||
51
README.md
51
README.md
@@ -11,31 +11,31 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
|
||||
MLX is designed by machine learning researchers for machine learning
|
||||
researchers. The framework is intended to be user-friendly, but still efficient
|
||||
@@ -68,18 +68,23 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||
macOS, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
To install the CUDA backend on Linux, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
```
|
||||
|
||||
Checkout the
|
||||
@@ -105,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||
MLX useful in your research and wish to cite it, please use the following
|
||||
BibTex entry:
|
||||
|
||||
```
|
||||
```text
|
||||
@software{mlx2023,
|
||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||
|
||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
dtypes = ("float32", "float16", "complex64")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
||||
@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
for dtype in ("float32", "float16", "complex64"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
@@ -215,7 +215,7 @@ for transpose in (False, True):
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||
# directories.
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||
|
||||
find_path(
|
||||
NCCL_INCLUDE_DIRS
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
|
||||
if($ENV{USE_STATIC_NCCL})
|
||||
message(
|
||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||
set(NCCL_LIBNAME "libnccl_static.a")
|
||||
else()
|
||||
set(NCCL_LIBNAME "nccl")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NCCL_LIBRARIES
|
||||
NAMES ${NCCL_LIBNAME}
|
||||
HINTS ${NCCL_LIB_DIR}
|
||||
${NCCL_ROOT_DIR}
|
||||
${NCCL_ROOT_DIR}/lib
|
||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||
${NCCL_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||
NCCL_LIBRARIES)
|
||||
|
||||
if(NCCL_FOUND)
|
||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||
message(
|
||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||
file(
|
||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||
LIMIT_COUNT 1)
|
||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||
endif()
|
||||
message(
|
||||
STATUS
|
||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||
endif()
|
||||
@@ -1,4 +1,5 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
sphinx-copybutton
|
||||
mlx
|
||||
|
||||
@@ -18,6 +18,7 @@ release = version
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
|
||||
@@ -127,7 +127,8 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
source=source,
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
|
||||
@@ -394,14 +394,14 @@ below.
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
std::stream kname;
|
||||
kname = "axpby_general_" + type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
|
||||
@@ -13,10 +13,10 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI you must meet the following requirements:
|
||||
To install from PyPI your system must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.9
|
||||
- Using a native Python >= 3.10
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
@@ -26,12 +26,21 @@ To install from PyPI you must meet the following requirements:
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||
MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "mlx[cuda]"
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
@@ -40,7 +49,14 @@ For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "mlx[cpu]"
|
||||
pip install mlx[cpu]
|
||||
|
||||
To install the CPU-only package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@@ -255,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
|
||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
||||
CUDA
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.cuda
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -13,3 +13,4 @@ Fast
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
||||
cuda_kernel
|
||||
|
||||
@@ -27,6 +27,7 @@ simple functions.
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu2
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
|
||||
@@ -50,6 +50,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU2
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
|
||||
@@ -112,6 +112,7 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
|
||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
state = tree_flatten(optimizer.state, destination={})
|
||||
mx.save_safetensors("optimizer.safetensors", state)
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
|
||||
@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
timeit(gelu, x)
|
||||
timeit(mx.compile(gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
|
||||
@@ -184,7 +184,7 @@ almost identical to the example above:
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
out, = imported_abs(mx.array([-1.0]))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
@@ -107,8 +107,20 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||
mutating it does not mutate the original array:
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a[:]
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 3], dtype=int32)
|
||||
|
||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
||||
@@ -14,14 +14,17 @@ void array_basics() {
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
assert(s == 1.0);
|
||||
(void)s;
|
||||
|
||||
// Scalars have a size of 1:
|
||||
size_t size = x.size();
|
||||
int64_t size = x.size();
|
||||
assert(size == 1);
|
||||
(void)size;
|
||||
|
||||
// Scalars have 0 dimensions:
|
||||
int ndim = x.ndim();
|
||||
assert(ndim == 0);
|
||||
(void)ndim;
|
||||
|
||||
// The shape should be an empty vector:
|
||||
auto shape = x.shape();
|
||||
@@ -30,6 +33,7 @@ void array_basics() {
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == mx::float32);
|
||||
(void)dtype;
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = mx::array(1, mx::int32);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@@ -16,6 +17,19 @@
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
// A helper function to find the location of the current binary on disk.
|
||||
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||
std::string current_binary_dir() {
|
||||
static std::string binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_";
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
std::string kname = "axpby_";
|
||||
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname += type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
nanobind==2.4.0
|
||||
|
||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c shape: {c_cpu.shape}")
|
||||
print(f"c dtype: {c_cpu.dtype}")
|
||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||
|
||||
@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(shapes); ++i) {
|
||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||
}
|
||||
// For each node in |outputs|, its siblings are the other nodes.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(outputs); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
@@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data_size = size();
|
||||
array_desc_->flags.contiguous = true;
|
||||
array_desc_->flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(shape().begin(), shape().end());
|
||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
||||
auto max_dim =
|
||||
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
|
||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
|
||||
}
|
||||
|
||||
void array::set_data(
|
||||
@@ -192,7 +193,7 @@ array::~array() {
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
if (auto n = std::ssize(siblings()); n > 0) {
|
||||
bool do_detach = true;
|
||||
// If all siblings have siblings.size() references except
|
||||
// the one we are currently destroying (which has siblings.size() + 1)
|
||||
@@ -241,8 +242,8 @@ array::ArrayDesc::ArrayDesc(
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
status(Status::unscheduled),
|
||||
primitive(std::move(primitive)),
|
||||
status(Status::unscheduled),
|
||||
inputs(std::move(inputs)) {
|
||||
init();
|
||||
}
|
||||
@@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
bool is_deletable =
|
||||
(a.array_desc_.use_count() <= a.siblings().size() + 1);
|
||||
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
|
||||
// An array with siblings is deletable only if all of its siblings
|
||||
// are deletable
|
||||
for (auto& s : a.siblings()) {
|
||||
@@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
}
|
||||
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||
is_deletable &=
|
||||
s.array_desc_.use_count() <= a.siblings().size() + is_input;
|
||||
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
|
||||
}
|
||||
if (is_deletable) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
|
||||
19
mlx/array.h
19
mlx/array.h
@@ -10,6 +10,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -18,8 +19,8 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
using Shape = SmallVector<ShapeElem>;
|
||||
using Strides = SmallVector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -80,22 +81,22 @@ class array {
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
int itemsize() const {
|
||||
return size_of(dtype());
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
int64_t size() const {
|
||||
return array_desc_->size;
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
int64_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
int ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
}
|
||||
|
||||
@@ -328,7 +329,7 @@ class array {
|
||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||
**/
|
||||
size_t data_size() const {
|
||||
int64_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
}
|
||||
|
||||
@@ -339,7 +340,7 @@ class array {
|
||||
return array_desc_->data->buffer;
|
||||
}
|
||||
|
||||
size_t buffer_size() const {
|
||||
int64_t buffer_size() const {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
@@ -529,7 +530,7 @@ array::array(
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
if (std::ssize(data) != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
int64_t r = 1, c = 1;
|
||||
for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
@@ -60,7 +60,8 @@ void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
|
||||
i < std::ssize(outputs);
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
@@ -70,7 +71,7 @@ void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(outputs); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
@@ -206,11 +207,11 @@ void Split::eval(
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
int64_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
@@ -240,7 +241,7 @@ void Split::eval(
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(indices); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
@@ -254,7 +255,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
if (j < std::ssize(axes_) && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
@@ -272,7 +273,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
for (int ax = 0; ax < std::ssize(axes_); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ void compiled_allocate_outputs(
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
@@ -138,7 +138,7 @@ void compiled_allocate_outputs(
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
for (; o < std::ssize(outputs); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
@@ -147,7 +147,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
@@ -162,7 +162,7 @@ void compiled_allocate_outputs(
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
for (; o < std::ssize(outputs); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
@@ -193,7 +193,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
@@ -201,7 +201,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
for (int i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
@@ -224,13 +224,13 @@ bool compiled_use_large_index(
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
int64_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
int64_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
|
||||
@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}};
|
||||
return {Shape{1}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}, {0}};
|
||||
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
|
||||
@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
for (int i = 1; i < std::ssize(axes); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
|
||||
@@ -24,8 +24,8 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
int64_t data_offset,
|
||||
int64_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
|
||||
@@ -61,7 +61,7 @@ void slice(
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
int64_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ namespace mlx::core {
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
VectorVectorScalar,
|
||||
VectorScalarVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else if (
|
||||
b.data_size() == 1 && a.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorScalarVector;
|
||||
} else if (
|
||||
c.data_size() == 1 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorVectorScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::VectorVectorScalar:
|
||||
case TernaryOpType::VectorScalarVector:
|
||||
case TernaryOpType::General:
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
|
||||
@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
}
|
||||
size_t size = shape[0];
|
||||
int64_t size = shape[0];
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
current_shape *= shape[to_collapse[k]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
for (int j = 0; j < std::ssize(strides); j++) {
|
||||
const auto& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
@@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
}
|
||||
|
||||
inline bool is_donatable(const array& in, const array& out) {
|
||||
constexpr size_t donation_extra = 16384;
|
||||
constexpr int64_t donation_extra = 16384;
|
||||
|
||||
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
@@ -197,7 +197,7 @@ void shared_buffer_reshape(
|
||||
array& out);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||
void arange(T start, T next, array& out, int64_t size, Stream stream) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
@@ -19,12 +19,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
|
||||
@@ -17,7 +17,12 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -81,7 +86,7 @@ void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -146,7 +151,7 @@ void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
@@ -187,7 +192,7 @@ void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@@ -137,21 +137,21 @@ void binary_op(
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
for (int64_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
for (int64_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
|
||||
@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
N = a.shape(-1),
|
||||
size = a.size()]() mutable {
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
size_t num_matrices = size / (N * N);
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
int64_t num_matrices = size / (N * N);
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/version.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -48,7 +49,7 @@ static CompilerCache& cache() {
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
bool compile_available_for_device(const Device& /* device */) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -94,7 +95,11 @@ void* compile(
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
auto output_dir = std::filesystem::temp_directory_path();
|
||||
auto output_dir =
|
||||
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
|
||||
if (!std::filesystem::exists(output_dir)) {
|
||||
std::filesystem::create_directories(output_dir);
|
||||
}
|
||||
|
||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||
@@ -157,11 +162,13 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
int strides_index = 1;
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
@@ -175,8 +182,8 @@ inline void build_kernel(
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,10 +193,8 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@@ -233,7 +238,7 @@ inline void build_kernel(
|
||||
} else {
|
||||
os << x.primitive().name();
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
@@ -290,7 +295,6 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
@@ -298,9 +302,6 @@ void Compiled::eval_cpu(
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
@@ -335,16 +336,20 @@ void Compiled::eval_cpu(
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
if (contiguous) {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = out.dtype();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {
|
||||
N,
|
||||
iH + padding_lo[0] + padding_hi[0],
|
||||
iW + padding_lo[1] + padding_hi[1],
|
||||
C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
std::vector<array> temps;
|
||||
temps.push_back(array(0, conv_dtype));
|
||||
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||
padding_lo[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
temps.push_back(in_padded_slice);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
|
||||
|
||||
// Make strided view
|
||||
Shape strided_shape = {N, oH, oW, wH, wW, C};
|
||||
|
||||
Strides strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * wt_strides[0],
|
||||
in_padded.strides()[2] * wt_strides[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
|
||||
temps.push_back(in_strided);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy_cpu(wt, gemm_wt, ctype, stream);
|
||||
temps.push_back(gemm_wt);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
|
||||
temps.push_back(gemm_out);
|
||||
}
|
||||
|
||||
encoder.set_input_array(in_strided);
|
||||
encoder.set_input_array(gemm_wt);
|
||||
encoder.set_output_array(gemm_out);
|
||||
|
||||
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
|
||||
gemm_wt_ptr = gemm_wt.data<float>(),
|
||||
gemm_out_ptr = gemm_out.data<float>(),
|
||||
strided_reshape = std::move(strided_reshape),
|
||||
O]() {
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided_ptr,
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt_ptr,
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out_ptr,
|
||||
O // ldc
|
||||
);
|
||||
});
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
|
||||
}
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -1128,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& /* wt_dilation */,
|
||||
const bool flip,
|
||||
Stream stream) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
@@ -1148,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Pad input
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
for (int i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
@@ -1179,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
|
||||
// Make strided view
|
||||
Shape strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
for (int i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
for (int i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
Strides strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(wt_strides); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -377,4 +377,10 @@ void copy_cpu_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -30,4 +30,7 @@ void copy_cpu_inplace(
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
return {contiguous_copy_cpu(arr, stream), true};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_cpu(in, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_cpu(in, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
@@ -93,6 +90,7 @@ void Recv::eval_cpu(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
(void)inputs;
|
||||
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
|
||||
@@ -46,7 +46,6 @@ void eig_impl(
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
@@ -71,7 +70,7 @@ void eig_impl(
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
|
||||
@@ -165,7 +165,7 @@ void eigh_impl(
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
for (int64_t i = 0; i < size / (N * N); ++i) {
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
|
||||
@@ -20,8 +20,8 @@ struct CommandEncoder {
|
||||
CommandEncoder(CommandEncoder&&) = delete;
|
||||
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||
|
||||
void set_input_array(const array& a) {}
|
||||
void set_output_array(array& a) {}
|
||||
void set_input_array(const array& /* a */) {}
|
||||
void set_output_array(array& /* a */) {}
|
||||
|
||||
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||
// an input are complete.
|
||||
|
||||
@@ -12,12 +12,12 @@ void matmul(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -35,7 +34,7 @@ void matmul_bnns(
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
size_t /* ldc */,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
@@ -49,9 +48,15 @@ void matmul_bnns(
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
if (beta != 1.0 && beta != 0.0) {
|
||||
// scale the output
|
||||
for (size_t i = 0; i < batch_size * M * N; ++i) {
|
||||
out[i] *= beta;
|
||||
}
|
||||
beta = 1.0;
|
||||
}
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
@@ -122,7 +127,7 @@ void matmul_bnns(
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
@@ -143,12 +148,12 @@ void matmul<float16_t>(
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
@@ -178,12 +183,12 @@ void matmul<bfloat16_t>(
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -13,20 +13,20 @@ void matmul<float>(
|
||||
float* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_sgemm(
|
||||
@@ -54,20 +54,20 @@ void matmul<double>(
|
||||
double* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_dgemm(
|
||||
@@ -88,4 +88,47 @@ void matmul<double>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<complex64_t>(
|
||||
const complex64_t* a,
|
||||
const complex64_t* b,
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_cgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
&calpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
&cbeta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -11,9 +11,9 @@ namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
size_t loc = b * n;
|
||||
int64_t loc = b * n;
|
||||
T* data_ptr = out + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
for (int i = 0; i < std::ssize(row); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
}
|
||||
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
int64_t loc = b * n * m;
|
||||
T* data_ptr = out + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
|
||||
@@ -78,7 +78,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
int64_t i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
@@ -91,7 +91,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
int64_t i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
@@ -101,11 +101,11 @@ void gather(
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
int64_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
|
||||
@@ -115,10 +115,10 @@ void gather(
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
int64_t out_idx = 0;
|
||||
for (int64_t idx = 0; idx < ind_size; idx++) {
|
||||
int64_t src_idx = 0;
|
||||
for (int ii = 0; ii < std::ssize(inds); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
@@ -134,7 +134,7 @@ void gather(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
for (int64_t jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
@@ -403,11 +403,11 @@ void scatter(
|
||||
const std::vector<int>& axes) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
int64_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
int64_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
@@ -418,9 +418,9 @@ void scatter(
|
||||
|
||||
auto out_ptr = out.data<InT>();
|
||||
auto upd_ptr = updates.data<InT>();
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < inds.size(); ++j) {
|
||||
for (int64_t i = 0; i < n_updates; ++i) {
|
||||
int64_t out_offset = 0;
|
||||
for (int j = 0; j < std::ssize(inds); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
@@ -429,7 +429,7 @@ void scatter(
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
for (int64_t j = 0; j < update_size; ++j) {
|
||||
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
|
||||
@@ -122,7 +122,7 @@ void inverse_impl(
|
||||
stream);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
const int64_t num_matrices = a.size() / (N * N);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(inv);
|
||||
@@ -130,13 +130,13 @@ void inverse_impl(
|
||||
auto inv_ptr = inv.data<T>();
|
||||
if (tri) {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
tri_inv<T>(inv_ptr + N * N * i, N, upper);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([inv_ptr, N, num_matrices]() {
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
general_inv<T>(inv_ptr + N * N * i, N);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_cpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ inline void mask_matrix(
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
const int64_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
@@ -61,13 +61,13 @@ inline void segmented_mm(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
int64_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
return std::make_tuple(true, sty, arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
};
|
||||
@@ -150,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [b_transposed, ldb, b, b_copied] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -173,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str,
|
||||
int64_t X_data_str,
|
||||
int64_t Y_data_str,
|
||||
const Shape& mask_shape,
|
||||
const Strides& mask_strides,
|
||||
bool is_bool) {
|
||||
@@ -216,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
const void* a_mask_ptr;
|
||||
const void* b_mask_ptr;
|
||||
const void* out_mask_ptr;
|
||||
const void* a_mask_ptr = nullptr;
|
||||
const void* b_mask_ptr = nullptr;
|
||||
const void* out_mask_ptr = nullptr;
|
||||
Shape a_mask_shape;
|
||||
Shape b_mask_shape;
|
||||
Shape out_mask_shape;
|
||||
Strides a_mask_strides;
|
||||
Strides b_mask_strides;
|
||||
Strides out_mask_strides;
|
||||
bool a_mask_bool;
|
||||
bool b_mask_bool;
|
||||
bool out_mask_bool;
|
||||
bool a_mask_bool = false;
|
||||
bool b_mask_bool = false;
|
||||
bool out_mask_bool = false;
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
@@ -254,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto a_ptr = a.data<float>();
|
||||
auto b_ptr = b.data<float>();
|
||||
auto out_ptr = out.data<float>();
|
||||
size_t num_matrices = out.size() / (M * size_t(N));
|
||||
int64_t num_matrices = out.size() / (M * int64_t(N));
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
encoder.dispatch([a_ptr,
|
||||
@@ -395,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -414,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get batch dims
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = M * N;
|
||||
int64_t matrix_stride_out = M * N;
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
@@ -424,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
auto batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
auto batch_shape_A = get_batch_dims(a.shape());
|
||||
auto batch_strides_A = get_batch_dims(a.strides());
|
||||
|
||||
@@ -91,7 +91,6 @@ void matmul_general(
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -108,6 +107,9 @@ void matmul_general(
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul_dispatch<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == complex64) {
|
||||
matmul_dispatch<complex64_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
@@ -128,10 +130,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
|
||||
@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(axes); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
@@ -124,6 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
(void)inputs;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
@@ -193,9 +194,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
int64_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
@@ -205,7 +206,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
constexpr int64_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
@@ -254,8 +255,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
int64_t data_offset = 0;
|
||||
for (int i = 0; i < std::ssize(axes_); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
@@ -274,10 +275,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
int64_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
int64_t elems_per_key = out.size() / num_keys;
|
||||
int64_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
@@ -291,8 +292,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
uintptr_t half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
|
||||
@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = M;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < num_reflectors; ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
orgqr<T>(
|
||||
&M,
|
||||
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&info);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < M; ++j) {
|
||||
for (int k = 0; k < num_reflectors; ++k) {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@@ -13,6 +11,35 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
const static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename T>
|
||||
static inline T dequantize_scale(uint8_t s) {
|
||||
using FOrI = union {
|
||||
bfloat16_t f;
|
||||
uint16_t i;
|
||||
};
|
||||
FOrI out;
|
||||
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||
return static_cast<T>(out.f);
|
||||
}
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
@@ -407,6 +434,229 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
std::fill(result, result + N, 0);
|
||||
|
||||
for (int k = 0; k < K; k++) {
|
||||
T* result_local = result;
|
||||
T xi = *x++;
|
||||
|
||||
for (int n = 0; n < N; n += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result += N;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
T gsum = 0;
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
sum += scale * gsum;
|
||||
}
|
||||
*result = sum;
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <int S>
|
||||
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
|
||||
if constexpr (S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & 0xf;
|
||||
simd::Simd<float, S> w_out;
|
||||
for (int i = 0; i < S; ++i) {
|
||||
w_out[i] = MXFP4_LUT[wi[i]];
|
||||
}
|
||||
return w_out;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = 32 / 4;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
simd::Simd<float, S> g_acc(0);
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
// Extract bits
|
||||
auto wf = mxfp4_extract_bits_simd<S>(w_local);
|
||||
w_local += packs_per_simd;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
g_acc = g_acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
acc = acc + scale * g_acc;
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_transpose(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (simd::max_size<T> % 8 == 0) {
|
||||
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
|
||||
} else {
|
||||
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
} else {
|
||||
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case bfloat16:
|
||||
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float32:
|
||||
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
@@ -513,115 +763,198 @@ void _bs_qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
mxfp4_bs_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_bs_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_cpy, CopyType::General, s);
|
||||
encoder.add_temporary(arr_cpy);
|
||||
return arr_cpy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
auto& lhs_indices = inputs[inputs.size() - 2];
|
||||
auto& rhs_indices = inputs[inputs.size() - 1];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
&temps](const array& arr) {
|
||||
&encoder](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_cpy, CopyType::General, s);
|
||||
encoder.add_temporary(arr_cpy);
|
||||
return arr_cpy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_bs_qmm_dispatch(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -705,16 +1038,14 @@ void dispatch_quantize(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
void fast::Quantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -766,7 +1097,7 @@ void fast::AffineQuantize::eval_cpu(
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
||||
@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_cpu(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
in = contiguous_copy_cpu(in, stream());
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// There seems to be a bug in simd/base_simd.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
@@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
static_assert(std::is_same_v<MaskT, bool>);
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
@@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
// Raising an integer to a negative power is undefined
|
||||
if (any(exp < static_cast<T>(0))) {
|
||||
return 0;
|
||||
}
|
||||
while (any(exp > static_cast<T>(0))) {
|
||||
res = select((exp & 1) != 0, res * base, res);
|
||||
base = select(exp > static_cast<T>(0), base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
|
||||
@@ -79,7 +79,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
auto poly_mask =
|
||||
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
@@ -87,8 +88,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
|
||||
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_cpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
@@ -8,13 +8,25 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// NaN-aware comparator that places NaNs at the end
|
||||
template <typename T>
|
||||
bool nan_aware_less(T a, T b) {
|
||||
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
|
||||
if (std::isnan(a))
|
||||
return false;
|
||||
if (std::isnan(b))
|
||||
return true;
|
||||
}
|
||||
return a < b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
@@ -27,7 +39,7 @@ struct StridedIterator {
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
: stride_(stride), ptr_(ptr + offset * stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
|
||||
@@ -108,8 +120,8 @@ template <typename T>
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -124,13 +136,13 @@ void sort(array& out, int axis) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
std::stable_sort(st, ed, nan_aware_less<T>);
|
||||
src_it.step();
|
||||
}
|
||||
}
|
||||
@@ -139,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -164,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
|
||||
// Handle NaNs (place them at the end)
|
||||
if (std::is_floating_point<T>::value) {
|
||||
if (std::isnan(v1))
|
||||
return false;
|
||||
if (std::isnan(v2))
|
||||
return true;
|
||||
}
|
||||
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
@@ -193,8 +214,8 @@ template <typename T>
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -211,7 +232,7 @@ void partition(array& out, int axis, int kth) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
std::nth_element(st, md, ed, nan_aware_less<T>);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -256,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
|
||||
// Handle NaNs (place them at the end)
|
||||
if (std::is_floating_point<T>::value) {
|
||||
if (std::isnan(v1))
|
||||
return false;
|
||||
if (std::isnan(v2))
|
||||
return true;
|
||||
}
|
||||
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
@@ -333,47 +363,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
int axis = axis_;
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
@@ -27,7 +27,7 @@ void svd_impl(
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -81,40 +81,26 @@ void svd_impl(
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto job_u = (u_ptr) ? "V" : "N";
|
||||
auto job_vt = (u_ptr) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
@@ -135,21 +121,14 @@ void svd_impl(
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
@@ -167,13 +146,6 @@ void svd_impl(
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
@@ -181,10 +153,10 @@ void svd_impl(
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
const array& /* a */,
|
||||
bool /* compute_uv */,
|
||||
std::vector<array>& /* outputs */,
|
||||
Stream /* stream */) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
|
||||
@@ -136,7 +136,7 @@ void ternary_op(
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
|
||||
for (int64_t i = 0; i < shape; i += 1) {
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
int64_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
int64_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
for (int64_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
|
||||
@@ -77,7 +77,8 @@ struct Real {
|
||||
struct Sigmoid {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return 1.0f / (1.0f + simd::exp(-x));
|
||||
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
|
||||
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
@@ -15,18 +15,26 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||
@@ -35,16 +43,28 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||
else()
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
@@ -87,11 +107,18 @@ endif()
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
||||
# and requires drivers released after CUDA 12.4.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||
endif()
|
||||
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES "native")
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
@@ -123,9 +150,30 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Use the frontend APIs of cuDNN.
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.14.0
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||
FetchContent_MakeAvailable(cudnn)
|
||||
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||
# Link with the actual cuDNN libraries.
|
||||
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -17,18 +16,82 @@ namespace cu {
|
||||
|
||||
constexpr int page_size = 16384;
|
||||
|
||||
// Any allocations smaller than this will try to use the small pool
|
||||
constexpr int small_block_size = 8;
|
||||
|
||||
// The small pool size in bytes. This should be a multiple of the host page
|
||||
// size and small_block_size.
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
|
||||
SmallSizePool::SmallSizePool() {
|
||||
auto num_blocks = small_pool_size / small_block_size;
|
||||
buffer_ = new Block[num_blocks];
|
||||
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
#else
|
||||
int loc = i;
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||
}
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
curr->next = buffer_ + i;
|
||||
curr = curr->next;
|
||||
}
|
||||
curr->next = nullptr;
|
||||
}
|
||||
|
||||
SmallSizePool::~SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||
delete[] buffer_;
|
||||
}
|
||||
|
||||
CudaBuffer* SmallSizePool::malloc() {
|
||||
if (next_free_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
Block* b = next_free_;
|
||||
uint64_t i = next_free_ - buffer_;
|
||||
next_free_ = next_free_->next;
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
void SmallSizePool::free(CudaBuffer* buf) {
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
b->next = next_free_;
|
||||
next_free_ = b;
|
||||
}
|
||||
|
||||
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
int64_t block_num = b - buffer_;
|
||||
return block_num >= 0 && block_num < num_blocks;
|
||||
}
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
@@ -36,7 +99,9 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size < page_size) {
|
||||
if (size <= small_block_size) {
|
||||
size = 8;
|
||||
} else if (size < page_size) {
|
||||
size = next_power_of_2(size);
|
||||
} else {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
@@ -44,19 +109,25 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
int64_t mem_to_free =
|
||||
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||
if (mem_to_free > 0) {
|
||||
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||
}
|
||||
|
||||
// Try the scalar pool first
|
||||
if (size <= small_block_size) {
|
||||
buf = scalar_pool_.malloc();
|
||||
}
|
||||
lock.unlock();
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
@@ -67,7 +138,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
@@ -82,9 +152,7 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
cuda_free(buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,27 +164,14 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
// This must be called with mutex_ aquired
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
cudaFree(buf);
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
|
||||
@@ -7,13 +7,10 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
@@ -22,21 +19,35 @@ struct CudaBuffer {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class SmallSizePool {
|
||||
private:
|
||||
union Block {
|
||||
Block* next;
|
||||
CudaBuffer buf;
|
||||
};
|
||||
|
||||
Block* buffer_{nullptr};
|
||||
void* data_{nullptr};
|
||||
Block* next_free_{nullptr};
|
||||
|
||||
public:
|
||||
SmallSizePool();
|
||||
~SmallSizePool();
|
||||
|
||||
SmallSizePool(const SmallSizePool&) = delete;
|
||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||
|
||||
CudaBuffer* malloc();
|
||||
void free(CudaBuffer* buf);
|
||||
bool in_pool(CudaBuffer* buf);
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call cudaFree in the safe thread.
|
||||
void cuda_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
@@ -47,19 +58,18 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
SmallSizePool scalar_pool_;
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
69
mlx/backend/cuda/arange.cu
Normal file
69
mlx/backend/cuda/arange.cu
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int N_WRITES>
|
||||
__global__ void arange(T* out, IdxT size, T start, T step) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_WRITES > size) {
|
||||
for (IdxT i = index * N_WRITES; i < size; ++i) {
|
||||
out[i] = start + i * step;
|
||||
}
|
||||
} else {
|
||||
AlignedVector<T, N_WRITES> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_WRITES; ++i) {
|
||||
out_vec[i] = start + (index * N_WRITES + i) * step;
|
||||
}
|
||||
|
||||
store_vector<N_WRITES>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_WRITES = 16 / sizeof(OutType);
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
|
||||
encoder.add_kernel_node(
|
||||
cu::arange<OutType, IdxT, N_WRITES>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
static_cast<CTYPE>(start_),
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -44,8 +44,11 @@ struct ArgMin {
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
__device__ IndexValPair<T> reduce_many(
|
||||
IndexValPair<T> best,
|
||||
const AlignedVector<T, N>& vals,
|
||||
uint32_t offset) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
@@ -74,8 +77,11 @@ struct ArgMax {
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
__device__ IndexValPair<T> reduce_many(
|
||||
IndexValPair<T> best,
|
||||
const AlignedVector<T, N>& vals,
|
||||
uint32_t offset) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
|
||||
|
||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||
in += in_idx;
|
||||
|
||||
Op op;
|
||||
T init = op.init();
|
||||
IndexValPair<T> best{0, init};
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
}
|
||||
|
||||
@@ -166,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
|
||||
21
mlx/backend/cuda/binary/CMakeLists.txt
Normal file
21
mlx/backend/cuda/binary/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
|
||||
7
mlx/backend/cuda/binary/add.cu
Normal file
7
mlx/backend/cuda/binary/add.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Add)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/arctan2.cu
Normal file
7
mlx/backend/cuda/binary/arctan2.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(ArcTan2)
|
||||
} // namespace mlx::core
|
||||
@@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a[0], b[0]);
|
||||
out_vec[i] = Op{}(a[0], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
|
||||
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
|
||||
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -92,46 +92,96 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
template <
|
||||
typename Op,
|
||||
typename In,
|
||||
typename Out,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
int N_READS>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto a_stride_x = a_strides[NDIM - 1];
|
||||
auto b_stride_x = b_strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto a_stride_x = a_strides[ndim - 1];
|
||||
auto b_stride_x = b_strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
ndim);
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
@@ -209,36 +259,61 @@ void binary_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::
|
||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
rest,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
@@ -248,8 +323,7 @@ void binary_op_gpu_inplace(
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||
@@ -259,16 +333,12 @@ void binary_op_gpu_inplace(
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out.data_size(),
|
||||
out.shape(),
|
||||
out.strides(),
|
||||
large(),
|
||||
N_READS);
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
@@ -306,54 +376,4 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
27
mlx/backend/cuda/binary/bitwise_binary.cu
Normal file
27
mlx/backend/cuda/binary/bitwise_binary.cu
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/divide.cu
Normal file
7
mlx/backend/cuda/binary/divide.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Divide)
|
||||
} // namespace mlx::core
|
||||
15
mlx/backend/cuda/binary/equal.cu
Normal file
15
mlx/backend/cuda/binary/equal.cu
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/greater.cu
Normal file
7
mlx/backend/cuda/binary/greater.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Greater)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/greater_equal.cu
Normal file
7
mlx/backend/cuda/binary/greater_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(GreaterEqual)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/less.cu
Normal file
7
mlx/backend/cuda/binary/less.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Less)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/less_equal.cu
Normal file
7
mlx/backend/cuda/binary/less_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LessEqual)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/log_add_exp.cu
Normal file
7
mlx/backend/cuda/binary/log_add_exp.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogAddExp)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/logical_and.cu
Normal file
7
mlx/backend/cuda/binary/logical_and.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogicalAnd)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/logical_or.cu
Normal file
7
mlx/backend/cuda/binary/logical_or.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogicalOr)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/maximum.cu
Normal file
7
mlx/backend/cuda/binary/maximum.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Maximum)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/minimum.cu
Normal file
7
mlx/backend/cuda/binary/minimum.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Minimum)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/multiply.cu
Normal file
7
mlx/backend/cuda/binary/multiply.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Multiply)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/not_equal.cu
Normal file
7
mlx/backend/cuda/binary/not_equal.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(NotEqual)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/power.cu
Normal file
7
mlx/backend/cuda/binary/power.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Power)
|
||||
} // namespace mlx::core
|
||||
7
mlx/backend/cuda/binary/remainder.cu
Normal file
7
mlx/backend/cuda/binary/remainder.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Remainder)
|
||||
} // namespace mlx::core
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user