mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
Compare commits
176 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
006d01ba42 | ||
![]() |
46dc24d835 | ||
![]() |
c9934fe8a4 | ||
![]() |
975e265f74 | ||
![]() |
c92a134b0d | ||
![]() |
3b4f066dac | ||
![]() |
b7f905787e | ||
![]() |
e3e933c6bc | ||
![]() |
1d90a76d63 | ||
![]() |
961435a243 | ||
![]() |
e9ca65c939 | ||
![]() |
753867123d | ||
![]() |
f099ebe535 | ||
![]() |
f45f70f133 | ||
![]() |
0b8aeddac6 | ||
![]() |
432ee5650b | ||
![]() |
73321b8097 | ||
![]() |
022a944367 | ||
![]() |
026ef9aae4 | ||
![]() |
a611b0bc82 | ||
![]() |
449b43762e | ||
![]() |
6ea6b4258d | ||
![]() |
48f6ca8c3a | ||
![]() |
c6d2878c1a | ||
![]() |
b34bf5d52b | ||
![]() |
608bd43604 | ||
![]() |
4c48f6460d | ||
![]() |
1331fa19f6 | ||
![]() |
dfdb284e16 | ||
![]() |
d8f41a5c0f | ||
![]() |
b9e415d19c | ||
![]() |
c82a8cc526 | ||
![]() |
75dc537e44 | ||
![]() |
cf88db44b5 | ||
![]() |
16856a0160 | ||
![]() |
d752f8e142 | ||
![]() |
d2467c320d | ||
![]() |
0d31128a44 | ||
![]() |
1ac18eac20 | ||
![]() |
526466dd09 | ||
![]() |
e7f5059fe4 | ||
![]() |
d7ac050f4b | ||
![]() |
c7edafb729 | ||
![]() |
dff4a3833f | ||
![]() |
0782a4573a | ||
![]() |
af66a09bde | ||
![]() |
436bec9fd9 | ||
![]() |
99c80a2c8b | ||
![]() |
295ce9db09 | ||
![]() |
44c1ce5e6a | ||
![]() |
144ecff849 | ||
![]() |
350095ce6e | ||
![]() |
e09bf35b28 | ||
![]() |
99c20f523e | ||
![]() |
e3b8da2a49 | ||
![]() |
a020a2d49d | ||
![]() |
930b159885 | ||
![]() |
5ad8fb7268 | ||
![]() |
2aedf3e791 | ||
![]() |
473b6b43b4 | ||
![]() |
d29770eeaa | ||
![]() |
040c3bafab | ||
![]() |
05767b026f | ||
![]() |
a83d5d60bd | ||
![]() |
ff2b58e299 | ||
![]() |
4417e37ede | ||
![]() |
79c95b6919 | ||
![]() |
1f6ab6a556 | ||
![]() |
6b0d30bb85 | ||
![]() |
447bc089b9 | ||
![]() |
fc4e5b476b | ||
![]() |
d58ac083f3 | ||
![]() |
a123c3c7d2 | ||
![]() |
9e6b8c9f48 | ||
![]() |
22fee5a383 | ||
![]() |
7365d142a3 | ||
![]() |
8b227fa9af | ||
![]() |
8c3da54c7d | ||
![]() |
acf1721b98 | ||
![]() |
f91f450141 | ||
![]() |
cd3616a463 | ||
![]() |
d35fa1db41 | ||
![]() |
e8deca84e0 | ||
![]() |
8385f93cea | ||
![]() |
2118c3dbfa | ||
![]() |
a002797d52 | ||
![]() |
1d053e0d1d | ||
![]() |
0aa65c7a6b | ||
![]() |
794feb83df | ||
![]() |
2c7df6795e | ||
![]() |
b3916cbf2b | ||
![]() |
57fe918cf8 | ||
![]() |
4912ff3ec2 | ||
![]() |
f40d17047d | ||
![]() |
2807c6aff0 | ||
![]() |
de892cb66c | ||
![]() |
37024d899c | ||
![]() |
137f55bf28 | ||
![]() |
e549f84532 | ||
![]() |
dfa9f4bc58 | ||
![]() |
e6872a4149 | ||
![]() |
f4f6e17d45 | ||
![]() |
4d4af12c6f | ||
![]() |
477397bc98 | ||
![]() |
18cca64c81 | ||
![]() |
0e5807bbcb | ||
![]() |
8eb56beb3a | ||
![]() |
ee0c2835c5 | ||
![]() |
90d04072b7 | ||
![]() |
52e1589a52 | ||
![]() |
eebd7c275d | ||
![]() |
a67bbfe745 | ||
![]() |
104c34f906 | ||
![]() |
dc2edc762c | ||
![]() |
2e02acdc83 | ||
![]() |
83f266c44c | ||
![]() |
f24200db2c | ||
![]() |
e28b57e371 | ||
![]() |
e5851e52b1 | ||
![]() |
f55908bc48 | ||
![]() |
b93c4cf378 | ||
![]() |
1e0c78b970 | ||
![]() |
76e1af0e02 | ||
![]() |
c3272d4917 | ||
![]() |
50f5d14b11 | ||
![]() |
d14a0e4ff9 | ||
![]() |
fb675de30d | ||
![]() |
25f70d4ca4 | ||
![]() |
02de234ef0 | ||
![]() |
f5df47ec6e | ||
![]() |
b9226c367c | ||
![]() |
3214629601 | ||
![]() |
072044e28f | ||
![]() |
e080290ba4 | ||
![]() |
69505b4e9b | ||
![]() |
f4ddd7dc44 | ||
![]() |
b0cd092b7f | ||
![]() |
71d1fff90a | ||
![]() |
0cfbfc9904 | ||
![]() |
2d0130f80f | ||
![]() |
c1e1c1443f | ||
![]() |
68bf1d7867 | ||
![]() |
600db7d754 | ||
![]() |
ef7b8756c0 | ||
![]() |
0b28399638 | ||
![]() |
ac6dc5d3eb | ||
![]() |
89b90dcfec | ||
![]() |
fd836d891b | ||
![]() |
976e8babbe | ||
![]() |
2520dbcf0a | ||
![]() |
430bfb4944 | ||
![]() |
08d51bf232 | ||
![]() |
cb9e585b8e | ||
![]() |
641d316484 | ||
![]() |
2b714714e1 | ||
![]() |
69a24e6a1e | ||
![]() |
5b9be57ac3 | ||
![]() |
e89c571de7 | ||
![]() |
209404239b | ||
![]() |
4e3bdb560c | ||
![]() |
86b614afcd | ||
![]() |
cfc39d84b7 | ||
![]() |
d11d77e581 | ||
![]() |
bf410cb85e | ||
![]() |
2e126aeb7e | ||
![]() |
dfbc52ce56 | ||
![]() |
43e336cff2 | ||
![]() |
d895e38f2e | ||
![]() |
d15dead35e | ||
![]() |
2440fe0124 | ||
![]() |
170e4b2d43 | ||
![]() |
2629cc8682 | ||
![]() |
9f4cf2e0fe | ||
![]() |
2ffaee0c0d | ||
![]() |
36b245b287 | ||
![]() |
8c96b9a890 |
@@ -62,6 +62,7 @@ jobs:
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
- run:
|
||||
name: Build python package
|
||||
@@ -79,6 +80,13 @@ jobs:
|
||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
|
||||
build_release:
|
||||
machine: true
|
||||
@@ -104,7 +112,7 @@ jobs:
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build pacakge
|
||||
name: Build package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
@@ -140,7 +148,7 @@ jobs:
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build pacakge
|
||||
name: Build package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
@@ -176,7 +184,7 @@ jobs:
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build pacakge
|
||||
name: Build package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
@@ -203,7 +211,7 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
nightly_build:
|
||||
when: << pipeline.parameters.nightly_build >>
|
||||
@@ -211,7 +219,7 @@ workflows:
|
||||
- build_package:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
weekly_build:
|
||||
when: << pipeline.parameters.weekly_build >>
|
||||
@@ -219,5 +227,5 @@ workflows:
|
||||
- build_dev_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
|
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report about an issue you've encountered
|
||||
title: "[BUG] "
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Include code snippet
|
||||
```python
|
||||
|
||||
```
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS Version: [e.g. MacOS 14.1.2]
|
||||
- Version [e.g. 0.7.0]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
12
.github/pull_request_template.md
vendored
Normal file
12
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
## Proposed changes
|
||||
|
||||
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
|
||||
|
||||
## Checklist
|
||||
|
||||
Put an `x` in the boxes that apply.
|
||||
|
||||
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
|
||||
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
|
||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] I have updated the necessary documentation (if needed)
|
20
.github/workflows/pull_request.yml
vendored
Normal file
20
.github/workflows/pull_request.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
9
.gitignore
vendored
9
.gitignore
vendored
@@ -6,10 +6,16 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# tensor files
|
||||
*.safe
|
||||
*.safetensors
|
||||
|
||||
# Metal libraries
|
||||
*.metallib
|
||||
venv/
|
||||
|
||||
# Distribution / packaging
|
||||
python/mlx/core
|
||||
python/mlx/share
|
||||
python/mlx/include
|
||||
.Python
|
||||
@@ -73,3 +79,6 @@ build/
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
||||
# Jetbrains
|
||||
.cache
|
||||
|
@@ -1,9 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v14.0.6
|
||||
rev: v17.0.6
|
||||
hooks:
|
||||
- id: clang-format
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.12.1
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
|
@@ -1,3 +1,24 @@
|
||||
# Individual Contributors
|
||||
|
||||
If you wish to be acknowledged for your contributions, please list your name
|
||||
with a short description of your contribution(s) below. For example:
|
||||
|
||||
- Jane Smith: Added the `foo` and `bar` ops.
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
MLX leverages several third-party software, listed here together with
|
||||
their license copied verbatim.
|
||||
|
||||
@@ -231,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
limitations under the License.
|
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
|
||||
project(mlx LANGUAGES CXX)
|
||||
project(mlx LANGUAGES C CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
@@ -18,7 +18,28 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.0.1)
|
||||
set(MLX_VERSION 0.0.9)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -37,20 +58,26 @@ endif()
|
||||
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
add_compile_definitions(_METAL_)
|
||||
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
|
||||
OUTPUT_VARIABLE MACOS_VERSION)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
message(STATUS "Detected macOS version ${MACOS_VERSION}")
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||
else()
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
|
||||
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
@@ -72,13 +99,13 @@ elseif (MLX_BUILD_METAL)
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (ACCELERATE_LIBRARY)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -125,6 +152,8 @@ if (MLX_BUILD_BENCHMARKS)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# ----------------------------- Installation -----------------------------
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -194,4 +223,4 @@ install(
|
||||
install(
|
||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
)
|
||||
|
58
README.md
58
README.md
@@ -2,39 +2,41 @@
|
||||
|
||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||
[**Examples**](#examples)
|
||||
[**Examples**](#examples)
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
||||
by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API which closely follows NumPy.
|
||||
MLX also has a fully featured C++ API which closely mirrors the Python API.
|
||||
MLX has higher level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
that closely follow PyTorch to simplify building more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are built
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and GPU).
|
||||
(currently the CPU and the GPU).
|
||||
|
||||
- **Unified memory**: A noteable difference from MLX and other frameworks
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without moving data.
|
||||
device types without transferring data.
|
||||
|
||||
MLX is designed by machine learning researchers for machine learning
|
||||
researchers. The framework is intended to be user friendly, but still efficient
|
||||
researchers. The framework is intended to be user-friendly, but still efficient
|
||||
to train and deploy models. The design of the framework itself is also
|
||||
conceptually simple. We intend to make it easy for researchers to extend and
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
@@ -47,11 +49,11 @@ The design of MLX is inspired by frameworks like
|
||||
## Examples
|
||||
|
||||
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
|
||||
variety of examples including:
|
||||
variety of examples, including:
|
||||
|
||||
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
||||
- Large scale text generation with
|
||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
|
||||
- Large-scale text generation with
|
||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
||||
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
||||
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
||||
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
||||
@@ -59,12 +61,12 @@ variety of examples including:
|
||||
## Quickstart
|
||||
|
||||
See the [quick start
|
||||
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
|
||||
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
|
||||
in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
|
||||
```
|
||||
pip install mlx
|
||||
@@ -77,4 +79,28 @@ for more information on building the C++ and Python APIs from source.
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
on contributing to MLX.
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
## Citing MLX
|
||||
|
||||
The MLX software suite was initially developed with equal contribution by Awni
|
||||
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||
MLX useful in your research and wish to cite it, please use the following
|
||||
BibTex entry:
|
||||
|
||||
```
|
||||
@software{mlx2023,
|
||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||
url = {https://github.com/ml-explore},
|
||||
version = {0.0},
|
||||
year = {2023},
|
||||
}
|
||||
```
|
||||
|
@@ -233,6 +233,20 @@ void time_gather_scatter() {
|
||||
TIME(single_element_add);
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
@@ -246,4 +260,5 @@ int main() {
|
||||
time_matmul();
|
||||
time_reductions();
|
||||
time_gather_scatter();
|
||||
time_divmod();
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 8
|
||||
@@ -30,7 +30,7 @@ def time_batch_matmul():
|
||||
time_fn(batch_vjp_second)
|
||||
|
||||
|
||||
def time_unbatch_matmul(key):
|
||||
def time_unbatch_matmul():
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B * T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
|
@@ -1,13 +1,14 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
@@ -1,14 +1,14 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
results_dir = "./results"
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||
|
||||
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
@@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
ax.legend()
|
||||
|
||||
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
|
@@ -4,8 +4,10 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def int_or_list(x):
|
||||
@@ -22,6 +24,16 @@ def none_or_list(x):
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def dtype_from_str(x):
|
||||
if x == "":
|
||||
return mx.float32
|
||||
else:
|
||||
dt = getattr(mx, x)
|
||||
if not isinstance(dt, mx.Dtype):
|
||||
raise ValueError(f"{x} is not an mlx dtype")
|
||||
return dt
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(10):
|
||||
f(*args)
|
||||
@@ -48,6 +60,23 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def _quant_matmul(x, w, s, b, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
|
||||
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
|
||||
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
|
||||
}
|
||||
|
||||
|
||||
def conv1d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
@@ -95,7 +124,77 @@ def softmax_fused(axis, x):
|
||||
def relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = mx.maximum(y, 0)
|
||||
y = nn.relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def prelu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.prelu(y, mx.ones(1))
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def mish(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.mish(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def elu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.elu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def relu6(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.relu6(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def celu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.celu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def log_sigmoid(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.log_sigmoid(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
@@ -180,6 +279,20 @@ def topk(axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.step(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.selu(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
@@ -211,9 +324,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--fused", action="store_true", help="Use fused functions where possible"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
|
||||
)
|
||||
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -230,11 +341,15 @@ if __name__ == "__main__":
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
|
||||
args.dtype
|
||||
]
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
types = [mx.float32]
|
||||
if len(types) < len(args.size):
|
||||
types = types + [types[0]] * (len(args.size) - len(types))
|
||||
|
||||
xs = []
|
||||
for size in args.size:
|
||||
for size, dtype in zip(args.size, types):
|
||||
xs.append(mx.random.normal(size).astype(dtype))
|
||||
for i, t in enumerate(args.transpose):
|
||||
if t is None:
|
||||
@@ -250,6 +365,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark.startswith("quant_matmul"):
|
||||
print(bench(quant_matmul[args.benchmark], *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
|
||||
@@ -277,6 +395,26 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
@@ -311,5 +449,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -22,6 +22,16 @@ def none_or_list(x):
|
||||
return [int(xi) for xi in x.split(",")]
|
||||
|
||||
|
||||
def dtype_from_str(x):
|
||||
if x == "":
|
||||
return torch.float32
|
||||
else:
|
||||
dt = getattr(torch, x)
|
||||
if not isinstance(dt, torch.dtype):
|
||||
raise ValueError(f"{x} is not a torch dtype")
|
||||
return dt
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(10):
|
||||
f(*args)
|
||||
@@ -115,6 +125,70 @@ def relu(x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def leaky_relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.leaky_relu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def elu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.elu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def celu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.celu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def relu6(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.relu6(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.softplus(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_sigmoid(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.logsigmoid(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
@@ -209,6 +283,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.selu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
@@ -240,7 +322,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--fused", action="store_true", help="Use fused functions where possible"
|
||||
)
|
||||
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
|
||||
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -255,9 +337,15 @@ if __name__ == "__main__":
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
types = [torch.float32]
|
||||
if len(types) < len(args.size):
|
||||
types = types + [types[0]] * (len(args.size) - len(types))
|
||||
|
||||
xs = []
|
||||
for size in args.size:
|
||||
for size, dtype in zip(args.size, types):
|
||||
xs.append(torch.randn(*size).to(device).to(dtype))
|
||||
for i, t in enumerate(args.transpose):
|
||||
if t is None:
|
||||
@@ -302,6 +390,28 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
@@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
||||
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
||||
parser.add_argument(
|
||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||
)
|
||||
@@ -125,6 +125,14 @@ if __name__ == "__main__":
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||
@@ -193,6 +201,27 @@ if __name__ == "__main__":
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||
compare_filtered("relu --size 32x16x1024")
|
||||
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||
compare_filtered("leaky_relu --size 32x16x1024")
|
||||
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
|
||||
compare_filtered("elu --size 32x16x1024")
|
||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu6 --size 32x16x1024")
|
||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||
compare_filtered("softplus --size 32x16x1024")
|
||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||
compare_filtered("celu --size 32x16x1024")
|
||||
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||
compare_filtered("step --size 32x16x1024")
|
||||
compare_filtered("step --size 32x16x1024 --cpu")
|
||||
compare_filtered("selu --size 32x16x1024")
|
||||
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
|
||||
compare_filtered("mish --size 32x16x1024 --cpu")
|
||||
compare_filtered("prelu --size 32x16x1024")
|
||||
compare_filtered("prelu --size 32x16x1024 --cpu")
|
||||
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
|
@@ -4,8 +4,8 @@ import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.mps
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
|
@@ -12,7 +12,7 @@ include(CMakeParseArguments)
|
||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||
# SOURCES: List of source files
|
||||
# INCLUDE_DIRS: List of include dirs
|
||||
# DEPS: List of depedency files (like headers)
|
||||
# DEPS: List of dependency files (like headers)
|
||||
#
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
@@ -32,7 +32,7 @@ macro(mlx_build_metallib)
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
|
||||
# Prepare metllib build command
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
|
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
src/python/_autosummary*/
|
||||
src/python/nn/_autosummary*/
|
||||
|
@@ -26,7 +26,7 @@ python -m http.server <port>
|
||||
|
||||
and point your browser to `http://localhost:<port>`.
|
||||
|
||||
### Push to Github Pages
|
||||
### Push to GitHub Pages
|
||||
|
||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||
the docs. Then force add the `build/html` directory:
|
||||
|
33
docs/src/_templates/module-base-class.rst
Normal file
33
docs/src/_templates/module-base-class.rst
Normal file
@@ -0,0 +1,33 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. add toctree option to make autodoc generate the pages
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block attributes %}
|
||||
{% if attributes %}
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in attributes %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block methods %}
|
||||
{% if methods %}
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
@@ -10,8 +10,8 @@ import subprocess
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = "0.0.3"
|
||||
release = "0.0.3"
|
||||
version = "0.0.9"
|
||||
release = "0.0.9"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -15,7 +15,7 @@ Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
@@ -69,7 +69,7 @@ C++ API:
|
||||
.. code-block:: C++
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@@ -131,7 +131,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself accross
|
||||
* The primitive must know how to vectorize itself across
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
@@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
This operation now handles the following:
|
||||
|
||||
#. Upcast inputs and resolve the the output data type.
|
||||
#. Upcast inputs and resolve the output data type.
|
||||
#. Broadcast the inputs and resolve the output shape.
|
||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||
@@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
@@ -305,7 +305,7 @@ if we encounter an unexpected type.
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@@ -485,7 +485,7 @@ each data type.
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
@@ -537,7 +537,7 @@ below.
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
// those in the kernel declaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
@@ -568,7 +568,7 @@ below.
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
@@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
|
||||
@@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
@@ -642,7 +642,7 @@ own :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
@@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
@@ -840,7 +840,7 @@ This will result in a directory structure as follows:
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
||||
(in ``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
|
||||
@@ -945,4 +945,4 @@ Scripts
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
|
@@ -321,7 +321,7 @@ which can then be used to update the model. Note that the method above incurs
|
||||
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||
will be replaced in the future with direct loading to MLX.
|
||||
|
||||
You can download the full example code in `mlx-examples <code>`_. Assuming, the
|
||||
You can download the full example code in `mlx-examples`_. Assuming, the
|
||||
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||
directory we can play around with our inference script as follows (the timings
|
||||
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||
@@ -369,9 +369,9 @@ Scripts
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
The full example code is available in `mlx-examples`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
|
||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
|
@@ -61,7 +61,10 @@ set:
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
Next, setup the problem parameters and load the data:
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as `mnist`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -127,5 +130,5 @@ Finally, we put it all together by instantiating the model, the
|
||||
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||
|
||||
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
|
||||
is available in the MLX GitHub repo.
|
||||
|
@@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
|
||||
|
||||
The design of MLX is inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
@@ -35,8 +35,13 @@ are the CPU and GPU.
|
||||
:caption: Usage
|
||||
:maxdepth: 1
|
||||
|
||||
quick_start
|
||||
using_streams
|
||||
usage/quick_start
|
||||
usage/lazy_evaluation
|
||||
usage/unified_memory
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/numpy
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
@@ -56,6 +61,7 @@ are the CPU and GPU.
|
||||
python/random
|
||||
python/transforms
|
||||
python/fft
|
||||
python/linalg
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/tree_utils
|
||||
|
@@ -11,6 +11,33 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- macOS >= 13.3
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.3
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
*My OS and Python versions are in the required range but pip still does not find
|
||||
a matching distribution.*
|
||||
|
||||
Probably you are using a non-native Python. The output of
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -c "import platform; print(platform.processor())"
|
||||
|
||||
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
||||
are using a non-native Python. Switch your Python to a native Python. A good
|
||||
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
|
||||
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
@@ -19,7 +46,11 @@ Build Requirements
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
||||
|
||||
.. note::
|
||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
|
||||
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
@@ -56,8 +87,16 @@ To make sure the install is working run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[dev]"
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
@@ -111,3 +150,66 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
macOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Metal not found
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
You see the following error when you try to build:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
error: unable to find utility "metal", not a developer tool or in PATH
|
||||
|
||||
To fix this, first make sure you have Xcode installed:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcode-select --install
|
||||
|
||||
Then set the active developer directory:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
x86 Shell
|
||||
~~~~~~~~~
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
|
||||
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
|
||||
terminal.
|
||||
|
||||
Verify the terminal is now running natively the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ uname -p
|
||||
arm
|
||||
|
@@ -34,6 +34,7 @@ Array
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
array.round
|
||||
array.rsqrt
|
||||
array.sin
|
||||
array.split
|
||||
|
@@ -29,9 +29,9 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``uint32``
|
||||
- 4
|
||||
- 32-bit unsigned integer
|
||||
* - ``uint32``
|
||||
* - ``uint64``
|
||||
- 8
|
||||
- 32-bit unsigned integer
|
||||
- 64-bit unsigned integer
|
||||
* - ``int8``
|
||||
- 1
|
||||
- 8-bit signed integer
|
||||
|
11
docs/src/python/linalg.rst
Normal file
11
docs/src/python/linalg.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
.. _linalg:
|
||||
|
||||
Linear Algebra
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
norm
|
@@ -64,7 +64,6 @@ Quick Start with Neural Networks
|
||||
# gradient with respect to `mlp.trainable_parameters()`
|
||||
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
|
||||
|
||||
|
||||
.. _module_class:
|
||||
|
||||
The Module Class
|
||||
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
|
||||
:meth:`Module.parameters` can be used to extract a nested dictionary with all
|
||||
the parameters of a module and its submodules.
|
||||
|
||||
A :class:`Module` can also keep track of "frozen" parameters.
|
||||
:meth:`Module.trainable_parameters` returns only the subset of
|
||||
:meth:`Module.parameters` that is not frozen. When using
|
||||
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
|
||||
trainable parameters.
|
||||
A :class:`Module` can also keep track of "frozen" parameters. See the
|
||||
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
|
||||
the gradients returned will be with respect to these trainable parameters.
|
||||
|
||||
Updating the parameters
|
||||
|
||||
Updating the Parameters
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module's parameters. This action is
|
||||
performed by :meth:`Module.update`.
|
||||
performed by :meth:`Module.update`.
|
||||
|
||||
Value and grad
|
||||
|
||||
Inspecting Modules
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The simplest way to see the model architecture is to print it. Following along with
|
||||
the above example, you can print the ``MLP`` with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
print(mlp)
|
||||
|
||||
This will display:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
MLP(
|
||||
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
|
||||
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
|
||||
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
|
||||
)
|
||||
|
||||
To get more detailed information on the arrays in a :class:`Module` you can use
|
||||
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
|
||||
all the parameters in a :class:`Module` do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||
|
||||
As another example, you can count the number of parameters in a :class:`Module`
|
||||
with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||
|
||||
|
||||
Value and Grad
|
||||
--------------
|
||||
|
||||
Using a :class:`Module` does not preclude using MLX's high order function
|
||||
@@ -137,36 +174,9 @@ In detail:
|
||||
|
||||
value_and_grad
|
||||
|
||||
Neural Network Layers
|
||||
---------------------
|
||||
.. toctree::
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
GELU
|
||||
SiLU
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
silu
|
||||
nn/module
|
||||
nn/layers
|
||||
nn/functions
|
||||
nn/losses
|
||||
|
23
docs/src/python/nn/functions.rst
Normal file
23
docs/src/python/nn/functions.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
.. _nn_functions:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
Layers without parameters (e.g. activation functions) are also provided as
|
||||
simple functions.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
prelu
|
||||
silu
|
||||
step
|
||||
selu
|
||||
mish
|
37
docs/src/python/nn/layers.rst
Normal file
37
docs/src/python/nn/layers.rst
Normal file
@@ -0,0 +1,37 @@
|
||||
.. _layers:
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
Layers
|
||||
------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Sequential
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
SiLU
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Embedding
|
||||
Linear
|
||||
QuantizedLinear
|
||||
Conv1d
|
||||
Conv2d
|
||||
BatchNorm
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
InstanceNorm
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Transformer
|
||||
MultiHeadAttention
|
||||
ALiBi
|
||||
RoPE
|
||||
SinusoidalPositionalEncoding
|
23
docs/src/python/nn/losses.rst
Normal file
23
docs/src/python/nn/losses.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
.. _losses:
|
||||
|
||||
.. currentmodule:: mlx.nn.losses
|
||||
|
||||
Loss Functions
|
||||
--------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
binary_cross_entropy
|
||||
cross_entropy
|
||||
kl_div_loss
|
||||
l1_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
smooth_l1_loss
|
||||
triplet_loss
|
||||
hinge_loss
|
||||
huber_loss
|
||||
log_cosh_loss
|
||||
cosine_similarity_loss
|
@@ -1,7 +1,36 @@
|
||||
mlx.nn.Module
|
||||
=============
|
||||
Module
|
||||
======
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
:members:
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.apply
|
||||
Module.apply_to_modules
|
||||
Module.children
|
||||
Module.eval
|
||||
Module.filter_and_map
|
||||
Module.freeze
|
||||
Module.leaf_modules
|
||||
Module.load_weights
|
||||
Module.modules
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
Module.update
|
||||
Module.update_modules
|
||||
|
@@ -26,23 +26,34 @@ Operations
|
||||
argsort
|
||||
array_equal
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
cos
|
||||
cosh
|
||||
dequantize
|
||||
divide
|
||||
divmod
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
inner
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
load
|
||||
log
|
||||
log2
|
||||
@@ -50,6 +61,8 @@ Operations
|
||||
log1p
|
||||
logaddexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
logsumexp
|
||||
matmul
|
||||
max
|
||||
@@ -57,19 +70,27 @@ Operations
|
||||
mean
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
negative
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
prod
|
||||
quantize
|
||||
quantized_matmul
|
||||
reciprocal
|
||||
repeat
|
||||
reshape
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
savez
|
||||
savez_compressed
|
||||
save_gguf
|
||||
save_safetensors
|
||||
sigmoid
|
||||
sign
|
||||
sin
|
||||
@@ -80,14 +101,20 @@ Operations
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
swapaxes
|
||||
take
|
||||
take_along_axis
|
||||
tan
|
||||
tanh
|
||||
tensordot
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
var
|
||||
where
|
||||
zeros
|
||||
|
@@ -38,4 +38,10 @@ model's parameters and the **optimizer state**.
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
|
@@ -14,3 +14,4 @@ Transforms
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
simplify
|
||||
|
123
docs/src/usage/indexing.rst
Normal file
123
docs/src/usage/indexing.rst
Normal file
@@ -0,0 +1,123 @@
|
||||
.. _indexing:
|
||||
|
||||
Indexing Arrays
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
||||
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
||||
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
||||
how that works.
|
||||
|
||||
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> arr[3]
|
||||
array(3, dtype=int32)
|
||||
>>> arr[-2] # negative indexing works
|
||||
array(8, dtype=int32)
|
||||
>>> arr[2:8:2] # start, stop, stride
|
||||
array([2, 4, 6], dtype=int32)
|
||||
|
||||
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
||||
>>> arr[:, :, 0]
|
||||
array(3, dtype=int32)
|
||||
array([[0, 2],
|
||||
[4, 6]], dtype=int32
|
||||
>>> arr[..., 0]
|
||||
array([[0, 2],
|
||||
[4, 6]], dtype=int32
|
||||
|
||||
You can index with ``None`` to create a new axis:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(8)
|
||||
>>> arr.shape
|
||||
[8]
|
||||
>>> arr[None].shape
|
||||
[1, 8]
|
||||
|
||||
|
||||
You can also use an :obj:`array` to index another :obj:`array`:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> arr[idx]
|
||||
array([5, 7], dtype=int32)
|
||||
|
||||
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
||||
works just as in NumPy.
|
||||
|
||||
Other functions which may be useful for indexing arrays are :func:`take` and
|
||||
:func:`take_along_axis`.
|
||||
|
||||
Differences from NumPy
|
||||
----------------------
|
||||
|
||||
.. Note::
|
||||
|
||||
MLX indexing is different from NumPy indexing in two important ways:
|
||||
|
||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||
undefined behavior.
|
||||
* Boolean mask based indexing is not yet supported.
|
||||
|
||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
||||
In Place Updates
|
||||
----------------
|
||||
|
||||
In place updates to indexed arrays are possible in MLX. For example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[2] = 0
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Just as in NumPy, in place updates will be reflected in all references to the
|
||||
same array:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, idx):
|
||||
x[idx] = 2.0
|
||||
return x.sum()
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
||||
|
||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||
and ones elsewhere.
|
144
docs/src/usage/lazy_evaluation.rst
Normal file
144
docs/src/usage/lazy_evaluation.rst
Normal file
@@ -0,0 +1,144 @@
|
||||
.. _lazy eval:
|
||||
|
||||
Lazy Evaluation
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Why Lazy Evaluation
|
||||
-------------------
|
||||
|
||||
When you perform operations in MLX, no computation actually happens. Instead a
|
||||
compute graph is recorded. The actual computation only happens if an
|
||||
:func:`eval` is performed.
|
||||
|
||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||
describe below.
|
||||
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Lazy evaluation let's us record a compute graph without actually doing any
|
||||
computations. This is useful for function transformations like :func:`grad` and
|
||||
:func:`vmap` and graph optimizations like :func:`simplify`.
|
||||
|
||||
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||
generated dynamically. However, lazy evaluation makes it much easier to
|
||||
integrate compilation for future performance enhancements.
|
||||
|
||||
Only Compute What You Use
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In MLX you do not need to worry as much about computing outputs that are never
|
||||
used. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
a = fun1(x)
|
||||
b = expensive_fun(a)
|
||||
return a, b
|
||||
|
||||
y, _ = fun(x)
|
||||
|
||||
Here, we never actually compute the output of ``expensive_fun``. Use this
|
||||
pattern with care though, as the graph of ``expensive_fun`` is still built, and
|
||||
that has some cost associated to it.
|
||||
|
||||
Similarly, lazy evaluation can be beneficial for saving memory while keeping
|
||||
code simple. Say you have a very large model ``Model`` derived from
|
||||
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
|
||||
Typically, this will initialize all of the weights as ``float32``, but the
|
||||
initialization does not actually compute anything until you perform an
|
||||
:func:`eval`. If you update the model with ``float16`` weights, your maximum
|
||||
consumed memory will be half that required if eager computation was used
|
||||
instead.
|
||||
|
||||
This pattern is simple to do in MLX thanks to lazy computation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = Model() # no memory used yet
|
||||
model.load_weights("weights_fp16.safetensors")
|
||||
|
||||
When to Evaluate
|
||||
----------------
|
||||
|
||||
A common question is when to use :func:`eval`. The trade-off is between
|
||||
letting graphs get too large and not batching enough useful work.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for _ in range(100):
|
||||
a = a + b
|
||||
mx.eval(a)
|
||||
b = b * 2
|
||||
mx.eval(b)
|
||||
|
||||
This is a bad idea because there is some fixed overhead with each graph
|
||||
evaluation. On the other hand, there is some slight overhead which grows with
|
||||
the compute graph size, so extremely large graphs (while computationally
|
||||
correct) can be costly.
|
||||
|
||||
Luckily, a wide range of compute graph sizes work pretty well with MLX:
|
||||
anything from a few tens of operations to many thousands of operations per
|
||||
evaluation should be okay.
|
||||
|
||||
Most numerical computations have an iterative outer loop (e.g. the iteration in
|
||||
stochastic gradient descent). A natural and usually efficient place to use
|
||||
:func:`eval` is at each iteration of this outer loop.
|
||||
|
||||
Here is a concrete example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for batch in dataset:
|
||||
|
||||
# Nothing has been evaluated yet
|
||||
loss, grad = value_and_grad_fn(model, batch)
|
||||
|
||||
# Still nothing has been evaluated
|
||||
optimizer.update(model, grad)
|
||||
|
||||
# Evaluate the loss and the new parameters which will
|
||||
# run the full gradient computation and optimizer update
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||
will be a partial evaluation, computing only the forward pass.
|
||||
|
||||
Also, calling :func:`eval` on an array or set of arrays multiple times is
|
||||
perfectly fine. This is effectively a no-op.
|
||||
|
||||
.. warning::
|
||||
|
||||
Using scalar arrays for control-flow will cause an evaluation.
|
||||
|
||||
Here is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
h, y = first_layer(x)
|
||||
if y > 0: # An evaluation is done here!
|
||||
z = second_layer_a(h)
|
||||
else:
|
||||
z = second_layer_b(h)
|
||||
return z
|
||||
|
||||
Using arrays for control flow should be done with care. The above example works
|
||||
and can even be used with gradient transformations. However, this can be very
|
||||
inefficient if evaluations are done too frequently.
|
108
docs/src/usage/numpy.rst
Normal file
108
docs/src/usage/numpy.rst
Normal file
@@ -0,0 +1,108 @@
|
||||
.. _numpy:
|
||||
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
a = mx.arange(3)
|
||||
b = np.array(a) # copy of a
|
||||
c = mx.array(b) # copy of b
|
||||
|
||||
.. note::
|
||||
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||
``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.arange(3)
|
||||
a_view = np.array(a, copy=False)
|
||||
print(a_view.flags.owndata) # False
|
||||
a_view[0] = 1
|
||||
print(a[0].item()) # 1
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||
This means writing to the view is reflected in the original array.
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||
|
||||
Let's demonstrate this in an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def f(x):
|
||||
x_view = np.array(x, copy=False)
|
||||
x_view[:] *= x_view # modify memory without telling mx
|
||||
return x.sum()
|
||||
|
||||
x = mx.array([3.0])
|
||||
y, df = mx.value_and_grad(f)(x)
|
||||
print("f(x) = x² =", y.item()) # 9.0
|
||||
print("f'(x) = 2x !=", df.item()) # 1.0
|
||||
|
||||
|
||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||
representing the gradient of the sum operation alone.
|
||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||
It's important to note that a similar issue arises during array conversion and copying.
|
||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
even though no in-place operations on MLX memory are executed.
|
||||
|
||||
PyTorch
|
||||
-------
|
||||
|
||||
.. warning::
|
||||
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
a = mx.arange(3)
|
||||
b = torch.tensor(memoryview(a))
|
||||
c = mx.array(b.numpy())
|
||||
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||
|
||||
JAX
|
||||
---
|
||||
JAX fully supports the buffer protocol.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import jax.numpy as jnp
|
||||
|
||||
a = mx.arange(3)
|
||||
b = jnp.array(a)
|
||||
c = mx.array(b)
|
||||
|
||||
TensorFlow
|
||||
----------
|
||||
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import tensorflow as tf
|
||||
|
||||
a = mx.arange(3)
|
||||
b = tf.constant(memoryview(a))
|
||||
c = mx.array(b)
|
@@ -40,6 +40,9 @@ automatically evaluate the array.
|
||||
>> np.array(c) # Also evaluates c
|
||||
array([2., 4., 6., 8.], dtype=float32)
|
||||
|
||||
|
||||
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
|
||||
|
||||
Function and Graph Transformations
|
||||
----------------------------------
|
||||
|
||||
@@ -62,10 +65,3 @@ and :func:`jvp` for Jacobian-vector products.
|
||||
|
||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||
gradient with respect to the function's input.
|
||||
|
||||
|
||||
Devices and Streams
|
||||
-------------------
|
||||
|
||||
|
||||
|
81
docs/src/usage/saving_and_loading.rst
Normal file
81
docs/src/usage/saving_and_loading.rst
Normal file
@@ -0,0 +1,81 @@
|
||||
.. _saving_and_loading:
|
||||
|
||||
Saving and Loading Arrays
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX supports multiple array serialization formats.
|
||||
|
||||
.. list-table:: Serialization Formats
|
||||
:widths: 20 8 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Format
|
||||
- Extension
|
||||
- Function
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- :func:`save`
|
||||
- Single arrays only
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
- :func:`savez` and :func:`savez_compressed`
|
||||
- Multiple arrays
|
||||
* - Safetensors
|
||||
- ``.safetensors``
|
||||
- :func:`save_safetensors`
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- :func:`save_gguf`
|
||||
- Multiple arrays
|
||||
|
||||
The :func:`load` function will load any of the supported serialization
|
||||
formats. It determines the format from the extensions. The output of
|
||||
:func:`load` depends on the format.
|
||||
|
||||
Here's an example of saving a single array to a file:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0])
|
||||
>>> mx.save("array", a)
|
||||
|
||||
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
|
||||
is automatically added). Including the extension is optional; if it is missing
|
||||
it will be added. You can load the array with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("array.npy", a)
|
||||
array([1], dtype=float32)
|
||||
|
||||
Here's an example of saving several arrays to a single file:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0])
|
||||
>>> b = mx.array([2.0])
|
||||
>>> mx.savez("arrays", a, b=b)
|
||||
|
||||
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
|
||||
as arguments. If the keywords are missing, then default names will be
|
||||
provided. This can be loaded with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("arrays.npz")
|
||||
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
|
||||
|
||||
In this case :func:`load` returns a dictionary of names to arrays.
|
||||
|
||||
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
|
||||
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0])
|
||||
>>> b = mx.array([2.0])
|
||||
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
|
78
docs/src/usage/unified_memory.rst
Normal file
78
docs/src/usage/unified_memory.rst
Normal file
@@ -0,0 +1,78 @@
|
||||
.. _unified_memory:
|
||||
|
||||
Unified Memory
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Apple silicon has a unified memory architecture. The CPU and GPU have direct
|
||||
access to the same memory pool. MLX is designed to take advantage of that.
|
||||
|
||||
Concretely, when you make an array in MLX you don't have to specify its location:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.random.normal((100,))
|
||||
b = mx.random.normal((100,))
|
||||
|
||||
Both ``a`` and ``b`` live in unified memory.
|
||||
|
||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||
without needing to move them from one memory location to another. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.add(a, b, stream=mx.cpu)
|
||||
mx.add(a, b, stream=mx.gpu)
|
||||
|
||||
In the above, both the CPU and the GPU will perform the same add
|
||||
operation. The operations can (and likely will) be run in parallel since
|
||||
there are no dependencies between them. See :ref:`using_streams` for more
|
||||
information the semantics of streams in MLX.
|
||||
|
||||
In the above ``add`` example, there are no dependencies between operations, so
|
||||
there is no possibility for race conditions. If there are dependencies, the
|
||||
MLX scheduler will automatically manage them. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
c = mx.add(a, b, stream=mx.cpu)
|
||||
d = mx.add(a, c, stream=mx.gpu)
|
||||
|
||||
In the above case, the second ``add`` runs on the GPU but it depends on the
|
||||
output of the first ``add`` which is running on the CPU. MLX will
|
||||
automatically insert a dependency between the two streams so that the second
|
||||
``add`` only starts executing after the first is complete and ``c`` is
|
||||
available.
|
||||
|
||||
A Simple Example
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is a more interesting (albeit slightly contrived example) of how unified
|
||||
memory can be helpful. Suppose we have the following computation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(a, b, d1, d2):
|
||||
x = mx.matmul(a, b, stream=d1)
|
||||
for _ in range(500):
|
||||
b = mx.exp(b, stream=d2)
|
||||
return x, b
|
||||
|
||||
which we want to run with the following arguments:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.random.uniform(shape=(4096, 512))
|
||||
b = mx.random.uniform(shape=(512, 4))
|
||||
|
||||
The first ``matmul`` operation is a good fit for the GPU since it's more
|
||||
compute dense. The second sequence of operations are a better fit for the CPU,
|
||||
since they are very small and would probably be overhead bound on the GPU.
|
||||
|
||||
If we time the computation fully on the GPU, we get 2.8 milliseconds. But if we
|
||||
run the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only
|
||||
about 1.4 milliseconds, about twice as fast. These times were measured on an M1
|
||||
Max.
|
@@ -1,3 +1,5 @@
|
||||
.. _using_streams:
|
||||
|
||||
Using Streams
|
||||
=============
|
||||
|
@@ -57,7 +57,7 @@ void array_basics() {
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
// To actually run the compuation you must evaluate `z`.
|
||||
// To actually run the computation you must evaluate `z`.
|
||||
// Under the hood, mlx records operations in a graph.
|
||||
// The variable `z` is a node in the graph which points to its operation
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
|
@@ -26,7 +26,7 @@ namespace mlx::core {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@@ -91,21 +91,21 @@ void axpby_impl(
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@@ -192,7 +192,7 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
#else // Accelerate not avaliable
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -254,7 +254,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
// those in the kernel declaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
@@ -287,7 +287,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
@@ -311,8 +311,8 @@ array Axpby::jvp(
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
@@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@@ -39,7 +39,7 @@ class Axpby : public Primitive {
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -58,7 +58,7 @@ class Axpby : public Primitive {
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself accross
|
||||
* The primitive must know how to vectorize itself across
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
|
@@ -59,5 +59,5 @@ template <typename T>
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
@@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
|
@@ -1,4 +1,5 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .mlx_sample_extensions import *
|
||||
|
@@ -1,8 +1,9 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
|
||||
from mlx import extension
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="mlx_sample_extensions",
|
||||
@@ -14,5 +15,5 @@ if __name__ == "__main__":
|
||||
package_dir={"": "."},
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000
|
||||
|
@@ -1,8 +1,9 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000
|
||||
|
@@ -8,17 +8,17 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
else()
|
||||
|
@@ -9,7 +9,7 @@
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -22,7 +22,7 @@ void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size) {
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
return Buffer{std::malloc(size)};
|
||||
}
|
||||
|
||||
@@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
|
@@ -37,9 +37,9 @@ void free(Buffer buffer);
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base clase for a memory allocator. */
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
|
||||
Allocator() = default;
|
||||
@@ -55,7 +55,7 @@ Allocator& allocator();
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
|
||||
private:
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -32,7 +39,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
@@ -40,6 +47,23 @@ array::array(
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
@@ -59,11 +83,17 @@ array::array(
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval(bool retain_graph /* = false */) {
|
||||
mlx::core::eval({*this}, retain_graph);
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
@@ -116,7 +146,7 @@ array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
@@ -128,10 +158,6 @@ array::ArrayDesc::ArrayDesc(
|
||||
}
|
||||
}
|
||||
|
||||
// Needed because the Primitive type used in array.h is incomplete and the
|
||||
// compiler needs to see the call to the desctructor after the type is complete.
|
||||
array::ArrayDesc::~ArrayDesc() = default;
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
|
65
mlx/array.h
65
mlx/array.h
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
@@ -116,11 +115,11 @@ class array {
|
||||
};
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval(bool retain_graph = false);
|
||||
void eval();
|
||||
|
||||
/** Get the value from a scalar array. */
|
||||
template <typename T>
|
||||
T item(bool retain_graph = false);
|
||||
T item();
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
@@ -154,8 +153,8 @@ class array {
|
||||
};
|
||||
|
||||
private:
|
||||
int idx;
|
||||
const array& arr;
|
||||
int idx;
|
||||
};
|
||||
|
||||
ArrayIterator begin() const {
|
||||
@@ -174,7 +173,13 @@ class array {
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
@@ -182,6 +187,11 @@ class array {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
}
|
||||
|
||||
/** A unique identifier for an arrays primitive. */
|
||||
std::uintptr_t primitive_id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
|
||||
}
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
@@ -219,12 +229,32 @@ class array {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
|
||||
/** A non-const reference to the array's inputs so that they can be used to
|
||||
* edit the graph. */
|
||||
std::vector<array>& editable_inputs() {
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
auto idx = array_desc_->position;
|
||||
std::vector<array> outputs;
|
||||
outputs.reserve(siblings().size() + 1);
|
||||
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
||||
@@ -265,9 +295,7 @@ class array {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
}
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const {
|
||||
return array_desc_->is_tracer;
|
||||
}
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
|
||||
@@ -301,7 +329,7 @@ class array {
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::unique_ptr<Primitive> primitive{nullptr};
|
||||
std::shared_ptr<Primitive> primitive{nullptr};
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
@@ -323,16 +351,19 @@ class array {
|
||||
Flags flags;
|
||||
|
||||
std::vector<array> inputs;
|
||||
// An array to keep track of the siblings from a multi-output
|
||||
// primitive.
|
||||
std::vector<array> siblings;
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
~ArrayDesc();
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
@@ -381,11 +412,11 @@ array::array(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item(bool retain_graph /* = false */) {
|
||||
T array::item() {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
eval(retain_graph);
|
||||
eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
|
@@ -4,6 +4,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
)
|
||||
|
@@ -17,6 +17,12 @@
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
@@ -26,12 +32,14 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
@@ -39,12 +47,15 @@ DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
@@ -52,6 +63,7 @@ DEFAULT(Slice)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -322,6 +334,45 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Avoid code duplication with the common backend.
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
RemainderFn{},
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
int num_el = n;
|
||||
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -494,7 +545,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
103
mlx/backend/accelerate/quantized.cpp
Normal file
103
mlx/backend/accelerate/quantized.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void _qmm_t_4_64(
|
||||
float* result,
|
||||
const float* x,
|
||||
const uint32_t* w,
|
||||
const float* scales,
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& scales = inputs[2];
|
||||
auto& biases = inputs[3];
|
||||
|
||||
bool condition =
|
||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int N = out.shape(-1);
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -8,6 +8,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -75,6 +76,61 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@@ -82,6 +138,29 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
}
|
||||
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
|
@@ -73,6 +73,12 @@ struct UseDefaultBinaryOp {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -89,6 +95,18 @@ struct DefaultVectorScalar {
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, scalar);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -105,6 +123,18 @@ struct DefaultScalarVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(scalar, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -121,6 +151,18 @@ struct DefaultVectorVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
|
536
mlx/backend/common/binary_two.h
Normal file
536
mlx/backend/common/binary_two.h
Normal file
@@ -0,0 +1,536 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out_a.size(); ++i) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[i] = dst.first;
|
||||
dst_b[i] = dst.second;
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
switch (out_a.ndim()) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||
op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
opsv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
opvs(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
opvv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
out_a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out_a.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
|
||||
auto ndim = out_a.ndim();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Ops... ops) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@@ -357,7 +357,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
@@ -459,7 +459,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
|
@@ -1,6 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
@@ -12,6 +16,12 @@
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
@@ -29,17 +39,20 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
@@ -50,6 +63,8 @@ DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
@@ -59,9 +74,11 @@ DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
@@ -78,6 +95,7 @@ DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -13,7 +13,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
@@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/erf.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
@@ -167,6 +168,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::ceil(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
@@ -287,6 +299,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::floor(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -342,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
unary(in, out, [](auto x) { return !x; });
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -444,6 +481,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, RoundOp());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
268
mlx/backend/common/quantized.cpp
Normal file
268
mlx/backend/common/quantized.cpp
Normal file
@@ -0,0 +1,268 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Ng = N / group_size;
|
||||
const int Nw = N / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
std::fill(result, result + N, 0);
|
||||
|
||||
for (int k = 0; k < K; k++) {
|
||||
T* result_local = result;
|
||||
T xi = *x++;
|
||||
|
||||
for (int n = 0; n < N; n += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result += N;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
uint32_t wi = *w_local++;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
*result = sum;
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _qmm_dispatch_typed(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transposed_w) {
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 128:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Quantization type not supported. Provided bits=" << bits
|
||||
<< " and group_size=" << group_size
|
||||
<< ". The supported options are bits in "
|
||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
array out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int N = out.shape(-1);
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>(),
|
||||
x.data<float16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float16_t>(),
|
||||
biases.data<float16_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>(),
|
||||
x.data<bfloat16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<bfloat16_t>(),
|
||||
biases.data<bfloat16_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -126,7 +126,7 @@ struct ReductionPlan {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
|
@@ -53,6 +53,17 @@ struct SignOp {
|
||||
}
|
||||
};
|
||||
|
||||
struct RoundOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::round(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {std::round(x.real()), std::round(x.imag())};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
|
@@ -10,6 +10,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
|
@@ -26,13 +26,10 @@ namespace metal {
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device),
|
||||
head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
clear();
|
||||
}
|
||||
|
||||
@@ -54,12 +51,16 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use > 50% of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
|
||||
// Make sure we use most of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() &&
|
||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
@@ -85,13 +86,9 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
|
||||
|
||||
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
size_t old_pool_size = pool_size_;
|
||||
clear();
|
||||
return old_pool_size;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
size_t total_bytes_freed = 0;
|
||||
@@ -104,9 +101,7 @@ size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return total_bytes_freed;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,8 +120,9 @@ void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove)
|
||||
if (!to_remove) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
@@ -153,26 +149,37 @@ MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
peak_allocated_size_(0),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
}
|
||||
|
||||
// Try the cache
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
|
||||
// Prepare to allocate new memory as needed
|
||||
if (!buf) {
|
||||
// If we are under very high memoory pressure, we don't allocate further
|
||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// If we are still under memory pressure, try cleaning cache
|
||||
if (buffer_cache_.can_garbage_collect()) {
|
||||
buffer_cache_.release_cached_buffers(size);
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure, check if we can reclaim some memory
|
||||
// from the cache
|
||||
if (device_->currentAllocatedSize() + size >= gc_limit_) {
|
||||
size_t min_bytes_to_free =
|
||||
size + device_->currentAllocatedSize() - gc_limit_;
|
||||
buffer_cache_.release_cached_buffers(min_bytes_to_free);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
|
@@ -23,11 +23,7 @@ class BufferCache {
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
size_t release_cached_buffers(size_t min_bytes_to_free);
|
||||
|
||||
bool can_garbage_collect() {
|
||||
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
|
||||
}
|
||||
void release_cached_buffers(size_t min_bytes_to_free);
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@@ -49,7 +45,6 @@ class BufferCache {
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
size_t gc_limit_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -57,7 +52,7 @@ class BufferCache {
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
|
||||
private:
|
||||
@@ -71,6 +66,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
// Allocation stats
|
||||
size_t peak_allocated_size_;
|
||||
size_t block_limit_;
|
||||
size_t gc_limit_;
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
@@ -68,7 +68,7 @@ void explicit_gemm_conv_1D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
@@ -260,7 +260,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
|
@@ -20,6 +20,9 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
|
@@ -17,8 +17,6 @@ namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
static Device metal_device_;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
@@ -27,7 +25,8 @@ static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
MTL::Device* device = MTL::CreateSystemDefaultDevice();
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0));
|
||||
if (!device) {
|
||||
throw std::runtime_error("Failed to load device");
|
||||
}
|
||||
@@ -44,6 +43,25 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
return std::make_pair(lib, error);
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
|
||||
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
|
||||
SWIFTPM_BUNDLE + ".bundle";
|
||||
auto bundle = NS::Bundle::alloc()->init(
|
||||
NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));
|
||||
if (bundle != nullptr) {
|
||||
std::string resource_path =
|
||||
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
|
||||
"default.metallib";
|
||||
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
@@ -57,6 +75,26 @@ MTL::Library* load_library(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SWIFTPM_BUNDLE
|
||||
// try to load from a swiftpm resource bundle -- scan the available bundles to
|
||||
// find one that contains the named bundle
|
||||
{
|
||||
MTL::Library* library =
|
||||
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
auto bundles = NS::Bundle::allBundles();
|
||||
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
|
||||
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
|
||||
library = try_load_bundle(device, bundle->resourceURL());
|
||||
if (library != nullptr) {
|
||||
return library;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
@@ -73,15 +111,23 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device()
|
||||
: pool_(NS::AutoreleasePool::alloc()->init()),
|
||||
device_(load_device()),
|
||||
library_map_({{"mlx", load_library(device_)}}) {}
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
e.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -89,10 +135,11 @@ Device::~Device() {
|
||||
l.second->release();
|
||||
}
|
||||
device_->release();
|
||||
pool_->release();
|
||||
}
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
@@ -198,6 +245,7 @@ void Device::register_library(
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
@@ -240,17 +288,18 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
return metal_device_;
|
||||
static Device metal_device;
|
||||
return metal_device;
|
||||
}
|
||||
|
||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||
static thread_local NS::AutoreleasePool* p =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
return p;
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
auto dtor = [](void* ptr) {
|
||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||
};
|
||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
thread_autorelease_pool();
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
|
@@ -67,7 +67,6 @@ class Device {
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
NS::AutoreleasePool* pool_;
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
@@ -78,6 +77,5 @@ class Device {
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@@ -33,6 +32,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -102,7 +104,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
@@ -110,14 +112,18 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
@@ -163,6 +169,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get stream
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -246,7 +257,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
@@ -254,14 +265,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
@@ -272,14 +287,32 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
|
@@ -14,10 +14,12 @@ set(
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"gemm"
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"scan"
|
||||
|
@@ -114,7 +114,7 @@ template <typename T, typename Op, int N_READS>
|
||||
// 4. Reduce among them and go to 3
|
||||
// 4. Reduce in each simd_group
|
||||
// 6. Write in the thread local memory
|
||||
// 6. Reduce them accross thread group
|
||||
// 6. Reduce them across thread group
|
||||
// 7. Write the output without need for atomic
|
||||
Op op;
|
||||
|
||||
|
@@ -14,6 +14,13 @@ struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
@@ -124,6 +131,16 @@ struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
@@ -350,7 +367,7 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_float(div, Divide)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
@@ -363,9 +380,13 @@ instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
259
mlx/backend/metal/kernels/binary_two.metal
Normal file
259
mlx/backend/metal/kernels/binary_two.metal
Normal file
@@ -0,0 +1,259 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[index]);
|
||||
d[index] = Op2()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[0]);
|
||||
d[index] = Op2()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[index]);
|
||||
d[index] = Op2()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[index] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
||||
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
||||
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
@@ -45,7 +45,7 @@ struct complex64_t {
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Converstions from complex64_t
|
||||
// Conversions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
@@ -110,3 +110,16 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a.real * b.real + a.imag * b.imag;
|
||||
auto y = a.imag * b.real - a.real * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
return {real, imag};
|
||||
}
|
||||
|
@@ -105,7 +105,7 @@ struct Conv2DInputBlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwize
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
@@ -334,7 +334,7 @@ struct Conv2DBlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@@ -93,13 +93,13 @@ struct BlockLoader {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indcies into tmp_val
|
||||
// Read all valid indices into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
@@ -241,7 +241,7 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@@ -28,7 +28,7 @@ struct GEMVKernel {
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@@ -42,7 +42,7 @@ struct GEMVKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
@@ -166,7 +166,7 @@ template <
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@@ -180,7 +180,7 @@ struct GEMVTKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
|
568
mlx/backend/metal/kernels/quantized.metal
Normal file
568
mlx/backend/metal/kernels/quantized.metal
Normal file
@@ -0,0 +1,568 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[colgroup];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result = 0;
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_thread[el_per_thread];
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
int out_row = tid.y * BM + simd_gid;
|
||||
w += out_row * in_vec_size_w;
|
||||
scales += out_row * in_vec_size_g;
|
||||
biases += out_row * in_vec_size_g;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size;
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid < simdgroups_fetching_vec) {
|
||||
x_block[lid] = x[lid + i];
|
||||
}
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||
}
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[i / el_per_thread + simd_lid];
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
result = simd_sum(result);
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
y[out_row] = result;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[BM];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result[el_per_int] = {0};
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = x[simd_lid + i];
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
y[k] = result[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
const uint lidy = lid / SIMD_SIZE;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int ints_per_block = BK / el_per_int;
|
||||
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
|
||||
constexpr int groups_per_simd = BN / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
|
||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
||||
|
||||
threadgroup T scales_block[BN * groups_per_block];
|
||||
threadgroup T biases_block[BN * groups_per_block];
|
||||
threadgroup T Xs[BM * BK];
|
||||
threadgroup T Ws[BN * BK];
|
||||
|
||||
// Set the block
|
||||
const int K_w = K / el_per_int;
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
w += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
biases += y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k=0; k<K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load the x tile
|
||||
if (num_els < BM) {
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
} else {
|
||||
loader_x.load_unsafe();
|
||||
}
|
||||
|
||||
// Load the scale and bias
|
||||
if (simd_lid == 0) {
|
||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gc=0; gc<groups_per_block; gc++) {
|
||||
scales_block_local[gc] = scales_local[gc];
|
||||
biases_block_local[gc] = biases_local[gc];
|
||||
}
|
||||
scales_block_local += groups_per_block;
|
||||
scales_local += K_g;
|
||||
biases_block_local += groups_per_block;
|
||||
biases_local += K_g;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(Xs, Ws);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_x.next();
|
||||
w += ints_per_block;
|
||||
// scales and biases cannot be advanced because they would have to be
|
||||
// advanced every other iteration or sth.
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM) {
|
||||
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
const uint lidy = lid / SIMD_SIZE;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1;
|
||||
constexpr int groups_per_simd = BK / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
|
||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
||||
|
||||
threadgroup T scales_block[BK * groups_per_block];
|
||||
threadgroup T biases_block[BK * groups_per_block];
|
||||
threadgroup T Xs[BM * BK];
|
||||
threadgroup T Ws[BK * BN];
|
||||
|
||||
// Set the block
|
||||
const int N_w = N / el_per_int;
|
||||
const int N_g = N / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
w += y_col / el_per_int;
|
||||
scales += y_col / group_size;
|
||||
biases += y_col / group_size;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k=0; k<K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load the x tile
|
||||
if (num_els < BM) {
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
} else {
|
||||
loader_x.load_unsafe();
|
||||
}
|
||||
|
||||
// Load the scale and bias
|
||||
if (simd_lid == 0) {
|
||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * N_g;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * N_g;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gc=0; gc<groups_per_block; gc++) {
|
||||
scales_block_local[gc] = scales_local[gc];
|
||||
biases_block_local[gc] = biases_local[gc];
|
||||
}
|
||||
scales_block_local += groups_per_block;
|
||||
scales_local += N_g;
|
||||
biases_block_local += groups_per_block;
|
||||
biases_local += N_g;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(Xs, Ws);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_x.next();
|
||||
w += BK * N_w;
|
||||
scales += BK * N_g;
|
||||
biases += BK * N_g;
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM) {
|
||||
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
instantiate_qmv(float32, float, group_size, bits) \
|
||||
instantiate_qmv(float16, half, group_size, bits) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmv_types(128, 2)
|
||||
instantiate_qmv_types(128, 4)
|
||||
instantiate_qmv_types(128, 8)
|
||||
instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qvm_types(group_size, bits) \
|
||||
instantiate_qvm(float32, float, group_size, bits) \
|
||||
instantiate_qvm(float16, half, group_size, bits) \
|
||||
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qvm_types(128, 2)
|
||||
instantiate_qvm_types(128, 4)
|
||||
instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
device itype* y [[buffer(4)]], \
|
||||
const constant int& M [[buffer(5)]], \
|
||||
const constant int& N [[buffer(6)]], \
|
||||
const constant int& K [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_n_types(group_size, bits) \
|
||||
instantiate_qmm_n(float32, float, group_size, bits) \
|
||||
instantiate_qmm_n(float16, half, group_size, bits) \
|
||||
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_n_types(128, 2)
|
||||
instantiate_qmm_n_types(128, 4)
|
||||
instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
@@ -65,7 +65,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
@@ -112,88 +112,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// General reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
template [[host_name("general_reduce_" #name)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce(
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& reduction_size [[buffer(2)]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
in += tid * reduction_size + lid * N_READS;
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid.x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
@@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
@@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize * N_READS;
|
||||
in += lsize.x * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
@@ -240,26 +185,30 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid == 0) {
|
||||
out[tid] = total_val;
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_" #name)]] \
|
||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& reduction_size [[buffer(2)]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
@@ -311,148 +260,57 @@ inline void _contiguous_strided_reduce(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce(
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
out_idx,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] \
|
||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
const device size_t& in_dim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
const device size_t& in_dim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
|
||||
#define instantiate_reduce(name, itype, otype, op) \
|
||||
instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_row_reduce(name, itype, otype, op) \
|
||||
instantiate_col_reduce(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_general_reduce(name, itype, otype, op)
|
||||
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
@@ -535,4 +391,4 @@ instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||
|
@@ -592,7 +592,7 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
@@ -777,8 +777,8 @@ template <
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
|
@@ -43,6 +43,19 @@ struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
@@ -83,6 +96,19 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
@@ -107,6 +133,11 @@ struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::round(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -253,9 +284,11 @@ instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
@@ -272,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
@@ -282,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
||||
|
@@ -61,7 +61,7 @@ inline void mps_matmul(
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||
// The matrix dimensions of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
|
@@ -4,7 +4,6 @@
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@@ -46,44 +45,40 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, []() {
|
||||
thread_autorelease_pool()->release();
|
||||
thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
for (auto s : arr.siblings()) {
|
||||
s.detach();
|
||||
}
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
|
@@ -20,11 +20,11 @@ constexpr bool is_available() {
|
||||
}
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -19,6 +19,101 @@ namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
if (bopt == General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@@ -28,6 +123,9 @@ void binary_op(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
@@ -84,9 +182,9 @@ void binary_op(
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = out.size() / (dim0 * dim1);
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
@@ -122,6 +220,9 @@ void unary_op(
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -171,6 +272,9 @@ void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel("arange" + type_to_name(out));
|
||||
@@ -215,7 +319,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
|
||||
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
@@ -297,9 +402,18 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
@@ -363,6 +477,16 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
@@ -434,6 +558,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(
|
||||
inputs,
|
||||
out,
|
||||
"land"); // Assume "land" is the operation identifier for logical AND
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(
|
||||
inputs,
|
||||
out,
|
||||
"lor"); // Assume "lor" is the operation identifier for logical OR
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
@@ -446,6 +584,14 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "floor");
|
||||
}
|
||||
|
||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "ceil");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
@@ -504,6 +650,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
@@ -551,6 +700,17 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
}
|
||||
|
172
mlx/backend/metal/quantized.cpp
Normal file
172
mlx/backend/metal/quantized.cpp
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the qmv kernel
|
||||
if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_t kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 64;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user