Compare commits

..

14 Commits

Author SHA1 Message Date
Angelos Katharopoulos
43e336cff2 Bump the version (#47) 2023-12-07 06:40:55 -08:00
Awni Hannun
d895e38f2e Nits (#38)
* include 3.12, black format

* circle ci badge

* format
2023-12-06 13:32:41 -08:00
Diogo
d15dead35e add extra_require with libs for running tests (#36) 2023-12-06 12:21:48 -08:00
Jagrit Digani
2440fe0124 NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
2023-12-06 12:03:47 -08:00
Awni Hannun
170e4b2d43 fix links (#32) 2023-12-06 08:12:06 -08:00
Jagrit Digani
2629cc8682 Install docs update (#29)
* Add notes about MacOS version restrictions for mlx in install docs 
* Add notes about Xcode version requirements for building from source in install docs
* Let make detect the macosx sdk version being used 
* Throw error if trying to build metal kernels with macOS <= 13.4 
* Add metal-cpp for macOS 14.2
2023-12-06 08:10:51 -08:00
Ikko Eltociear Ashimine
9f4cf2e0fe Update extensions.rst (#26)
unecessary -> unnecessary
2023-12-06 07:18:28 -08:00
Markus Enzweiler
2ffaee0c0d Updated default argument for stride to 1 in Conv2d() in the docstring (#22) 2023-12-06 07:17:58 -08:00
Yingbo Ma
36b245b287 Fix benchmark example (#11) 2023-12-06 07:17:16 -08:00
Esakkivel Esakkiraja
8c96b9a890 Update README.md (#9)
- Fixed typo and other minor errors
2023-12-05 21:31:27 -08:00
Angelos Katharopoulos
07897a346d Bump the version (#8)
* Bump the version
* Change the version in the docs as well
2023-12-05 17:46:08 -08:00
Jagrit Digani
d518b3b6a5 Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Awni Hannun
49cda449b1 apple mlr (#7) 2023-12-05 14:10:59 -08:00
Awni Hannun
6449a8682a Doc theme (#5)
* change docs theme + links + logo

* move mlx intro to landing page
2023-12-05 12:08:05 -08:00
22 changed files with 669 additions and 270 deletions

View File

@@ -203,7 +203,7 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
nightly_build:
when: << pipeline.parameters.nightly_build >>
@@ -211,7 +211,7 @@ workflows:
- build_package:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
weekly_build:
when: << pipeline.parameters.weekly_build >>
@@ -219,5 +219,5 @@ workflows:
- build_dev_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@ __pycache__/
# Metal libraries
*.metallib
venv/
# Distribution / packaging
python/mlx/share

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.1)
set(MLX_VERSION 0.0.3)
endif()
# ----------------------------- Lib -----------------------------
@@ -41,16 +41,19 @@ elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION)
message(STATUS "Detected macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
message(STATUS "Building with SDK for MacOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
message(FATAL_ERROR "MLX requires MacOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(

View File

@@ -2,15 +2,18 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[**Examples**](#examples)
MLX is an array framework for machine learning on Apple silicon.
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API which closely follows NumPy.
MLX also has a fully featured C++ API which closely mirrors the Python API.
MLX has higher level packages like `mlx.nn` and `mlx.optimizers` with APIs
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
MLX also has a fully featured C++ API, which closely mirrors the Python API.
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Composable function transformations**: MLX has composable function
@@ -25,15 +28,15 @@ Some key features of MLX include:
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and GPU).
(currently, the CPU and GPU).
- **Unified memory**: A noteable difference from MLX and other frameworks
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without moving data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user friendly, but still efficient
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
@@ -46,10 +49,10 @@ The design of MLX is inspired by frameworks like
## Examples
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
variety of examples including:
variety of examples, including:
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
- Large scale text generation with
- Large-scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
@@ -63,7 +66,7 @@ in the documentation.
## Installation
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API run:
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API, run:
```
pip install mlx

View File

@@ -30,7 +30,7 @@ def time_batch_matmul():
time_fn(batch_vjp_second)
def time_unbatch_matmul(key):
def time_unbatch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))

View File

@@ -7,7 +7,7 @@ for example with `conda`:
```
conda install sphinx
pip install sphinx-rtd-theme
pip install sphinx-book-theme
```
### Build

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@@ -10,8 +10,8 @@ import subprocess
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = "0.0.0"
release = "0.0.0"
version = "0.0.4"
release = "0.0.4"
# -- General configuration ---------------------------------------------------
@@ -39,7 +39,17 @@ pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
html_theme = "sphinx_rtd_theme"
html_theme = "sphinx_book_theme"
html_theme_options = {
"show_toc_level": 2,
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
}
html_logo = "_static/mlx_logo.png"
# -- Options for HTMLHelp output ---------------------------------------------

View File

@@ -131,7 +131,7 @@ back and go to our example to give ourselves a more concrete image.
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
@@ -945,4 +945,4 @@ Scripts
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@@ -321,7 +321,7 @@ which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.
You can download the full example code in `mlx-examples <code>`_. Assuming, the
You can download the full example code in `mlx-examples`_. Assuming, the
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):
@@ -369,9 +369,9 @@ Scripts
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
The full example code is available in `mlx-examples`_.
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llama
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv

View File

@@ -127,5 +127,5 @@ Finally, we put it all together by instantiating the model, the
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_
is available in the MLX GitHub repo.

View File

@@ -1,6 +1,30 @@
MLX
===
MLX is a NumPy-like array framework designed for efficient and flexible machine
learning on Apple silicon, brought to you by Apple machine learning research.
The Python API closely follows NumPy with a few exceptions. MLX also has a
fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
.. toctree::
:caption: Install
:maxdepth: 1

View File

@@ -11,6 +11,10 @@ silicon computer is
pip install mlx
.. note::
MLX is only available on devices running MacOS >= 13.3
It is highly recommended to use MacOS 14 (Sonoma)
Build from source
-----------------
@@ -19,6 +23,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above)
Python API
@@ -55,7 +60,7 @@ For developing use an editable install:
To make sure the install is working run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
C++ API
@@ -111,3 +116,21 @@ should point to the path to the built metal library.
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
MacOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version

View File

@@ -1,28 +1,6 @@
Quick Start Guide
=================
MLX is a NumPy-like array framework designed for efficient and flexible
machine learning on Apple silicon. The Python API closely follows NumPy with
a few exceptions. MLX also has a fully featured C++ API which closely follows
the Python API.
The main differences between MLX and NumPy are:
- **Composable function transformations**: MLX has composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Multi-device**: Operations can run on any of the supported devices (CPU,
GPU, ...)
The design of MLX is strongly inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
are the CPU and GPU.
Basics
------

View File

@@ -3,8 +3,9 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
@@ -12,52 +13,57 @@ using namespace metal;
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
static constant constexpr int SIMD_SIZE = 32;
#define MLX_MTL_CONST static constant constexpr const
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
MLX_MTL_CONST int SIMD_SIZE = 32;
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
struct GEMVKernel {
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Threadgroup in_vec cache
threadgroup T in_vec_block[BN][TN * 2];
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
// Thread local accumulation results
thread T result[TM] = {0};
@@ -69,7 +75,7 @@ template <typename T,
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
return;
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
@@ -79,62 +85,304 @@ template <typename T,
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
for(int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[simd_lid][tn];
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)simd_gid;
(void)simd_lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
// Threadgroup accumulation results
threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN;
int out_col = (tid.x * BN + lid.x) * TN;
int in_row = lid.y * TM;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
#pragma clang loop unroll(full)
for(int i = 0; i < TN; i++) {
tgp_results[lid.y * TN + i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j];
}
}
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
@@ -145,28 +393,51 @@ template <typename T,
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
instantiate_gemv_blocks(float32, float)
instantiate_gemv_blocks(float16, half)
instantiate_gemv_blocks(bfloat16, bfloat16_t)
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel]] void gemv_t(
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
@@ -175,110 +446,77 @@ template <typename T,
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
// Threadgroup accumulation results
threadgroup T tgp_results[BN][BM][TM];
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
int out_col = (tid.x * BN + lid.x) * TN;
int in_row = lid.y * TM;
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Edgecase handling
if (out_col < out_vec_size) {
// Edgecase handling
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size;
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
for(int i = 0; i < TN; i++) {
tgp_results[lid.x][lid.y][i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(lid.y == 0 && out_col < out_vec_size) {
// Threadgroup accumulation
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
for(int j = 0; j < TN; j++) {
result[j] += tgp_results[lid.x][i][j];
}
}
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
@@ -289,16 +527,39 @@ template <typename T,
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t_blocks(float32, float)
instantiate_gemv_t_blocks(float16, half)
instantiate_gemv_t_blocks(bfloat16, bfloat16_t)
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);

View File

@@ -9,7 +9,7 @@
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;\
using namespace metal;
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub

View File

@@ -343,10 +343,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0;
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
int batch_size_vec = vec.data_size() / in_vector_len;
int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0;
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel =
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
(batch_size_mat == batch_size_vec ||
std::min(batch_size_mat, batch_size_vec) == 1));
int nc_dim = out.ndim() - 2;
// Determine dispatch kernel
int tm = 4, tn = 4;
@@ -383,6 +391,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
if (!contiguous_kernel) {
kname << "_nc";
}
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
@@ -398,8 +410,22 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
if (contiguous_kernel) {
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
} else {
// In case of complex broadcasting, we consider the shape[:-2] and
// strides [:-2] to determine the location of a batch
// nc_dim = out.ndim() - 2
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(
vec.strides().data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(
mat.strides().data(), nc_dim * sizeof(size_t), 8);
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@@ -33,7 +33,7 @@ def silu(x):
def gelu(x):
"""Applies the Gaussian Error Linear Units function.
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)

View File

@@ -78,7 +78,7 @@ class Conv2d(Module):
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: 0.
applying the filter. Default: 1.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: 0.
bias (bool, optional): If ``True`` add a learnable bias to the

View File

@@ -11,8 +11,6 @@
#include <unordered_map>
#include <vector>
#include <iostream>
#include "mlx/load.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
@@ -99,26 +97,54 @@ class PyFileReader : public io::Reader {
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileReader() {
py::gil_scoped_acquire gil;
pyistream_.release().dec_ref();
readinto_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
bool is_open() const override {
return !pyistream_.attr("closed").cast<bool>();
bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.attr("closed").cast<bool>();
}
return out;
}
bool good() const override {
return !pyistream_.is_none();
bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.is_none();
}
return out;
}
size_t tell() const override {
return tell_func_().cast<size_t>();
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void read(char* data, size_t n) override {
py::gil_scoped_acquire gil;
py::object bytes_read =
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream");
}
@@ -163,6 +189,7 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
// If we don't own the stream and it was passed to us, eval immediately
for (auto& [key, arr] : array_dict) {
py::gil_scoped_release gil;
arr.eval();
}
@@ -172,7 +199,10 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
arr.eval();
{
py::gil_scoped_release gil;
arr.eval();
}
return {arr};
}
@@ -192,26 +222,54 @@ class PyFileWriter : public io::Writer {
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileWriter() {
py::gil_scoped_acquire gil;
pyostream_.release().dec_ref();
write_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
bool is_open() const override {
return !pyostream_.attr("closed").cast<bool>();
bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.attr("closed").cast<bool>();
}
return out;
}
bool good() const override {
return !pyostream_.is_none();
bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.is_none();
}
return out;
}
size_t tell() const override {
return tell_func_().cast<size_t>();
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void write(const char* data, size_t n) override {
py::gil_scoped_acquire gil;
py::object bytes_written =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream");
}
@@ -233,7 +291,12 @@ void mlx_save_helper(py::object file, array a, bool retain_graph) {
save(py::cast<std::string>(file), a, retain_graph);
return;
} else if (is_ostream_object(file)) {
save(std::make_shared<PyFileWriter>(file), a, retain_graph);
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
save(writer, a, retain_graph);
}
return;
}
@@ -285,7 +348,11 @@ void mlx_savez_helper(
for (auto [k, a] : arrays_dict) {
std::string fname = k + ".npy";
auto py_ostream = zipfile_object.open(fname, 'w');
save(std::make_shared<PyFileWriter>(py_ostream), a);
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release gil;
save(writer, a);
}
}
return;

View File

@@ -340,6 +340,7 @@ class TestBlas(mlx_tests.MLXTestCase):
((32, 128, 64), (32, 64, 1)),
((128, 64), (32, 64, 1)),
((32, 128, 64), (64, 1)),
((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
@@ -350,6 +351,7 @@ class TestBlas(mlx_tests.MLXTestCase):
((32, 1, 128), (32, 128, 64)),
((32, 1, 128), (128, 64)),
((1, 128), (32, 128, 64)),
((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype

View File

@@ -136,7 +136,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.0.2"),
version=get_version("0.0.4"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple Silicon.",
@@ -145,6 +145,7 @@ if __name__ == "__main__":
package_dir=package_dir,
package_data=package_data,
include_package_data=True,
extras_require={"testing": ["numpy", "torch"]},
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,