mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Merge branch 'main' into losses
This commit is contained in:
commit
7111d5b889
@ -104,7 +104,7 @@ jobs:
|
|||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
- run:
|
- run:
|
||||||
name: Build pacakge
|
name: Build package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
conda activate runner-env
|
conda activate runner-env
|
||||||
@ -140,7 +140,7 @@ jobs:
|
|||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
- run:
|
- run:
|
||||||
name: Build pacakge
|
name: Build package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
conda activate runner-env
|
conda activate runner-env
|
||||||
@ -176,7 +176,7 @@ jobs:
|
|||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
- run:
|
- run:
|
||||||
name: Build pacakge
|
name: Build package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
conda activate runner-env
|
conda activate runner-env
|
||||||
|
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
||||||
|
@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
|||||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||||
|
|
||||||
|
|
||||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
mlx_gb_s = []
|
mlx_gb_s = []
|
||||||
mlx_gflops = []
|
mlx_gflops = []
|
||||||
@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
|||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
||||||
|
|
||||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
mlx_gb_s = []
|
mlx_gb_s = []
|
||||||
mlx_gflops = []
|
mlx_gflops = []
|
||||||
|
@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@ include(CMakeParseArguments)
|
|||||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||||
# SOURCES: List of source files
|
# SOURCES: List of source files
|
||||||
# INCLUDE_DIRS: List of include dirs
|
# INCLUDE_DIRS: List of include dirs
|
||||||
# DEPS: List of depedency files (like headers)
|
# DEPS: List of dependency files (like headers)
|
||||||
#
|
#
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
@ -32,7 +32,7 @@ macro(mlx_build_metallib)
|
|||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||||
|
|
||||||
# Prepare metllib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND xcrun -sdk macosx metal
|
||||||
|
@ -26,7 +26,7 @@ python -m http.server <port>
|
|||||||
|
|
||||||
and point your browser to `http://localhost:<port>`.
|
and point your browser to `http://localhost:<port>`.
|
||||||
|
|
||||||
### Push to Github Pages
|
### Push to GitHub Pages
|
||||||
|
|
||||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||||
the docs. Then force add the `build/html` directory:
|
the docs. Then force add the `build/html` directory:
|
||||||
|
@ -15,7 +15,7 @@ Introducing the Example
|
|||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
Let's say that you would like an operation that takes in two arrays,
|
Let's say that you would like an operation that takes in two arrays,
|
||||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||||
respectively, and then adds them together to get the result
|
respectively, and then adds them together to get the result
|
||||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||||
writing out a function as follows:
|
writing out a function as follows:
|
||||||
@ -69,7 +69,7 @@ C++ API:
|
|||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
|
|
||||||
This operation now handles the following:
|
This operation now handles the following:
|
||||||
|
|
||||||
#. Upcast inputs and resolve the the output data type.
|
#. Upcast inputs and resolve the output data type.
|
||||||
#. Broadcast the inputs and resolve the output shape.
|
#. Broadcast the inputs and resolve the output shape.
|
||||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||||
@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the elementwise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -305,7 +305,7 @@ if we encounter an unexpected type.
|
|||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||||
// Check the inputs (registered in the op while contructing the out array)
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@ -485,7 +485,7 @@ each data type.
|
|||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bflot16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
|
||||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||||
@ -537,7 +537,7 @@ below.
|
|||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel decelaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
int ndim = out.ndim();
|
int ndim = out.ndim();
|
||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
@ -568,7 +568,7 @@ below.
|
|||||||
// Fix the 3D size of the launch grid (in terms of threads)
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
// Launch the grid with the given number of threads divded among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
|
|||||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||||
until some specified limit is hit or the compute encoder needs to be flushed
|
until some specified limit is hit or the compute encoder needs to be flushed
|
||||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
for synchronization. MLX also handles enqueuing and committing the associated
|
||||||
command buffers as needed. We suggest taking a deeper dive into
|
command buffers as needed. We suggest taking a deeper dive into
|
||||||
:class:`metal::Device` if you would like to study this routine further.
|
:class:`metal::Device` if you would like to study this routine further.
|
||||||
|
|
||||||
@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@ -642,7 +642,7 @@ own :class:`Primitive`.
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<array, int> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
|
|||||||
| └── setup.py
|
| └── setup.py
|
||||||
|
|
||||||
* ``extensions/axpby/`` defines the C++ extension library
|
* ``extensions/axpby/`` defines the C++ extension library
|
||||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||||
associated python package
|
associated python package
|
||||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||||
@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = py::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Scale and sum two vectors elementwise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
@ -840,7 +840,7 @@ This will result in a directory structure as follows:
|
|||||||
| ...
|
| ...
|
||||||
|
|
||||||
When you try to install using the command ``python -m pip install .``
|
When you try to install using the command ``python -m pip install .``
|
||||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
(in ``extensions/``), the package will be installed with the same structure as
|
||||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||||
copied along with the python binding since they are specified as ``package_data``.
|
copied along with the python binding since they are specified as ``package_data``.
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
|
|||||||
|
|
||||||
The design of MLX is inspired by frameworks like `PyTorch
|
The design of MLX is inspired by frameworks like `PyTorch
|
||||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||||
memory. Operations on MLX arrays can be performed on any of the supported
|
memory. Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without performing data copies. Currently supported device types
|
device types without performing data copies. Currently supported device types
|
||||||
|
@ -28,6 +28,7 @@ Layers
|
|||||||
GroupNorm
|
GroupNorm
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
|
Dropout3d
|
||||||
Transformer
|
Transformer
|
||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
ALiBi
|
ALiBi
|
||||||
|
@ -57,7 +57,7 @@ void array_basics() {
|
|||||||
assert(z.shape(0) == 2);
|
assert(z.shape(0) == 2);
|
||||||
assert(z.shape(1) == 2);
|
assert(z.shape(1) == 2);
|
||||||
|
|
||||||
// To actually run the compuation you must evaluate `z`.
|
// To actually run the computation you must evaluate `z`.
|
||||||
// Under the hood, mlx records operations in a graph.
|
// Under the hood, mlx records operations in a graph.
|
||||||
// The variable `z` is a node in the graph which points to its operation
|
// The variable `z` is a node in the graph which points to its operation
|
||||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||||
|
@ -26,7 +26,7 @@ namespace mlx::core {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@ -91,21 +91,21 @@ void axpby_impl(
|
|||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the elementwise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||||
// Check the inputs (registered in the op while contructing the out array)
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
@ -192,7 +192,7 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Accelerate not avaliable
|
#else // Accelerate not available
|
||||||
|
|
||||||
/** Evaluate primitive on CPU falling back to common backend */
|
/** Evaluate primitive on CPU falling back to common backend */
|
||||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -254,7 +254,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel decelaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
int ndim = out.ndim();
|
int ndim = out.ndim();
|
||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Fix the 3D size of the launch grid (in terms of threads)
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
// Launch the grid with the given number of threads divded among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
@ -311,8 +311,8 @@ array Axpby::jvp(
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<array, int> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -12,7 +12,7 @@ namespace mlx::core {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scale and sum two vectors elementwise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
@ -39,7 +39,7 @@ class Axpby : public Primitive {
|
|||||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
* for the given inputs and populate the output array.
|
* for the given inputs and populate the output array.
|
||||||
*
|
*
|
||||||
* To avoid unecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
@ -59,5 +59,5 @@ template <typename T>
|
|||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bflot16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = py::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Scale and sum two vectors elementwise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
|
@ -37,7 +37,7 @@ void free(Buffer buffer);
|
|||||||
Buffer malloc_or_wait(size_t size);
|
Buffer malloc_or_wait(size_t size);
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base clase for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
|
@ -129,7 +129,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Needed because the Primitive type used in array.h is incomplete and the
|
// Needed because the Primitive type used in array.h is incomplete and the
|
||||||
// compiler needs to see the call to the desctructor after the type is complete.
|
// compiler needs to see the call to the destructor after the type is complete.
|
||||||
array::ArrayDesc::~ArrayDesc() = default;
|
array::ArrayDesc::~ArrayDesc() = default;
|
||||||
|
|
||||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
|
@ -13,7 +13,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <const uint8_t scalar_size>
|
template <const uint8_t scalar_size>
|
||||||
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||||
struct Elem {
|
struct Elem {
|
||||||
uint8_t bytes[scalar_size];
|
uint8_t bytes[scalar_size];
|
||||||
};
|
};
|
||||||
@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
case 2:
|
case 2:
|
||||||
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +165,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
// Prepare to allocate new memory as needed
|
// Prepare to allocate new memory as needed
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we are under very high memoory pressure, we don't allocate further
|
// If we are under very high memory pressure, we don't allocate further
|
||||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||||
return Buffer{nullptr};
|
return Buffer{nullptr};
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ void explicit_gemm_conv_1D_gpu(
|
|||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
// Peform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
mlx_matmul(
|
mlx_matmul(
|
||||||
s,
|
s,
|
||||||
@ -260,7 +260,7 @@ void explicit_gemm_conv_2D_gpu(
|
|||||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
// Peform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
mlx_matmul(
|
mlx_matmul(
|
||||||
s,
|
s,
|
||||||
|
@ -102,7 +102,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate the argument bufer
|
// Allocate the argument buffer
|
||||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||||
|
|
||||||
// Register data with the encoder
|
// Register data with the encoder
|
||||||
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate the argument bufer
|
// Allocate the argument buffer
|
||||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||||
|
|
||||||
// Register data with the encoder
|
// Register data with the encoder
|
||||||
|
@ -114,7 +114,7 @@ template <typename T, typename Op, int N_READS>
|
|||||||
// 4. Reduce among them and go to 3
|
// 4. Reduce among them and go to 3
|
||||||
// 4. Reduce in each simd_group
|
// 4. Reduce in each simd_group
|
||||||
// 6. Write in the thread local memory
|
// 6. Write in the thread local memory
|
||||||
// 6. Reduce them accross thread group
|
// 6. Reduce them across thread group
|
||||||
// 7. Write the output without need for atomic
|
// 7. Write the output without need for atomic
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ struct complex64_t {
|
|||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||||
|
|
||||||
// Converstions from complex64_t
|
// Conversions from complex64_t
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
@ -105,7 +105,7 @@ struct Conv2DInputBlockLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero pad otherwize
|
// Zero pad otherwise
|
||||||
else {
|
else {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (short j = 0; j < vec_size; ++j) {
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
@ -334,7 +334,7 @@ struct Conv2DBlockMMA {
|
|||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
// Multiply and accumulate into resulr simdgroup matrices
|
// Multiply and accumulate into result simdgroup matrices
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (short i = 0; i < TM; i++) {
|
for (short i = 0; i < TM; i++) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
|
@ -93,13 +93,13 @@ struct BlockLoader {
|
|||||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read all valid indcies into tmp_val
|
// Read all valid indices into tmp_val
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (short j = 0; j < vec_size; j++) {
|
for (short j = 0; j < vec_size; j++) {
|
||||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero out uneeded values
|
// Zero out unneeded values
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (short j = 0; j < vec_size; j++) {
|
for (short j = 0; j < vec_size; j++) {
|
||||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||||
@ -241,7 +241,7 @@ struct BlockMMA {
|
|||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
// Multiply and accumulate into resulr simdgroup matrices
|
// Multiply and accumulate into result simdgroup matrices
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (short i = 0; i < TM; i++) {
|
for (short i = 0; i < TM; i++) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
|
@ -28,7 +28,7 @@ struct GEMVKernel {
|
|||||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||||
|
|
||||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||||
// - Every thread works on a block of (TM, TN)
|
// - Every thread works on a block of (TM, TN)
|
||||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||||
//
|
//
|
||||||
@ -42,7 +42,7 @@ struct GEMVKernel {
|
|||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||||
@ -166,7 +166,7 @@ template <
|
|||||||
struct GEMVTKernel {
|
struct GEMVTKernel {
|
||||||
|
|
||||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||||
// - Every thread works on a block of (TM, TN)
|
// - Every thread works on a block of (TM, TN)
|
||||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||||
//
|
//
|
||||||
@ -180,7 +180,7 @@ struct GEMVTKernel {
|
|||||||
// Edge case handling:
|
// Edge case handling:
|
||||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||||
// such that the thread block fits exactly in the matrix
|
// such that the thread block fits exactly in the matrix
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
in += grid_size * N_READS;
|
in += grid_size * N_READS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sepate case for the last set as we close the reduction size
|
// Separate case for the last set as we close the reduction size
|
||||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||||
if (curr_idx < in_size) {
|
if (curr_idx < in_size) {
|
||||||
int max_reads = in_size - curr_idx;
|
int max_reads = in_size - curr_idx;
|
||||||
|
@ -592,7 +592,7 @@ template <
|
|||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD>
|
short N_PER_THREAD>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||||
device idx_t* block_partitions [[buffer(0)]],
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals [[buffer(1)]],
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs [[buffer(2)]],
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
@ -777,8 +777,8 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]], \
|
const device size_t* nc_strides [[buffer(7)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||||
device itype* block_partitions [[buffer(0)]], \
|
device itype* block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals [[buffer(1)]], \
|
const device vtype* dev_vals [[buffer(1)]], \
|
||||||
const device itype* dev_idxs [[buffer(2)]], \
|
const device itype* dev_idxs [[buffer(2)]], \
|
||||||
|
@ -61,7 +61,7 @@ inline void mps_matmul(
|
|||||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||||
// the other has matrix worth of data
|
// the other has matrix worth of data
|
||||||
|
|
||||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
// The matrix dimensions of a and b are sure to be regularly strided
|
||||||
if (batch_size_out > 1) {
|
if (batch_size_out > 1) {
|
||||||
// No broadcasting defaults
|
// No broadcasting defaults
|
||||||
auto batch_size_a = a.data_size() / (M * K);
|
auto batch_size_a = a.data_size() / (M * K);
|
||||||
|
@ -40,7 +40,7 @@ void all_reduce_dispatch(
|
|||||||
// Set grid dimensions
|
// Set grid dimensions
|
||||||
|
|
||||||
// We make sure each thread has enough to do by making it read in
|
// We make sure each thread has enough to do by making it read in
|
||||||
// atleast n_reads inputs
|
// at least n_reads inputs
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
|
|
||||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||||
@ -176,7 +176,7 @@ void strided_reduce_general_dispatch(
|
|||||||
|
|
||||||
// We spread outputs over the x dimension and inputs over the y dimension
|
// We spread outputs over the x dimension and inputs over the y dimension
|
||||||
// Threads with the same lid.x in a given threadgroup work on the same
|
// Threads with the same lid.x in a given threadgroup work on the same
|
||||||
// output and each thread in the y dimension accumlates for that output
|
// output and each thread in the y dimension accumulates for that output
|
||||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||||
uint threadgroup_dim_y =
|
uint threadgroup_dim_y =
|
||||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||||
|
@ -165,10 +165,10 @@ void multi_block_sort(
|
|||||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||||
ping = !ping;
|
ping = !ping;
|
||||||
|
|
||||||
// Do partiton
|
// Do partition
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
@ -18,7 +18,7 @@ void set_array_buffer(
|
|||||||
auto offset = a.data<char>() -
|
auto offset = a.data<char>() -
|
||||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
enc->setBuffer(a_buf, offset, idx);
|
enc->setBuffer(a_buf, offset, idx);
|
||||||
// MTL::Resource usage through argument buffer needs to be explicity
|
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||||
// flagged to enable hazard tracking
|
// flagged to enable hazard tracking
|
||||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ array fft_impl(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the following shape manipulations there are three cases to consdier:
|
// In the following shape manipulations there are three cases to consider:
|
||||||
// 1. In a complex to complex transform (fftn / ifftn) the output
|
// 1. In a complex to complex transform (fftn / ifftn) the output
|
||||||
// and input shapes are the same.
|
// and input shapes are the same.
|
||||||
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
||||||
|
@ -155,7 +155,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
// Read and check version
|
// Read and check version
|
||||||
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
|
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load] Unsupport npy format version in " + in_stream->label());
|
"[load] Unsupported npy format version in " + in_stream->label());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read header len and header
|
// Read header len and header
|
||||||
|
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
if (x.ndim() < 2) {
|
if (x.ndim() < 2) {
|
||||||
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
throw std::invalid_argument("[tril] array must be at least 2-D");
|
||||||
}
|
}
|
||||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||||
return where(mask, x, zeros_like(x, s), s);
|
return where(mask, x, zeros_like(x, s), s);
|
||||||
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
if (x.ndim() < 2) {
|
if (x.ndim() < 2) {
|
||||||
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
throw std::invalid_argument("[triu] array must be at least 2-D");
|
||||||
}
|
}
|
||||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||||
return where(mask, zeros_like(x, s), x, s);
|
return where(mask, zeros_like(x, s), x, s);
|
||||||
@ -350,7 +350,7 @@ array squeeze(
|
|||||||
ax = ax < 0 ? ax + a.ndim() : ax;
|
ax = ax < 0 ? ax + a.ndim() : ax;
|
||||||
if (ax < 0 || ax >= a.ndim()) {
|
if (ax < 0 || ax >= a.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[squeeze] Invalid axies " << ax << " for array with " << a.ndim()
|
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
|
||||||
<< " dimensions.";
|
<< " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -405,7 +405,7 @@ array expand_dims(
|
|||||||
ax = ax < 0 ? ax + out_ndim : ax;
|
ax = ax < 0 ? ax + out_ndim : ax;
|
||||||
if (ax < 0 || ax >= out_ndim) {
|
if (ax < 0 || ax >= out_ndim) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[squeeze] Invalid axies " << ax << " for output array with "
|
msg << "[squeeze] Invalid axes " << ax << " for output array with "
|
||||||
<< a.ndim() << " dimensions.";
|
<< a.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -478,7 +478,7 @@ array slice(
|
|||||||
|
|
||||||
// If strides are negative, slice and then make a copy with axes flipped
|
// If strides are negative, slice and then make a copy with axes flipped
|
||||||
if (negatively_strided_axes.size() > 0) {
|
if (negatively_strided_axes.size() > 0) {
|
||||||
// First, take the slice of the positvely strided axes
|
// First, take the slice of the positively strided axes
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
@ -517,7 +517,7 @@ array slice(
|
|||||||
// Gather moves the axis up, remainder needs to be squeezed
|
// Gather moves the axis up, remainder needs to be squeezed
|
||||||
out_reshape[i] = indices[i].size();
|
out_reshape[i] = indices[i].size();
|
||||||
|
|
||||||
// Gather moves the axis up, needs to be tranposed
|
// Gather moves the axis up, needs to be transposed
|
||||||
t_axes[ax] = i;
|
t_axes[ax] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
|||||||
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||||
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Repeate an array along an axis. */
|
/** Repeat an array along an axis. */
|
||||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
||||||
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class Primitive {
|
|||||||
* A primitive must know how to evaluate itself on
|
* A primitive must know how to evaluate itself on
|
||||||
* the CPU/GPU for the given inputs and populate the output array.
|
* the CPU/GPU for the given inputs and populate the output array.
|
||||||
*
|
*
|
||||||
* To avoid unecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
|
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
|
||||||
@ -84,7 +84,7 @@ class Primitive {
|
|||||||
/** Print the primitive. */
|
/** Print the primitive. */
|
||||||
virtual void print(std::ostream& os) = 0;
|
virtual void print(std::ostream& os) = 0;
|
||||||
|
|
||||||
/** Equivalence check defaults to false unless overriden by the primitive */
|
/** Equivalence check defaults to false unless overridden by the primitive */
|
||||||
virtual bool is_equivalent(const Primitive& other) const {
|
virtual bool is_equivalent(const Primitive& other) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ array truncated_normal(
|
|||||||
auto u = uniform(a, b, shape, dtype, key, s);
|
auto u = uniform(a, b, shape, dtype, key, s);
|
||||||
auto out = multiply(sqrt2, erfinv(u, s), s);
|
auto out = multiply(sqrt2, erfinv(u, s), s);
|
||||||
|
|
||||||
// Clip in bouds
|
// Clip in bounds
|
||||||
return maximum(minimum(upper_t, out, s), lower_t, s);
|
return maximum(minimum(upper_t, out, s), lower_t, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ class KeySequence {
|
|||||||
void seed(uint64_t seed);
|
void seed(uint64_t seed);
|
||||||
array next();
|
array next();
|
||||||
|
|
||||||
// static defualt
|
// static default
|
||||||
static KeySequence& default_() {
|
static KeySequence& default_() {
|
||||||
static KeySequence ks(0);
|
static KeySequence ks(0);
|
||||||
return ks;
|
return ks;
|
||||||
|
@ -80,7 +80,7 @@ ValueAndGradFn value_and_grad(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a function which computes the value and gradient of the input
|
* Returns a function which computes the value and gradient of the input
|
||||||
* function with repsect to a single input array.
|
* function with respect to a single input array.
|
||||||
**/
|
**/
|
||||||
ValueAndGradFn inline value_and_grad(
|
ValueAndGradFn inline value_and_grad(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
@ -132,7 +132,7 @@ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a function which computes the gradient of the input function with
|
* Returns a function which computes the gradient of the input function with
|
||||||
* repsect to a single input array.
|
* respect to a single input array.
|
||||||
*
|
*
|
||||||
* The function being differentiated takes a vector of arrays and returns an
|
* The function being differentiated takes a vector of arrays and returns an
|
||||||
* array. The optional `argnum` index specifies which the argument to compute
|
* array. The optional `argnum` index specifies which the argument to compute
|
||||||
|
@ -68,7 +68,7 @@ struct _MLX_Float16 {
|
|||||||
inf_scale.u = uint32_t(0x77800000);
|
inf_scale.u = uint32_t(0x77800000);
|
||||||
zero_scale.u = uint32_t(0x08800000);
|
zero_scale.u = uint32_t(0x08800000);
|
||||||
|
|
||||||
// Combine with magic and let addition do rouding
|
// Combine with magic and let addition do rounding
|
||||||
magic_bits.u = x_expo_32;
|
magic_bits.u = x_expo_32;
|
||||||
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ from mlx.nn.layers.activations import (
|
|||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||||
|
@ -86,3 +86,52 @@ class Dropout2d(Module):
|
|||||||
|
|
||||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
return (1 / self._p_1) * mask * x
|
return (1 / self._p_1) * mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class Dropout3d(Module):
|
||||||
|
r"""Apply 3D channel-wise dropout during training.
|
||||||
|
|
||||||
|
Randomly zero out entire channels independently with probability :math:`p`.
|
||||||
|
This layer expects the channels to be last, i.e., the input shape should be
|
||||||
|
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
|
||||||
|
`H` is the input image height, `W` is the input image width, and `C` is
|
||||||
|
the number of input channels.
|
||||||
|
|
||||||
|
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
||||||
|
maintain the expected value of each element. Unlike traditional dropout,
|
||||||
|
which zeros individual entries, this layer zeros entire channels. This is
|
||||||
|
often beneficial for convolutional layers processing 3D data, like in
|
||||||
|
medical imaging or video processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (float): Probability of zeroing a channel during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: float = 0.5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if p < 0 or p >= 1:
|
||||||
|
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
|
||||||
|
|
||||||
|
self._p_1 = 1 - p
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"p={1-self._p_1}"
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if x.ndim not in (4, 5):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._p_1 == 1 or not self.training:
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Dropout is applied on the whole channel
|
||||||
|
# 4D input: (1, 1, 1, C)
|
||||||
|
# 5D input: (B, 1, 1, 1, C)
|
||||||
|
mask_shape = list(x.shape)
|
||||||
|
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
||||||
|
|
||||||
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
|
return (1 / self._p_1) * mask * x
|
||||||
|
@ -198,7 +198,7 @@ class BatchNorm(Module):
|
|||||||
batch, ``C`` is the number of features or channels, and ``L`` is the
|
batch, ``C`` is the number of features or channels, and ``L`` is the
|
||||||
sequence length. The output has the same shape as the input. For
|
sequence length. The output has the same shape as the input. For
|
||||||
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
||||||
the height and width respecitvely.
|
the height and width respectively.
|
||||||
|
|
||||||
For more information on Batch Normalization, see the original paper `Batch
|
For more information on Batch Normalization, see the original paper `Batch
|
||||||
Normalization: Accelerating Deep Network Training by Reducing Internal
|
Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||||
|
@ -24,13 +24,21 @@ class RoPE(Module):
|
|||||||
implementation which is slightly less efficient. Default: ``False``.
|
implementation which is slightly less efficient. Default: ``False``.
|
||||||
base (float, optional): The base used to compute angular frequency for
|
base (float, optional): The base used to compute angular frequency for
|
||||||
each dimension in the positional encodings. Default: ``10000``.
|
each dimension in the positional encodings. Default: ``10000``.
|
||||||
|
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.traditional = traditional
|
self.traditional = traditional
|
||||||
self.base = base
|
self.base = base
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.dims}, traditional={self.traditional}"
|
return f"{self.dims}, traditional={self.traditional}"
|
||||||
@ -68,7 +76,7 @@ class RoPE(Module):
|
|||||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
N = x.shape[1] + offset
|
N = x.shape[1] + offset
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
rope = (
|
rope = (
|
||||||
@ -80,10 +88,15 @@ class RoPE(Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_cos_sin_theta(
|
def create_cos_sin_theta(
|
||||||
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
|
N: int,
|
||||||
|
D: int,
|
||||||
|
offset: int = 0,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
dtype=mx.float32,
|
||||||
):
|
):
|
||||||
D = D // 2
|
D = D // 2
|
||||||
positions = mx.arange(offset, N, dtype=dtype)
|
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||||
return mx.cos(theta), mx.sin(theta)
|
return mx.cos(theta), mx.sin(theta)
|
||||||
|
@ -253,7 +253,7 @@ class AdaDelta(Optimizer):
|
|||||||
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
||||||
running average of squared gradients. Default: ``0.9``
|
running average of squared gradients. Default: ``0.9``
|
||||||
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
||||||
numerical stability. Ddefault: `1e-8`
|
numerical stability. Default: `1e-8`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||||
|
@ -64,7 +64,7 @@ auto to_scalar(array& a) {
|
|||||||
case float32:
|
case float32:
|
||||||
return py::cast(a.item<float>(retain_graph));
|
return py::cast(a.item<float>(retain_graph));
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
|
||||||
case complex64:
|
case complex64:
|
||||||
return py::cast(a.item<std::complex<float>>(retain_graph));
|
return py::cast(a.item<std::complex<float>>(retain_graph));
|
||||||
}
|
}
|
||||||
@ -507,7 +507,7 @@ void init_array(py::module_& m) {
|
|||||||
|
|
||||||
array_class
|
array_class
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
@ -559,7 +559,7 @@ void init_array(py::module_& m) {
|
|||||||
If the array has more than one dimension then the result is a nested
|
If the array has more than one dimension then the result is a nested
|
||||||
list of lists.
|
list of lists.
|
||||||
|
|
||||||
The value type of the list correpsonding to the last dimension is either
|
The value type of the list corresponding to the last dimension is either
|
||||||
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def("__array__", &mlx_array_to_np)
|
.def("__array__", &mlx_array_to_np)
|
||||||
|
@ -1263,7 +1263,7 @@ void init_ops(py::module_& m) {
|
|||||||
If the axis is not specified the array is treated as a flattened
|
If the axis is not specified the array is treated as a flattened
|
||||||
1-D array prior to performing the take.
|
1-D array prior to performing the take.
|
||||||
|
|
||||||
As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``.
|
As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
@ -1742,7 +1742,7 @@ void init_ops(py::module_& m) {
|
|||||||
"a"_a,
|
"a"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"source"_a,
|
"source"_a,
|
||||||
"destiantion"_a,
|
"destination"_a,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) {
|
|||||||
will be of elements less or equal to the element at the ``kth``
|
will be of elements less or equal to the element at the ``kth``
|
||||||
index and all indices after will be of elements greater or equal
|
index and all indices after will be of elements greater or equal
|
||||||
to the element at the ``kth`` index.
|
to the element at the ``kth`` index.
|
||||||
axis (int or None, optional): Optional axis to partiton over.
|
axis (int or None, optional): Optional axis to partition over.
|
||||||
If ``None``, this partitions over the flattened array.
|
If ``None``, this partitions over the flattened array.
|
||||||
If unspecified, it defaults to ``-1``.
|
If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
@ -2426,13 +2426,13 @@ void init_ops(py::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Repeate an array along a specified axis.
|
Repeat an array along a specified axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
array (array): Input array.
|
array (array): Input array.
|
||||||
repeats (int): The number of repetitions for each element.
|
repeats (int): The number of repetitions for each element.
|
||||||
axis (int, optional): The axis in which to repeat the array along. If
|
axis (int, optional): The axis in which to repeat the array along. If
|
||||||
unspecified it uses the flattened array of the input and repeates
|
unspecified it uses the flattened array of the input and repeats
|
||||||
along axis 0.
|
along axis 0.
|
||||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||||
|
|
||||||
@ -3050,7 +3050,7 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Round to the given number of decimals.
|
Round to the given number of decimals.
|
||||||
|
|
||||||
Bascially performs:
|
Basically performs:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ void init_random(py::module_& parent_module) {
|
|||||||
upper (scalar or array): Upper bound of the domain.
|
upper (scalar or array): Upper bound of the domain.
|
||||||
shape (list(int), optional): The shape of the output.
|
shape (list(int), optional): The shape of the output.
|
||||||
Default is ``()``.
|
Default is ``()``.
|
||||||
dtype (Dtype, optinoal): The data type of the output.
|
dtype (Dtype, optional): The data type of the output.
|
||||||
Default is ``float32``.
|
Default is ``float32``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: None.
|
||||||
|
|
||||||
|
@ -102,6 +102,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(x.item(), 1)
|
self.assertEqual(x.item(), 1)
|
||||||
self.assertTrue(isinstance(x.item(), int))
|
self.assertTrue(isinstance(x.item(), int))
|
||||||
|
|
||||||
|
x = mx.array(1, mx.bfloat16)
|
||||||
|
self.assertEqual(x.item(), 1.0)
|
||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
self.assertEqual(x.size, 1)
|
self.assertEqual(x.size, 1)
|
||||||
self.assertEqual(x.ndim, 0)
|
self.assertEqual(x.ndim, 0)
|
||||||
@ -949,7 +952,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
b_mx = a_mx[25:-50:-3]
|
b_mx = a_mx[25:-50:-3]
|
||||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||||
|
|
||||||
# Negatie slice and ascending bounds
|
# Negative slice and ascending bounds
|
||||||
b_np = a_np[0:20:-3]
|
b_np = a_np[0:20:-3]
|
||||||
b_mx = a_mx[0:20:-3]
|
b_mx = a_mx[0:20:-3]
|
||||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||||
|
@ -53,10 +53,10 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
for dtype in self.dtypes:
|
for dtype in self.dtypes:
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
base_shapes = [4, 8, 16, 32, 64, 128]
|
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||||
pertubations = [-2, -1, 0, 1, 2]
|
perturbations = [-2, -1, 0, 1, 2]
|
||||||
|
|
||||||
for dim in base_shapes:
|
for dim in base_shapes:
|
||||||
for p in pertubations:
|
for p in perturbations:
|
||||||
shape_a = (dim + p, dim + p)
|
shape_a = (dim + p, dim + p)
|
||||||
shape_b = (dim + p, dim + p)
|
shape_b = (dim + p, dim + p)
|
||||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||||
@ -81,12 +81,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
for B, M, N, K in shapes:
|
for B, M, N, K in shapes:
|
||||||
|
|
||||||
with self.subTest(tranpose="nn"):
|
with self.subTest(transpose="nn"):
|
||||||
shape_a = (B, M, K)
|
shape_a = (B, M, K)
|
||||||
shape_b = (B, K, N)
|
shape_b = (B, K, N)
|
||||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||||
|
|
||||||
with self.subTest(tranpose="nt"):
|
with self.subTest(transpose="nt"):
|
||||||
shape_a = (B, M, K)
|
shape_a = (B, M, K)
|
||||||
shape_b = (B, N, K)
|
shape_b = (B, N, K)
|
||||||
self.__gemm_test(
|
self.__gemm_test(
|
||||||
@ -97,7 +97,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.subTest(tranpose="tn"):
|
with self.subTest(transpose="tn"):
|
||||||
shape_a = (B, K, M)
|
shape_a = (B, K, M)
|
||||||
shape_b = (B, K, N)
|
shape_b = (B, K, N)
|
||||||
self.__gemm_test(
|
self.__gemm_test(
|
||||||
@ -108,7 +108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.subTest(tranpose="tt"):
|
with self.subTest(transpose="tt"):
|
||||||
shape_a = (B, K, M)
|
shape_a = (B, K, M)
|
||||||
shape_b = (B, N, K)
|
shape_b = (B, N, K)
|
||||||
self.__gemm_test(
|
self.__gemm_test(
|
||||||
@ -191,7 +191,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
# Batched matmul with simple broadast
|
# Batched matmul with simple broadcast
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
c_npy = a_npy @ b_npy
|
c_npy = a_npy @ b_npy
|
||||||
@ -213,7 +213,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||||
|
|
||||||
# Batched and transposed matmul with simple broadast
|
# Batched and transposed matmul with simple broadcast
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||||
a_mlx = mx.array(a_npy)
|
a_mlx = mx.array(a_npy)
|
||||||
|
@ -749,7 +749,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||||
rope = nn.RoPE(4, **kwargs)
|
rope = nn.RoPE(4, **kwargs)
|
||||||
shape = (1, 3, 4)
|
shape = (1, 3, 4)
|
||||||
x = mx.random.uniform(shape=shape)
|
x = mx.random.uniform(shape=shape)
|
||||||
@ -864,6 +864,54 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
|
||||||
|
def test_dropout(self):
|
||||||
|
x = mx.ones((2, 4))
|
||||||
|
y = nn.Dropout(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4), dtype=mx.bfloat16)
|
||||||
|
y = nn.Dropout(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4), dtype=mx.float16)
|
||||||
|
y = nn.Dropout(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
def test_dropout2d(self):
|
||||||
|
x = mx.ones((2, 4, 4, 4))
|
||||||
|
y = nn.Dropout2d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
|
||||||
|
y = nn.Dropout2d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
|
||||||
|
y = nn.Dropout2d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
def test_dropout3d(self):
|
||||||
|
x = mx.ones((2, 4, 4, 4, 4))
|
||||||
|
y = nn.Dropout3d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
|
||||||
|
y = nn.Dropout3d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
|
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
|
||||||
|
y = nn.Dropout3d(0.5)(x)
|
||||||
|
self.assertTrue(y.shape, x.shape)
|
||||||
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -88,7 +88,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(a.dtype, mx.float32)
|
self.assertEqual(a.dtype, mx.float32)
|
||||||
self.assertEqual(a.item(), 3.0)
|
self.assertEqual(a.item(), 3.0)
|
||||||
|
|
||||||
# Check comibinations with mlx arrays
|
# Check combinations with mlx arrays
|
||||||
a = mx.add(mx.array(True), False)
|
a = mx.add(mx.array(True), False)
|
||||||
self.assertEqual(a.dtype, mx.bool_)
|
self.assertEqual(a.dtype, mx.bool_)
|
||||||
self.assertEqual(a.item(), True)
|
self.assertEqual(a.item(), True)
|
||||||
|
1
setup.py
1
setup.py
@ -5,7 +5,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import sysconfig
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import run
|
from subprocess import run
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ TEST_CASE("test arg reduce small") {
|
|||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
if (!metal::is_available()) {
|
if (!metal::is_available()) {
|
||||||
INFO("Skiping arg reduction gpu tests");
|
INFO("Skipping arg reduction gpu tests");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ TEST_CASE("test arg reduce small") {
|
|||||||
|
|
||||||
TEST_CASE("test arg reduce against cpu") {
|
TEST_CASE("test arg reduce against cpu") {
|
||||||
if (!metal::is_available()) {
|
if (!metal::is_available()) {
|
||||||
INFO("Skiping arg reduction gpu tests");
|
INFO("Skipping arg reduction gpu tests");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ void test_arg_reduce_small_bool(
|
|||||||
|
|
||||||
TEST_CASE("test arg reduce bool") {
|
TEST_CASE("test arg reduce bool") {
|
||||||
if (!metal::is_available()) {
|
if (!metal::is_available()) {
|
||||||
INFO("Skiping arg reduction gpu tests");
|
INFO("Skipping arg reduction gpu tests");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto x = array(
|
auto x = array(
|
||||||
@ -201,7 +201,7 @@ TEST_CASE("test arg reduce irregular strides") {
|
|||||||
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
|
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
|
||||||
|
|
||||||
if (!metal::is_available()) {
|
if (!metal::is_available()) {
|
||||||
INFO("Skiping arg reduction gpu tests");
|
INFO("Skipping arg reduction gpu tests");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -989,7 +989,7 @@ TEST_CASE("test as_strided grads") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test jvp from vjp") {
|
TEST_CASE("test jvp from vjp") {
|
||||||
// Unary elementwise ops
|
// Unary element-wise ops
|
||||||
{
|
{
|
||||||
auto x = random::uniform({5, 10});
|
auto x = random::uniform({5, 10});
|
||||||
eval(x);
|
eval(x);
|
||||||
@ -1022,7 +1022,7 @@ TEST_CASE("test jvp from vjp") {
|
|||||||
CHECK(compute_derivs(mlx::core::rsqrt));
|
CHECK(compute_derivs(mlx::core::rsqrt));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary elementwise ops
|
// Binary element-wise ops
|
||||||
{
|
{
|
||||||
auto x = random::uniform({5, 10});
|
auto x = random::uniform({5, 10});
|
||||||
auto y = random::uniform({5, 10});
|
auto y = random::uniform({5, 10});
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
TEST_CASE("test arange") {
|
TEST_CASE("test arange") {
|
||||||
// Check type is inferred correclty
|
// Check type is inferred correctly
|
||||||
{
|
{
|
||||||
auto x = arange(10);
|
auto x = arange(10);
|
||||||
CHECK_EQ(x.dtype(), int32);
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
@ -1411,7 +1411,7 @@ TEST_CASE("test broadcast") {
|
|||||||
x.eval();
|
x.eval();
|
||||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
|
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
|
||||||
|
|
||||||
// Broadcast on transposed arrray works
|
// Broadcast on transposed array works
|
||||||
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
|
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
|
||||||
x = broadcast_to(transpose(x), {2, 3, 2});
|
x = broadcast_to(transpose(x), {2, 3, 2});
|
||||||
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
|
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
|
||||||
@ -1733,7 +1733,7 @@ TEST_CASE("test scatter") {
|
|||||||
out = scatter(in, inds, updates, 0);
|
out = scatter(in, inds, updates, 0);
|
||||||
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
|
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
|
||||||
|
|
||||||
// Irregular strided index and reduce collison test
|
// Irregular strided index and reduce collision test
|
||||||
in = zeros({10}, float32);
|
in = zeros({10}, float32);
|
||||||
inds = broadcast_to(array(3), {10});
|
inds = broadcast_to(array(3), {10});
|
||||||
updates = ones({10, 1}, float32);
|
updates = ones({10, 1}, float32);
|
||||||
@ -1750,7 +1750,7 @@ TEST_CASE("test scatter") {
|
|||||||
out = scatter_max(array(1), {}, array(2), std::vector<int>{});
|
out = scatter_max(array(1), {}, array(2), std::vector<int>{});
|
||||||
CHECK_EQ(out.item<int>(), 2);
|
CHECK_EQ(out.item<int>(), 2);
|
||||||
|
|
||||||
// Irregularaly strided updates test
|
// Irregularly strided updates test
|
||||||
in = ones({3, 3});
|
in = ones({3, 3});
|
||||||
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
|
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
|
||||||
inds = array({0});
|
inds = array({0});
|
||||||
|
Loading…
Reference in New Issue
Block a user