docs update

This commit is contained in:
Awni Hannun
2024-01-03 20:14:05 -08:00
committed by CircleCI Docs
parent 6f71a74c87
commit bd5f469bac
330 changed files with 40495 additions and 3726 deletions

View File

@@ -15,7 +15,7 @@ Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
@@ -69,7 +69,7 @@ C++ API:
.. code-block:: C++
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
This operation now handles the following:
#. Upcast inputs and resolve the the output data type.
#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
@@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
@@ -305,7 +305,7 @@ if we encounter an unexpected type.
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -485,7 +485,7 @@ each data type.
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
@@ -537,7 +537,7 @@ below.
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@@ -568,7 +568,7 @@ below.
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
@@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@@ -642,7 +642,7 @@ own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitve along given axis */
/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
@@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
@@ -840,7 +840,7 @@ This will result in a directory structure as follows:
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same strucutre as
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.

View File

@@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
@@ -57,6 +57,7 @@ are the CPU and GPU.
python/random
python/transforms
python/fft
python/linalg
python/nn
python/optimizers
python/tree_utils

12
docs/build/html/_sources/indexing.rst vendored Normal file
View File

@@ -0,0 +1,12 @@
.. _indexing:
Indexing Arrays
===============
.. currentmodule:: mlx.core
For the most part, indexing an MLX :obj:`array` works the same as indexing a
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
how that works

View File

@@ -63,6 +63,8 @@
~array.T
~array.dtype
~array.itemsize
~array.nbytes
~array.ndim
~array.shape
~array.size

View File

@@ -0,0 +1,6 @@
mlx.core.linalg.norm
====================
.. currentmodule:: mlx.core.linalg
.. autofunction:: norm

View File

@@ -0,0 +1,6 @@
mlx.core.repeat
===============
.. currentmodule:: mlx.core
.. autofunction:: repeat

View File

@@ -0,0 +1,6 @@
mlx.core.save\_safetensors
==========================
.. currentmodule:: mlx.core
.. autofunction:: save_safetensors

View File

@@ -0,0 +1,6 @@
mlx.core.tensordot
==================
.. currentmodule:: mlx.core
.. autofunction:: tensordot

View File

@@ -0,0 +1,11 @@
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm

View File

@@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
@@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
@@ -170,14 +170,13 @@ In detail:
:meth:`mlx.core.value_and_grad`
.. autosummary::
:recursive:
:toctree: _autosummary
value_and_grad
Module
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses

View File

@@ -0,0 +1,8 @@
mlx.nn.ALiBi
============
.. currentmodule:: mlx.nn
.. autoclass:: ALiBi

View File

@@ -0,0 +1,8 @@
mlx.nn.BatchNorm
================
.. currentmodule:: mlx.nn
.. autoclass:: BatchNorm

View File

@@ -0,0 +1,8 @@
mlx.nn.Dropout
==============
.. currentmodule:: mlx.nn
.. autoclass:: Dropout

View File

@@ -0,0 +1,8 @@
mlx.nn.Dropout2d
================
.. currentmodule:: mlx.nn
.. autoclass:: Dropout2d

View File

@@ -0,0 +1,8 @@
mlx.nn.Dropout3d
================
.. currentmodule:: mlx.nn
.. autoclass:: Dropout3d

View File

@@ -0,0 +1,8 @@
mlx.nn.InstanceNorm
===================
.. currentmodule:: mlx.nn
.. autoclass:: InstanceNorm

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.apply
===================
.. currentmodule:: mlx.nn
.. automethod:: Module.apply

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.apply\_to\_modules
================================
.. currentmodule:: mlx.nn
.. automethod:: Module.apply_to_modules

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.children
======================
.. currentmodule:: mlx.nn
.. automethod:: Module.children

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.eval
==================
.. currentmodule:: mlx.nn
.. automethod:: Module.eval

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.filter\_and\_map
==============================
.. currentmodule:: mlx.nn
.. automethod:: Module.filter_and_map

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.freeze
====================
.. currentmodule:: mlx.nn
.. automethod:: Module.freeze

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.leaf\_modules
===========================
.. currentmodule:: mlx.nn
.. automethod:: Module.leaf_modules

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.load\_weights
===========================
.. currentmodule:: mlx.nn
.. automethod:: Module.load_weights

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.modules
=====================
.. currentmodule:: mlx.nn
.. automethod:: Module.modules

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.named\_modules
============================
.. currentmodule:: mlx.nn
.. automethod:: Module.named_modules

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.parameters
========================
.. currentmodule:: mlx.nn
.. automethod:: Module.parameters

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.save\_weights
===========================
.. currentmodule:: mlx.nn
.. automethod:: Module.save_weights

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.train
===================
.. currentmodule:: mlx.nn
.. automethod:: Module.train

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.trainable\_parameters
===================================
.. currentmodule:: mlx.nn
.. automethod:: Module.trainable_parameters

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.training
======================
.. currentmodule:: mlx.nn
.. autoproperty:: Module.training

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.unfreeze
======================
.. currentmodule:: mlx.nn
.. automethod:: Module.unfreeze

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.update
====================
.. currentmodule:: mlx.nn
.. automethod:: Module.update

View File

@@ -0,0 +1,6 @@
mlx.nn.Module.update\_modules
=============================
.. currentmodule:: mlx.nn
.. automethod:: Module.update_modules

View File

@@ -0,0 +1,8 @@
mlx.nn.SinusoidalPositionalEncoding
===================================
.. currentmodule:: mlx.nn
.. autoclass:: SinusoidalPositionalEncoding

View File

@@ -0,0 +1,8 @@
mlx.nn.Transformer
==================
.. currentmodule:: mlx.nn
.. autoclass:: Transformer

View File

@@ -0,0 +1,8 @@
mlx.nn.losses.hinge\_loss
=========================
.. currentmodule:: mlx.nn.losses
.. autoclass:: hinge_loss

View File

@@ -0,0 +1,8 @@
mlx.nn.losses.huber\_loss
=========================
.. currentmodule:: mlx.nn.losses
.. autoclass:: huber_loss

View File

@@ -0,0 +1,8 @@
mlx.nn.losses.log\_cosh\_loss
=============================
.. currentmodule:: mlx.nn.losses
.. autoclass:: log_cosh_loss

View File

@@ -9,7 +9,7 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
Embedding
Sequential
ReLU
PReLU
GELU
@@ -17,13 +17,21 @@ Layers
Step
SELU
Mish
Embedding
Linear
QuantizedLinear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
RoPE
InstanceNorm
Dropout
Dropout2d
Dropout3d
Transformer
MultiHeadAttention
Sequential
QuantizedLinear
ALiBi
RoPE
SinusoidalPositionalEncoding

View File

@@ -16,4 +16,7 @@ Loss Functions
mse_loss
nll_loss
smooth_l1_loss
triplet_loss
triplet_loss
hinge_loss
huber_loss
log_cosh_loss

View File

@@ -0,0 +1,36 @@
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View File

@@ -77,12 +77,14 @@ Operations
quantize
quantized_matmul
reciprocal
repeat
reshape
round
rsqrt
save
savez
savez_compressed
save_safetensors
sigmoid
sign
sin
@@ -102,6 +104,7 @@ Operations
take_along_axis
tan
tanh
tensordot
transpose
tri
tril