mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
Compare commits
14 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
43e336cff2 | ||
![]() |
d895e38f2e | ||
![]() |
d15dead35e | ||
![]() |
2440fe0124 | ||
![]() |
170e4b2d43 | ||
![]() |
2629cc8682 | ||
![]() |
9f4cf2e0fe | ||
![]() |
2ffaee0c0d | ||
![]() |
36b245b287 | ||
![]() |
8c96b9a890 | ||
![]() |
07897a346d | ||
![]() |
d518b3b6a5 | ||
![]() |
49cda449b1 | ||
![]() |
6449a8682a |
@@ -203,7 +203,7 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
nightly_build:
|
||||
when: << pipeline.parameters.nightly_build >>
|
||||
@@ -211,7 +211,7 @@ workflows:
|
||||
- build_package:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
weekly_build:
|
||||
when: << pipeline.parameters.weekly_build >>
|
||||
@@ -219,5 +219,5 @@ workflows:
|
||||
- build_dev_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ __pycache__/
|
||||
|
||||
# Metal libraries
|
||||
*.metallib
|
||||
venv/
|
||||
|
||||
# Distribution / packaging
|
||||
python/mlx/share
|
||||
|
@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.0.1)
|
||||
set(MLX_VERSION 0.0.3)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -41,16 +41,19 @@ elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
add_compile_definitions(_METAL_)
|
||||
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION)
|
||||
|
||||
message(STATUS "Detected macOS version ${MACOS_VERSION}")
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
message(STATUS "Building with SDK for MacOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||
else()
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
|
||||
message(FATAL_ERROR "MLX requires MacOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
|
25
README.md
25
README.md
@@ -2,15 +2,18 @@
|
||||
|
||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||
[**Examples**](#examples)
|
||||
[**Examples**](#examples)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon.
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
||||
by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API which closely follows NumPy.
|
||||
MLX also has a fully featured C++ API which closely mirrors the Python API.
|
||||
MLX has higher level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
that closely follow PyTorch to simplify building more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
@@ -25,15 +28,15 @@ Some key features of MLX include:
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and GPU).
|
||||
(currently, the CPU and GPU).
|
||||
|
||||
- **Unified memory**: A noteable difference from MLX and other frameworks
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without moving data.
|
||||
|
||||
MLX is designed by machine learning researchers for machine learning
|
||||
researchers. The framework is intended to be user friendly, but still efficient
|
||||
researchers. The framework is intended to be user-friendly, but still efficient
|
||||
to train and deploy models. The design of the framework itself is also
|
||||
conceptually simple. We intend to make it easy for researchers to extend and
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
@@ -46,10 +49,10 @@ The design of MLX is inspired by frameworks like
|
||||
## Examples
|
||||
|
||||
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
|
||||
variety of examples including:
|
||||
variety of examples, including:
|
||||
|
||||
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
||||
- Large scale text generation with
|
||||
- Large-scale text generation with
|
||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
|
||||
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
||||
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
||||
@@ -63,7 +66,7 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API run:
|
||||
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
|
||||
```
|
||||
pip install mlx
|
||||
|
@@ -30,7 +30,7 @@ def time_batch_matmul():
|
||||
time_fn(batch_vjp_second)
|
||||
|
||||
|
||||
def time_unbatch_matmul(key):
|
||||
def time_unbatch_matmul():
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B * T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
|
@@ -7,7 +7,7 @@ for example with `conda`:
|
||||
|
||||
```
|
||||
conda install sphinx
|
||||
pip install sphinx-rtd-theme
|
||||
pip install sphinx-book-theme
|
||||
```
|
||||
|
||||
### Build
|
||||
|
BIN
docs/src/_static/mlx_logo.png
Normal file
BIN
docs/src/_static/mlx_logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.2 KiB |
@@ -10,8 +10,8 @@ import subprocess
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = "0.0.0"
|
||||
release = "0.0.0"
|
||||
version = "0.0.4"
|
||||
release = "0.0.4"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@@ -39,7 +39,17 @@ pygments_style = "sphinx"
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_theme = "sphinx_book_theme"
|
||||
|
||||
html_theme_options = {
|
||||
"show_toc_level": 2,
|
||||
"repository_url": "https://github.com/ml-explore/mlx",
|
||||
"use_repository_button": True,
|
||||
"navigation_with_keys": False,
|
||||
}
|
||||
|
||||
html_logo = "_static/mlx_logo.png"
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@@ -131,7 +131,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -945,4 +945,4 @@ Scripts
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
|
@@ -321,7 +321,7 @@ which can then be used to update the model. Note that the method above incurs
|
||||
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||
will be replaced in the future with direct loading to MLX.
|
||||
|
||||
You can download the full example code in `mlx-examples <code>`_. Assuming, the
|
||||
You can download the full example code in `mlx-examples`_. Assuming, the
|
||||
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||
directory we can play around with our inference script as follows (the timings
|
||||
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||
@@ -369,9 +369,9 @@ Scripts
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
The full example code is available in `mlx-examples`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
|
||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llama
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
|
@@ -127,5 +127,5 @@ Finally, we put it all together by instantiating the model, the
|
||||
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||
|
||||
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
|
||||
is available in the MLX GitHub repo.
|
||||
|
@@ -1,6 +1,30 @@
|
||||
MLX
|
||||
===
|
||||
|
||||
MLX is a NumPy-like array framework designed for efficient and flexible machine
|
||||
learning on Apple silicon, brought to you by Apple machine learning research.
|
||||
|
||||
The Python API closely follows NumPy with a few exceptions. MLX also has a
|
||||
fully featured C++ API which closely follows the Python API.
|
||||
|
||||
The main differences between MLX and NumPy are:
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Multi-device**: Operations can run on any of the supported devices (CPU,
|
||||
GPU, ...)
|
||||
|
||||
The design of MLX is inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
are the CPU and GPU.
|
||||
|
||||
.. toctree::
|
||||
:caption: Install
|
||||
:maxdepth: 1
|
||||
|
@@ -11,6 +11,10 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running MacOS >= 13.3
|
||||
It is highly recommended to use MacOS 14 (Sonoma)
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
@@ -19,6 +23,7 @@ Build Requirements
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above)
|
||||
|
||||
|
||||
Python API
|
||||
@@ -55,7 +60,7 @@ For developing use an editable install:
|
||||
To make sure the install is working run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
C++ API
|
||||
@@ -111,3 +116,21 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
MacOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
@@ -1,28 +1,6 @@
|
||||
Quick Start Guide
|
||||
=================
|
||||
|
||||
MLX is a NumPy-like array framework designed for efficient and flexible
|
||||
machine learning on Apple silicon. The Python API closely follows NumPy with
|
||||
a few exceptions. MLX also has a fully featured C++ API which closely follows
|
||||
the Python API.
|
||||
|
||||
The main differences between MLX and NumPy are:
|
||||
|
||||
- **Composable function transformations**: MLX has composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Multi-device**: Operations can run on any of the supported devices (CPU,
|
||||
GPU, ...)
|
||||
|
||||
The design of MLX is strongly inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
are the CPU and GPU.
|
||||
|
||||
Basics
|
||||
------
|
||||
|
@@ -3,8 +3,9 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@@ -12,52 +13,57 @@ using namespace metal;
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constant constexpr int SIMD_SIZE = 32;
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
struct GEMVKernel {
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
(void)lid;
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Threadgroup in_vec cache
|
||||
threadgroup T in_vec_block[BN][TN * 2];
|
||||
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
|
||||
|
||||
// Thread local accumulation results
|
||||
thread T result[TM] = {0};
|
||||
@@ -69,7 +75,7 @@ template <typename T,
|
||||
|
||||
// Exit simdgroup if rows out of bound
|
||||
if(out_row >= out_vec_size)
|
||||
return;
|
||||
return;
|
||||
|
||||
// Adjust tail simdgroup to ensure in bound reads
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
@@ -79,62 +85,304 @@ template <typename T,
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
|
||||
}
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Prefetch in_vector for threadgroup use
|
||||
if(simd_gid == 0) {
|
||||
// Main load loop
|
||||
if(bn + TN <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[simd_lid][tn];
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
// Accumulate results
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if(simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Appease compiler
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN;
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix vector multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
@@ -145,28 +393,51 @@ template <typename T,
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
instantiate_gemv_blocks(float32, float)
|
||||
instantiate_gemv_blocks(float16, half)
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t)
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel]] void gemv_t(
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
@@ -175,110 +446,77 @@ template <typename T,
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Thread local accumulation results
|
||||
T result[TN] = {0};
|
||||
T inter[TN];
|
||||
T v_coeff[TM];
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
// Threadgroup accumulation results
|
||||
threadgroup T tgp_results[BN][BM][TM];
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
}
|
||||
|
||||
int out_col = (tid.x * BN + lid.x) * TN;
|
||||
int in_row = lid.y * TM;
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
// Edgecase handling
|
||||
if (out_col < out_vec_size) {
|
||||
// Edgecase handling
|
||||
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Per thread accumulation main loop
|
||||
int bm = in_row;
|
||||
for(; bm < in_vec_size; bm += BM * TM) {
|
||||
// Adding a threadgroup_barrier improves performance slightly
|
||||
// This is possibly it may help exploit cache better
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
if(bm + TM <= in_vec_size) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
|
||||
} else { // Edgecase handling
|
||||
for(int tm = 0; bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Threadgroup collection
|
||||
for(int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.x][lid.y][i] = result[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(lid.y == 0 && out_col < out_vec_size) {
|
||||
// Threadgroup accumulation
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < BM; i++) {
|
||||
for(int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[lid.x][i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
@@ -289,16 +527,39 @@ template <typename T,
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
|
||||
|
||||
instantiate_gemv_t_blocks(float32, float)
|
||||
instantiate_gemv_t_blocks(float16, half)
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
||||
|
@@ -9,7 +9,7 @@
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
using namespace metal;\
|
||||
using namespace metal;
|
||||
|
||||
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||
|
||||
|
@@ -343,10 +343,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
|
||||
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
|
||||
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel =
|
||||
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
|
||||
(batch_size_mat == batch_size_vec ||
|
||||
std::min(batch_size_mat, batch_size_vec) == 1));
|
||||
|
||||
int nc_dim = out.ndim() - 2;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
@@ -383,6 +391,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
if (!contiguous_kernel) {
|
||||
kname << "_nc";
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
@@ -398,8 +410,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
|
||||
if (contiguous_kernel) {
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
} else {
|
||||
// In case of complex broadcasting, we consider the shape[:-2] and
|
||||
// strides [:-2] to determine the location of a batch
|
||||
// nc_dim = out.ndim() - 2
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
vec.strides().data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
mat.strides().data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
@@ -33,7 +33,7 @@ def silu(x):
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Applies the Gaussian Error Linear Units function.
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
@@ -78,7 +78,7 @@ class Conv2d(Module):
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: 0.
|
||||
applying the filter. Default: 1.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: 0.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
|
@@ -11,8 +11,6 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -99,26 +97,54 @@ class PyFileReader : public io::Reader {
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileReader() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
pyistream_.release().dec_ref();
|
||||
readinto_func_.release().dec_ref();
|
||||
seek_func_.release().dec_ref();
|
||||
tell_func_.release().dec_ref();
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyistream_.attr("closed").cast<bool>();
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyistream_.attr("closed").cast<bool>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyistream_.is_none();
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyistream_.is_none();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
py::object bytes_read =
|
||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
|
||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||
throw std::runtime_error("[load] Failed to read from python stream");
|
||||
}
|
||||
@@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
||||
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
for (auto& [key, arr] : array_dict) {
|
||||
py::gil_scoped_release gil;
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
@@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
arr.eval();
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
arr.eval();
|
||||
}
|
||||
return {arr};
|
||||
}
|
||||
|
||||
@@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer {
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileWriter() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
pyostream_.release().dec_ref();
|
||||
write_func_.release().dec_ref();
|
||||
seek_func_.release().dec_ref();
|
||||
tell_func_.release().dec_ref();
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return !pyostream_.attr("closed").cast<bool>();
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyostream_.attr("closed").cast<bool>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return !pyostream_.is_none();
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyostream_.is_none();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return tell_func_().cast<size_t>();
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
py::object bytes_written =
|
||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
|
||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||
throw std::runtime_error("[load] Failed to write to python stream");
|
||||
}
|
||||
@@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, retain_graph);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -285,7 +348,11 @@ void mlx_savez_helper(
|
||||
for (auto [k, a] : arrays_dict) {
|
||||
std::string fname = k + ".npy";
|
||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||
save(std::make_shared<PyFileWriter>(py_ostream), a);
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
|
@@ -340,6 +340,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
((32, 128, 64), (32, 64, 1)),
|
||||
((128, 64), (32, 64, 1)),
|
||||
((32, 128, 64), (64, 1)),
|
||||
((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
@@ -350,6 +351,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
((32, 1, 128), (32, 128, 64)),
|
||||
((32, 1, 128), (128, 64)),
|
||||
((1, 128), (32, 128, 64)),
|
||||
((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
|
3
setup.py
3
setup.py
@@ -136,7 +136,7 @@ if __name__ == "__main__":
|
||||
|
||||
setup(
|
||||
name="mlx",
|
||||
version=get_version("0.0.2"),
|
||||
version=get_version("0.0.4"),
|
||||
author="MLX Contributors",
|
||||
author_email="mlx@group.apple.com",
|
||||
description="A framework for machine learning on Apple Silicon.",
|
||||
@@ -145,6 +145,7 @@ if __name__ == "__main__":
|
||||
package_dir=package_dir,
|
||||
package_data=package_data,
|
||||
include_package_data=True,
|
||||
extras_require={"testing": ["numpy", "torch"]},
|
||||
ext_modules=[CMakeExtension("mlx.core")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
zip_safe=False,
|
||||
|
Reference in New Issue
Block a user