mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Compare commits
254 Commits
c35f4d089a
...
gh-pages
Author | SHA1 | Date | |
---|---|---|---|
![]() |
57a4334bbc | ||
![]() |
84d493c53c | ||
![]() |
4a729ea9ba | ||
![]() |
852cffda73 | ||
![]() |
0746bf174f | ||
![]() |
c1c6d69d53 | ||
![]() |
36ff32876e | ||
![]() |
cb9421d68a | ||
![]() |
03683cc507 | ||
![]() |
95c0bfb5ed | ||
![]() |
f84333aa33 | ||
![]() |
0f716377d1 | ||
![]() |
eac8a84521 | ||
![]() |
ea62c81b68 | ||
![]() |
a6c7193c6c | ||
![]() |
0b8425f93b | ||
![]() |
085ab02328 | ||
![]() |
b1385dbac5 | ||
![]() |
b7deadf44c | ||
![]() |
358d1cffdb | ||
![]() |
23fec194d8 | ||
![]() |
0ac71de969 | ||
![]() |
8a70e9e8cb | ||
![]() |
32195ee16f | ||
![]() |
abe1da8af4 | ||
![]() |
59c868fcbc | ||
![]() |
a196e0e669 | ||
![]() |
d1357760e4 | ||
![]() |
9ad625c70d | ||
![]() |
d1859a4f24 | ||
![]() |
9d0d5648d9 | ||
![]() |
eed934c954 | ||
![]() |
1bcb45cee2 | ||
![]() |
5f2d990100 | ||
![]() |
3f1778f3a1 | ||
![]() |
92badab745 | ||
![]() |
e95b4c1f0e | ||
![]() |
8c272db45b | ||
![]() |
da6b809825 | ||
![]() |
3aabcb2850 | ||
![]() |
62608cbd7e | ||
![]() |
061d6d3979 | ||
![]() |
e4a6ad6701 | ||
![]() |
dcc04ea7f4 | ||
![]() |
a10b1c4457 | ||
![]() |
21c71fca27 | ||
![]() |
60e0eb6b5d | ||
![]() |
ffc0be0bdf | ||
![]() |
76f2d4a67f | ||
![]() |
6ea4f86f7f | ||
![]() |
023b343eff | ||
![]() |
fec431f299 | ||
![]() |
8680b4a35e | ||
![]() |
26e68230dc | ||
![]() |
05bccb46b1 | ||
![]() |
b1b0991896 | ||
![]() |
71656654d6 | ||
![]() |
76c94fab2e | ||
![]() |
d25d7ea240 | ||
![]() |
55a57ba1fb | ||
![]() |
0484e62fbd | ||
![]() |
83b7891525 | ||
![]() |
3128dc2560 | ||
![]() |
c6af74f07d | ||
![]() |
cf742646fc | ||
![]() |
a08ce1389c | ||
![]() |
beb994875b | ||
![]() |
cfff02c477 | ||
![]() |
8ce49cd39e | ||
![]() |
9c68b50853 | ||
![]() |
111f1e71af | ||
![]() |
827003d568 | ||
![]() |
d363a76aa4 | ||
![]() |
70560b6bd5 | ||
![]() |
7ef8a6f2d5 | ||
![]() |
31c6f6e33f | ||
![]() |
584d48458e | ||
![]() |
5cf984ca87 | ||
![]() |
a9bac3d9e5 | ||
![]() |
5458d43247 | ||
![]() |
a4dba65220 | ||
![]() |
3dcb286baf | ||
![]() |
4822c3dbe9 | ||
![]() |
2ca75bb529 | ||
![]() |
db14e29a0b | ||
![]() |
d2f540f4e0 | ||
![]() |
333ffea273 | ||
![]() |
f55b6f1f2f | ||
![]() |
30561229c7 | ||
![]() |
068a4612e9 | ||
![]() |
5722c147de | ||
![]() |
f6819a1f26 | ||
![]() |
f93f87c802 | ||
![]() |
9392fc3f88 | ||
![]() |
e843c4d8d5 | ||
![]() |
0c5fc63a36 | ||
![]() |
e397177f6e | ||
![]() |
f4c8888cbe | ||
![]() |
25c1e03205 | ||
![]() |
512281781c | ||
![]() |
ac85ddfdb7 | ||
![]() |
65d0d40232 | ||
![]() |
cea9369610 | ||
![]() |
e7c6e1db82 | ||
![]() |
c5fcd5b61b | ||
![]() |
1df9887998 | ||
![]() |
73f22d6226 | ||
![]() |
c422050ca7 | ||
![]() |
1ba18ff7d9 | ||
![]() |
37b440faa8 | ||
![]() |
888b13ed63 | ||
![]() |
4abb218d21 | ||
![]() |
6441c21a94 | ||
![]() |
dfb5022eab | ||
![]() |
ac207ce7aa | ||
![]() |
fce53b61d6 | ||
![]() |
8ae4a76308 | ||
![]() |
7fde1b6a1e | ||
![]() |
aa7b47481a | ||
![]() |
56be773610 | ||
![]() |
a9bdd67baa | ||
![]() |
f2adb5638d | ||
![]() |
728d4db582 | ||
![]() |
db5c7efcf6 | ||
![]() |
7bb96e4249 | ||
![]() |
fa89f0b150 | ||
![]() |
ca973d1e83 | ||
![]() |
828c5f1137 | ||
![]() |
7d86a5c108 | ||
![]() |
0b807893a7 | ||
![]() |
6ad0889c8a | ||
![]() |
737dd6d1ac | ||
![]() |
aaf78f4c6b | ||
![]() |
8831064493 | ||
![]() |
be9bc96da4 | ||
![]() |
86258f292f | ||
![]() |
b26d88591c | ||
![]() |
86c6a15571 | ||
![]() |
8b25ce62d5 | ||
![]() |
da5912e4f2 | ||
![]() |
daafee676f | ||
![]() |
d32519c8ee | ||
![]() |
b405591249 | ||
![]() |
3bf81ed1bd | ||
![]() |
2204182bba | ||
![]() |
3628e5d497 | ||
![]() |
a0ae49d397 | ||
![]() |
254476718b | ||
![]() |
3adba92ebe | ||
![]() |
ef631d63af | ||
![]() |
970dbe8e25 | ||
![]() |
641be9463b | ||
![]() |
ab0e608862 | ||
![]() |
1588659062 | ||
![]() |
b9e88fb976 | ||
![]() |
4ad53414dd | ||
![]() |
d1165b215e | ||
![]() |
dcb8319f3d | ||
![]() |
5597fa089c | ||
![]() |
9acec364c2 | ||
![]() |
7d9d6ef456 | ||
![]() |
6f5874a2f2 | ||
![]() |
70dc336785 | ||
![]() |
4e504039f5 | ||
![]() |
d1f4d291e8 | ||
![]() |
e1840853ce | ||
![]() |
0f5ce173da | ||
![]() |
588854195f | ||
![]() |
28d068bce6 | ||
![]() |
d107d8d495 | ||
![]() |
1e496ddb82 | ||
![]() |
74eccbf3fa | ||
![]() |
08638223ca | ||
![]() |
56cc858af9 | ||
![]() |
f55c4ed1d6 | ||
![]() |
93d70419e7 | ||
![]() |
63f663d9c6 | ||
![]() |
84b4d96efa | ||
![]() |
aec67f2fa6 | ||
![]() |
deee214a95 | ||
![]() |
45adec102c | ||
![]() |
31fc530c76 | ||
![]() |
fbb3f65a1a | ||
![]() |
6b1b8ea91b | ||
![]() |
b2273733ea | ||
![]() |
f409b229a4 | ||
![]() |
30571e2326 | ||
![]() |
d7734edd9f | ||
![]() |
2ba69bc8fa | ||
![]() |
cb349a291c | ||
![]() |
f0a0b077a0 | ||
![]() |
49114f28ab | ||
![]() |
e7d2ebadd2 | ||
![]() |
e569803d7c | ||
![]() |
d34f887abc | ||
![]() |
5201df5030 | ||
![]() |
2d3c26c565 | ||
![]() |
6325f60d52 | ||
![]() |
42cc9cfbc7 | ||
![]() |
8347575ba1 | ||
![]() |
b6eec20260 | ||
![]() |
0eb035b4b1 | ||
![]() |
afb9817599 | ||
![]() |
8fb3e7a26c | ||
![]() |
8c7bc30ce4 | ||
![]() |
85873cb162 | ||
![]() |
e14ee12491 | ||
![]() |
8b9a3f3cea | ||
![]() |
fb4e8b896b | ||
![]() |
2ca533b279 | ||
![]() |
4a9b29a875 | ||
![]() |
a4fcc893cd | ||
![]() |
9d10239af7 | ||
![]() |
19facd4b20 | ||
![]() |
f5299f72cd | ||
![]() |
0e0d9ac522 | ||
![]() |
8917022deb | ||
![]() |
ec0d5db67b | ||
![]() |
e76e9b87f0 | ||
![]() |
cfb6a244ea | ||
![]() |
58f3860306 | ||
![]() |
dd4f53db63 | ||
![]() |
3d5e17e507 | ||
![]() |
33bf1a244b | ||
![]() |
772f471ff2 | ||
![]() |
2c11d10f8d | ||
![]() |
656ed7f780 | ||
![]() |
81bb9a2a9e | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
76831ed83d | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
b8022c578a | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 | ||
![]() |
4fda5fbdf9 | ||
![]() |
580776559b | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
6871e2eeb7 | ||
![]() |
8402a2acf4 | ||
![]() |
fddb6933e1 | ||
![]() |
c8b4787e4e | ||
![]() |
2188199ff8 | ||
![]() |
aa07429bad | ||
![]() |
918761a25a | ||
![]() |
a4fc671d3e | ||
![]() |
f5f65ef48c | ||
![]() |
c2dd81a8aa | ||
![]() |
d7e680ffe4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a |
@@ -7,15 +7,9 @@ parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -24,13 +18,14 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
xcode: "26.0.0"
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
@@ -38,7 +33,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
@@ -70,9 +65,9 @@ jobs:
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -84,37 +79,37 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
@@ -125,7 +120,7 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -133,57 +128,56 @@ jobs:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
uv venv --python 3.9
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
@@ -192,7 +186,7 @@ jobs:
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
@@ -204,36 +198,74 @@ jobs:
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -242,7 +274,7 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -251,7 +283,7 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
@@ -259,11 +291,15 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
mkdir -p ~/miniconda3
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
@@ -273,30 +309,38 @@ jobs:
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
@@ -306,52 +350,100 @@ jobs:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
build_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Upload package
|
||||
name: Build wheel
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -363,22 +455,23 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
@@ -392,68 +485,7 @@ workflows:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -461,6 +493,25 @@ workflows:
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
@@ -476,11 +527,14 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -492,58 +546,18 @@ workflows:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
@@ -551,76 +565,13 @@ workflows:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
|
@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Organizations
|
||||
|
||||
MLX has received contributions from the following companies:
|
||||
- NVIDIA Corporation & Affiliates
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
MLX leverages several third-party software, listed here together with
|
||||
|
@@ -41,7 +41,9 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@@ -64,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -131,6 +140,12 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||
# building on old distributions they need to be explicitly listed.
|
||||
target_link_libraries(mlx PRIVATE dl pthread)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
@@ -234,12 +249,16 @@ target_include_directories(
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
@@ -68,18 +68,23 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||
macOS, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
To install the CUDA backend on Linux, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
```
|
||||
|
||||
Checkout the
|
||||
|
@@ -192,6 +192,22 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
@@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
if x.device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -51,6 +51,20 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||
# directories.
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||
|
||||
find_path(
|
||||
NCCL_INCLUDE_DIRS
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
|
||||
if($ENV{USE_STATIC_NCCL})
|
||||
message(
|
||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||
set(NCCL_LIBNAME "libnccl_static.a")
|
||||
else()
|
||||
set(NCCL_LIBNAME "nccl")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NCCL_LIBRARIES
|
||||
NAMES ${NCCL_LIBNAME}
|
||||
HINTS ${NCCL_LIB_DIR}
|
||||
${NCCL_ROOT_DIR}
|
||||
${NCCL_ROOT_DIR}/lib
|
||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||
${NCCL_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||
NCCL_LIBRARIES)
|
||||
|
||||
if(NCCL_FOUND)
|
||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||
message(
|
||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||
file(
|
||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||
LIMIT_COUNT 1)
|
||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||
endif()
|
||||
message(
|
||||
STATUS
|
||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||
endif()
|
4
docs/build/html/.buildinfo
vendored
Normal file
4
docs/build/html/.buildinfo
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Sphinx build info version 1
|
||||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
|
||||
config: 6e9fcd3fd9a477c32d79521f0d5d7188
|
||||
tags: 645f666f9bcd5a90fca523b33c5a78b7
|
BIN
docs/build/html/_images/capture.png
vendored
Normal file
BIN
docs/build/html/_images/capture.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/build/html/_images/schema.png
vendored
Normal file
BIN
docs/build/html/_images/schema.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
7
docs/build/html/_sources/cpp/ops.rst
vendored
Normal file
7
docs/build/html/_sources/cpp/ops.rst
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.. _cpp_ops:
|
||||
|
||||
Operations
|
||||
==========
|
||||
|
||||
.. doxygengroup:: ops
|
||||
:content-only:
|
445
docs/build/html/_sources/dev/custom_metal_kernels.rst
vendored
Normal file
445
docs/build/html/_sources/dev/custom_metal_kernels.rst
vendored
Normal file
@@ -0,0 +1,445 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
These will be added as function arguments.
|
||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||
|
||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void custom_kernel_myexp_float(
|
||||
const device float16_t* inp [[buffer(0)]],
|
||||
device float16_t* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
|
||||
}
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
811
docs/build/html/_sources/dev/extensions.rst
vendored
Normal file
811
docs/build/html/_sources/dev/extensions.rst
vendored
Normal file
@@ -0,0 +1,811 @@
|
||||
Custom Extensions in MLX
|
||||
========================
|
||||
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||
You can do that in MLX directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:
|
||||
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||
evaluating and transforming the graph. Let's start by discussing operations in
|
||||
more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Use NumPy-style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
The simplest way to implement this is with existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
The operations themselves do not contain the implementations that act on the
|
||||
data, nor do they contain the rules of transformations. Rather, they are an
|
||||
easy to use interface that use :class:`Primitive` building blocks.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||
implementations of how the output array is produced given the inputs through
|
||||
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||
:meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitive
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||
computation graph. An :class:`array` can be constructed by providing its data
|
||||
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||
inputs that are passed to the primitive.
|
||||
|
||||
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
auto out_shape = broadcasted_inputs[0].shape();
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
|
||||
This operation now handles the following:
|
||||
|
||||
#. Upcast inputs and resolve the output data type.
|
||||
#. Broadcast the inputs and resolve the output shape.
|
||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. The operation only
|
||||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed.
|
||||
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||
|
||||
The method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
``complex64``. We throw an error if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here.
|
||||
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
GPU kernels in MLX are written using Metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to Metal:
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||
there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void axpby_general(
|
||||
device const T* x [[buffer(0)]],
|
||||
device const T* y [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::stream kname;
|
||||
kname = "axpby_general_" + type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||
metal compute command encoder instead of building a new one and calling
|
||||
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||
pipelines) to the active command buffer until some specified limit is hit or
|
||||
the command buffer needs to be flushed for synchronization.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||
These transformations can be built on top of other operations, including the
|
||||
one we just defined:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>& /* unused */) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Note, a transformation does not need to be fully defined to start using
|
||||
the :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
| │ ├── axpby.cpp
|
||||
| │ ├── axpby.h
|
||||
| │ └── axpby.metal
|
||||
| ├── mlx_sample_extensions
|
||||
| │ └── __init__.py
|
||||
| ├── bindings.cpp
|
||||
| ├── CMakeLists.txt
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated Python package
|
||||
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
Python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the Python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
y (array): Input array.
|
||||
alpha (float): Scaling factor for ``x``.
|
||||
beta (float): Scaling factor for ``y``.
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` must be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library only requires that you ``find_package(MLX
|
||||
CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Add library
|
||||
add_library(mlx_ext)
|
||||
|
||||
# Add sources
|
||||
target_sources(
|
||||
mlx_ext
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||
)
|
||||
|
||||
# Add include headers
|
||||
target_include_directories(
|
||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached Metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx_ext
|
||||
mlx_ext_metallib
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the nanobind_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev":[]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
To build the package, first install the build dependencies with ``pip install
|
||||
-r requirements.txt``. You can then build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This results in the directory structure:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .`` (in
|
||||
``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||
copied along with the Python binding since they are specified as
|
||||
``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the Python package and play with it as you would any other MLX operation.
|
||||
|
||||
Let's look at a simple script and its results:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block::
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 4096
|
||||
N = 4096
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval(x, y)
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(5):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return 1000 * (e - s) / 100
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||
|
||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`.
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
68
docs/build/html/_sources/dev/metal_debugger.rst
vendored
Normal file
68
docs/build/html/_sources/dev/metal_debugger.rst
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
To build with debugging enabled in Python prepend
|
||||
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||
|
||||
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||
work.
|
||||
|
||||
.. note::
|
||||
|
||||
To capture a GPU trace you must run the application with
|
||||
``MTL_CAPTURE_ENABLED=1``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.random.uniform(shape=(512, 512))
|
||||
b = mx.random.uniform(shape=(512, 512))
|
||||
mx.eval(a, b)
|
||||
|
||||
trace_file = "mlx_trace.gputrace"
|
||||
|
||||
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||
# that the path trace_file does not already exist.
|
||||
mx.metal.start_capture(trace_file)
|
||||
|
||||
for _ in range(10):
|
||||
mx.eval(mx.add(a, b))
|
||||
|
||||
mx.metal.stop_capture()
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
121
docs/build/html/_sources/dev/mlx_in_cpp.rst
vendored
Normal file
121
docs/build/html/_sources/dev/mlx_in_cpp.rst
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
77
docs/build/html/_sources/examples/linear_regression.rst
vendored
Normal file
77
docs/build/html/_sources/examples/linear_regression.rst
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
.. _linear_regression:
|
||||
|
||||
Linear Regression
|
||||
-----------------
|
||||
|
||||
Let's implement a basic linear regression model as a starting point to
|
||||
learn MLX. First import the core package and setup some problem metadata:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000 # iterations of SGD
|
||||
lr = 0.01 # learning rate for SGD
|
||||
|
||||
|
||||
We'll generate a synthetic dataset by:
|
||||
|
||||
1. Sampling the design matrix ``X``.
|
||||
2. Sampling a ground truth parameter vector ``w_star``.
|
||||
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# True parameters
|
||||
w_star = mx.random.normal((num_features,))
|
||||
|
||||
# Input examples (design matrix)
|
||||
X = mx.random.normal((num_examples, num_features))
|
||||
|
||||
# Noisy labels
|
||||
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||
y = X @ w_star + eps
|
||||
|
||||
|
||||
We will use SGD to find the optimal weights. To start, define the squared loss
|
||||
and get the gradient function of the loss with respect to the parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(w):
|
||||
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
Start the optimization by initializing the parameters ``w`` randomly. Then
|
||||
repeatedly update the parameters for ``num_iters`` iterations.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
w = 1e-2 * mx.random.normal((num_features,))
|
||||
|
||||
for _ in range(num_iters):
|
||||
grad = grad_fn(w)
|
||||
w = w - lr * grad
|
||||
mx.eval(w)
|
||||
|
||||
Finally, compute the loss of the learned parameters and verify that they are
|
||||
close to the ground truth parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
loss = loss_fn(w)
|
||||
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||
|
||||
print(
|
||||
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||
)
|
||||
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
|
||||
|
||||
Complete `linear regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
|
||||
and `logistic regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
|
||||
examples are available in the MLX GitHub repo.
|
382
docs/build/html/_sources/examples/llama-inference.rst
vendored
Normal file
382
docs/build/html/_sources/examples/llama-inference.rst
vendored
Normal file
@@ -0,0 +1,382 @@
|
||||
LLM inference
|
||||
==============
|
||||
|
||||
MLX enables efficient inference of large-ish transformers on Apple silicon
|
||||
without compromising on ease of use. In this example we will create an
|
||||
inference script for the Llama family of transformer models in which the model
|
||||
is defined in less than 200 lines of python.
|
||||
|
||||
Implementing the model
|
||||
----------------------
|
||||
|
||||
We will use the neural network building blocks defined in the :mod:`mlx.nn`
|
||||
module to concisely define the model architecture.
|
||||
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the Llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
||||
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
|
||||
:class:`mlx.nn.RoPE` for the positional encoding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
Encoder layer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
The other component of the Llama model is the encoder layer which uses RMS
|
||||
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
|
||||
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
Full model
|
||||
^^^^^^^^^^
|
||||
|
||||
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
|
||||
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.out_proj(x)
|
||||
|
||||
Note that in the implementation above we use a simple list to hold the encoder
|
||||
layers but using ``model.parameters()`` will still consider these layers.
|
||||
|
||||
Generation
|
||||
^^^^^^^^^^^
|
||||
|
||||
Our ``Llama`` module can be used for training but not inference as the
|
||||
``__call__`` method above processes one input, completely ignores the cache and
|
||||
performs no sampling whatsoever. In the rest of this subsection, we will
|
||||
implement the inference function as a python generator that processes the
|
||||
prompt and then autoregressively yields tokens one at a time.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
...
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same way as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
cache.append(c) # <--- we store the per layer cache in a
|
||||
# simple python list
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
|
||||
# that generate the next token
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
yield y
|
||||
|
||||
Putting it all together
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We now have everything we need to create a Llama model and sample tokens from
|
||||
it. In the following code, we randomly initialize a small Llama model, process
|
||||
6 tokens of prompt and generate 10 tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
|
||||
|
||||
# Since MLX is lazily evaluated nothing has actually been materialized yet.
|
||||
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
|
||||
# code above would still run. Let's actually materialize the model.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
|
||||
# have a batch dimension even
|
||||
# though it is 1 in this case
|
||||
|
||||
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
|
||||
|
||||
# Since we haven't evaluated anything, nothing is computed yet. The list
|
||||
# `generated` contains the arrays that hold the computation graph for the
|
||||
# full processing of the prompt and the generation of 10 tokens.
|
||||
#
|
||||
# We can evaluate them one at a time, or all together. Concatenate them or
|
||||
# print them. They would all result in very similar runtimes and give exactly
|
||||
# the same results.
|
||||
mx.eval(generated)
|
||||
|
||||
Converting the weights
|
||||
----------------------
|
||||
|
||||
This section assumes that you have access to the original Llama weights and the
|
||||
SentencePiece model that comes with them. We will write a small script to
|
||||
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
|
||||
that can be loaded directly by MLX.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
||||
|
||||
|
||||
Weight loading and benchmarking
|
||||
-------------------------------
|
||||
|
||||
After converting the weights to be compatible to our implementation, all that is
|
||||
left is to load them from disk and we can finally use the LLM to generate text.
|
||||
We can load numpy format files using the :func:`mlx.core.load` operation.
|
||||
|
||||
To create a parameter dictionary from the key/value representation of NPZ files
|
||||
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
model.update(tree_unflatten(list(mx.load(weight_file).items())))
|
||||
|
||||
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
|
||||
like ``layers.2.attention.query_proj.weight`` and will transform them to
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
|
||||
|
||||
which can then be used to update the model. Note that the method above incurs
|
||||
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||
will be replaced in the future with direct loading to MLX.
|
||||
|
||||
You can download the full example code in `mlx-examples`_. Assuming, the
|
||||
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||
directory we can play around with our inference script as follows (the timings
|
||||
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python convert.py weights.pth llama-7B.mlx.npz
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
|
||||
------
|
||||
[INFO] Prompt processing: 0.437 s
|
||||
[INFO] Full generation: 4.330 s
|
||||
|
||||
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
|
||||
of those are spent processing the prompt. This amounts to a little over **39 ms
|
||||
per token**.
|
||||
|
||||
By running with a much bigger prompt we can see that the per token generation
|
||||
time as well as the prompt processing time remains almost constant.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
|
||||
------
|
||||
[INFO] Prompt processing: 0.579 s
|
||||
[INFO] Full generation: 4.690 s
|
||||
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.628 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
|
||||
------
|
||||
[INFO] Prompt processing: 0.633 s
|
||||
[INFO] Full generation: 21.475 s
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples`_.
|
||||
|
||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
preprint arXiv:2104.09864.
|
||||
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
|
||||
Advances in Neural Information Processing Systems, 32.
|
||||
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
|
||||
arXiv:2002.05202.
|
134
docs/build/html/_sources/examples/mlp.rst
vendored
Normal file
134
docs/build/html/_sources/examples/mlp.rst
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
.. _mlp:
|
||||
|
||||
Multi-Layer Perceptron
|
||||
----------------------
|
||||
|
||||
In this example we'll learn to use ``mlx.nn`` by implementing a simple
|
||||
multi-layer perceptron to classify MNIST.
|
||||
|
||||
As a first step import the MLX packages we need:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
The model is defined as the ``MLP`` class which inherits from
|
||||
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
|
||||
|
||||
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
|
||||
the :ref:`Module class docs<module_class>` for more information on how
|
||||
:class:`mlx.nn.Module` registers parameters.
|
||||
2. Define a ``__call__`` where the computation is implemented.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
We define the loss function which takes the mean of the per-example cross
|
||||
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
|
||||
commonly used loss functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
|
||||
We also need a function to compute the accuracy of the model on the validation
|
||||
set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as ``mnist``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
import mnist
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, mnist.mnist()
|
||||
)
|
||||
|
||||
Since we're using SGD, we need an iterator which shuffles and constructs
|
||||
minibatches of examples in the training set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
Finally, we put it all together by instantiating the model, the
|
||||
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Get a function which gives the loss and gradient of the
|
||||
# loss with respect to the model's trainable parameters
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Instantiate the optimizer
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
|
||||
# Update the optimizer state and model parameters
|
||||
# in a single call
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Force a graph evaluation
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
|
||||
|
||||
|
||||
.. note::
|
||||
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
|
||||
the gradient of a loss with respect to the trainable parameters of a model.
|
||||
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||
|
||||
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
|
||||
is available in the MLX GitHub repo.
|
93
docs/build/html/_sources/index.rst
vendored
Normal file
93
docs/build/html/_sources/index.rst
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
MLX
|
||||
===
|
||||
|
||||
MLX is a NumPy-like array framework designed for efficient and flexible machine
|
||||
learning on Apple silicon, brought to you by Apple machine learning research.
|
||||
|
||||
The Python API closely follows NumPy with a few exceptions. MLX also has a
|
||||
fully featured C++ API which closely follows the Python API.
|
||||
|
||||
The main differences between MLX and NumPy are:
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Multi-device**: Operations can run on any of the supported devices (CPU,
|
||||
GPU, ...)
|
||||
|
||||
The design of MLX is inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
are the CPU and GPU.
|
||||
|
||||
.. toctree::
|
||||
:caption: Install
|
||||
:maxdepth: 1
|
||||
|
||||
install
|
||||
|
||||
.. toctree::
|
||||
:caption: Usage
|
||||
:maxdepth: 1
|
||||
|
||||
usage/quick_start
|
||||
usage/lazy_evaluation
|
||||
usage/unified_memory
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
usage/export
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
:maxdepth: 1
|
||||
|
||||
examples/linear_regression
|
||||
examples/mlp
|
||||
examples/llama-inference
|
||||
|
||||
.. toctree::
|
||||
:caption: Python API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fast
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
:caption: C++ API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
cpp/ops
|
||||
|
||||
.. toctree::
|
||||
:caption: Further Reading
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
345
docs/build/html/_sources/install.rst
vendored
Normal file
345
docs/build/html/_sources/install.rst
vendored
Normal file
@@ -0,0 +1,345 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
Python Installation
|
||||
-------------------
|
||||
|
||||
MLX is available on PyPI. All you have to do to use MLX with your own Apple
|
||||
silicon computer is
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI your system must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cpu]
|
||||
|
||||
To install the CPU-only package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
*My OS and Python versions are in the required range but pip still does not find
|
||||
a matching distribution.*
|
||||
|
||||
Probably you are using a non-native Python. The output of
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -c "import platform; print(platform.processor())"
|
||||
|
||||
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
||||
are using a non-native Python. Switch your Python to a native Python. A good
|
||||
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
|
||||
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
|
||||
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install -e ".[dev]"
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your
|
||||
IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
by cloning MLX from `its GitHub repo
|
||||
<https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Create a build directory and run CMake and make:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make test
|
||||
|
||||
Install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make install
|
||||
|
||||
Note that the built ``mlx.metallib`` file should be either at the same
|
||||
directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
* - Option
|
||||
- Default
|
||||
* - MLX_BUILD_TESTS
|
||||
- ON
|
||||
* - MLX_BUILD_EXAMPLES
|
||||
- OFF
|
||||
* - MLX_BUILD_BENCHMARKS
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_CPU
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
* - MLX_BUILD_SAFETENSORS
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
macOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Metal not found
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
You see the following error when you try to build:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
error: unable to find utility "metal", not a developer tool or in PATH
|
||||
|
||||
To fix this, first make sure you have Xcode installed:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcode-select --install
|
||||
|
||||
Then set the active developer directory:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
x86 Shell
|
||||
~~~~~~~~~
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
|
||||
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
|
||||
terminal.
|
||||
|
||||
Verify the terminal is now running natively the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ uname -p
|
||||
arm
|
||||
|
||||
Also check that cmake is using the correct architecture:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
|
||||
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
|
||||
|
||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||
wipe your build cache with ``rm -rf build/`` and try again.
|
28
docs/build/html/_sources/python/_autosummary/mlx.core.Device.rst
vendored
Normal file
28
docs/build/html/_sources/python/_autosummary/mlx.core.Device.rst
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
mlx.core.Device
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: Device
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Device.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Device.type
|
||||
|
||||
|
28
docs/build/html/_sources/python/_autosummary/mlx.core.Dtype.rst
vendored
Normal file
28
docs/build/html/_sources/python/_autosummary/mlx.core.Dtype.rst
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
mlx.core.Dtype
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: Dtype
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Dtype.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Dtype.size
|
||||
|
||||
|
29
docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst
vendored
Normal file
29
docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
mlx.core.DtypeCategory
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: DtypeCategory
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~DtypeCategory.complexfloating
|
||||
~DtypeCategory.floating
|
||||
~DtypeCategory.inexact
|
||||
~DtypeCategory.signedinteger
|
||||
~DtypeCategory.unsignedinteger
|
||||
~DtypeCategory.integer
|
||||
~DtypeCategory.number
|
||||
~DtypeCategory.generic
|
||||
|
||||
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.abs.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.abs.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.abs
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: abs
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.add.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.add.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.add
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: add
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.addmm.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.addmm.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.addmm
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: addmm
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.all.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.all.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.all
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: all
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.allclose.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.allclose.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.allclose
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: allclose
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.any.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.any.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.any
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: any
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arange.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arange.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arange
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arange
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccos.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccos.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arccos
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arccos
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccosh.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccosh.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arccosh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arccosh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsin.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsin.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arcsin
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arcsin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsinh.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsinh.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arcsinh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arcsinh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arctan
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arctan
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arctan2
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arctan2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctanh.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctanh.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.arctanh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arctanh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmax.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmax.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.argmax
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argmax
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmin.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmin.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.argmin
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argmin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argpartition.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argpartition.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.argpartition
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argpartition
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argsort.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argsort.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.argsort
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argsort
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.T.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.T.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.T
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.T
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.abs.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.abs.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.abs
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.abs
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.all.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.all.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.all
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.all
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.any.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.any.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.any
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.any
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmax.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmax.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.argmax
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.argmax
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmin.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmin.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.argmin
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.argmin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.astype.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.astype.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.astype
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.astype
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.at.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.at.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.at
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.at
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.conj.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.conj.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.conj
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.conj
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cos.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cos.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.cos
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cos
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cummax.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cummax.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.cummax
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cummax
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cummin.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cummin.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.cummin
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cummin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cumprod.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cumprod.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.cumprod
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cumprod
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cumsum.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cumsum.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.cumsum
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cumsum
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.diag.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.diag.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.diag
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.diag
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.diagonal.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.diagonal.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.diagonal
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.diagonal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.dtype.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.dtype.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.dtype
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.dtype
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.exp.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.exp.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.exp
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.exp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.flatten.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.flatten.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.flatten
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.flatten
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.imag.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.imag.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.imag
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.imag
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.item.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.item.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.item
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.item
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.itemsize.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.itemsize.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.itemsize
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.itemsize
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.log
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log10.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log10.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.log10
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log10
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log1p.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log1p.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.log1p
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log1p
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log2.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log2.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.log2
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logcumsumexp.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logcumsumexp.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.logcumsumexp
|
||||
===========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.logcumsumexp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logsumexp.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logsumexp.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.logsumexp
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.logsumexp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.max.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.max.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.max
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.max
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.mean.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.mean.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.mean
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.mean
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.min.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.min.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.min
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.min
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.moveaxis.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.moveaxis.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.moveaxis
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.moveaxis
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.nbytes.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.nbytes.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.nbytes
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.nbytes
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.ndim.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.ndim.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.ndim
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.ndim
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.prod.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.prod.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.prod
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.prod
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.real.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.real.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.real
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.real
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reciprocal.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reciprocal.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.reciprocal
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.reciprocal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reshape.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reshape.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.reshape
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.reshape
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.round.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.round.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.round
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.round
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.rsqrt.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.rsqrt.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.rsqrt
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.rsqrt
|
81
docs/build/html/_sources/python/_autosummary/mlx.core.array.rst
vendored
Normal file
81
docs/build/html/_sources/python/_autosummary/mlx.core.array.rst
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
mlx.core.array
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: array
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~array.__init__
|
||||
~array.abs
|
||||
~array.all
|
||||
~array.any
|
||||
~array.argmax
|
||||
~array.argmin
|
||||
~array.astype
|
||||
~array.conj
|
||||
~array.cos
|
||||
~array.cummax
|
||||
~array.cummin
|
||||
~array.cumprod
|
||||
~array.cumsum
|
||||
~array.diag
|
||||
~array.diagonal
|
||||
~array.exp
|
||||
~array.flatten
|
||||
~array.item
|
||||
~array.log
|
||||
~array.log10
|
||||
~array.log1p
|
||||
~array.log2
|
||||
~array.logcumsumexp
|
||||
~array.logsumexp
|
||||
~array.max
|
||||
~array.mean
|
||||
~array.min
|
||||
~array.moveaxis
|
||||
~array.prod
|
||||
~array.reciprocal
|
||||
~array.reshape
|
||||
~array.round
|
||||
~array.rsqrt
|
||||
~array.sin
|
||||
~array.split
|
||||
~array.sqrt
|
||||
~array.square
|
||||
~array.squeeze
|
||||
~array.std
|
||||
~array.sum
|
||||
~array.swapaxes
|
||||
~array.tolist
|
||||
~array.transpose
|
||||
~array.var
|
||||
~array.view
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~array.T
|
||||
~array.at
|
||||
~array.dtype
|
||||
~array.imag
|
||||
~array.itemsize
|
||||
~array.nbytes
|
||||
~array.ndim
|
||||
~array.real
|
||||
~array.shape
|
||||
~array.size
|
||||
|
||||
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.shape.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.shape.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.shape
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.shape
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sin.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sin.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.sin
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.size.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.size.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.size
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.size
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.split.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.split.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.split
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.split
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sqrt.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sqrt.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.sqrt
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sqrt
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.square.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.square.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.square
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.square
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.squeeze.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.squeeze.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.squeeze
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.squeeze
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.std.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.std.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.std
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.std
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sum.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sum.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.sum
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sum
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.swapaxes.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.swapaxes.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.swapaxes
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.swapaxes
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.tolist.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.tolist.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.tolist
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.tolist
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.transpose.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.transpose.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.transpose
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.transpose
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.var.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.var.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.var
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.var
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.view.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.view.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array.view
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.view
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array_equal.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array_equal.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.array\_equal
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: array_equal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.as_strided.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.as_strided.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.as\_strided
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: as_strided
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.async_eval.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.async_eval.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.async\_eval
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: async_eval
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user