mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
docs
This commit is contained in:
parent
aede70e81d
commit
fbd10a48d4
4
docs/build/html/.buildinfo
vendored
Normal file
4
docs/build/html/.buildinfo
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# Sphinx build info version 1
|
||||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
|
||||
config: 5b1e2769941ea186df86318a9fdcef20
|
||||
tags: 645f666f9bcd5a90fca523b33c5a78b7
|
6
docs/build/html/_sources/cpp/ops.rst.txt
vendored
Normal file
6
docs/build/html/_sources/cpp/ops.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
.. _cpp_ops:
|
||||
|
||||
Operations
|
||||
==========
|
||||
|
||||
|
948
docs/build/html/_sources/dev/extensions.rst.txt
vendored
Normal file
948
docs/build/html/_sources/dev/extensions.rst.txt
vendored
Normal file
@ -0,0 +1,948 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
|
||||
MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.
|
||||
|
||||
However, you work with vector math libraries often and realize that the
|
||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
|
||||
Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:
|
||||
|
||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||
* How to implement your own GPU implementation using metal.
|
||||
* How to add your own ``vjp`` and ``jvp``.
|
||||
* How to build your implementations, link them to MLX, and bind them to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let's start
|
||||
by discussing operations in more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||
operations in the Python API (:ref:`ops`).
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||
C++ API:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
|
||||
This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call :class:`Primitive`.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create an output given a set of input :class:`array` . Further,
|
||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself accross
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitives
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||
the computation graph. An :class:`array` can be constructed by providing its
|
||||
data type, shape, the :class:`Primitive` that computes it, and the
|
||||
:class:`array` inputs that are passed to the primitive.
|
||||
|
||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
auto out_shape = broadcasted_inputs[0].shape();
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
|
||||
This operation now handles the following:
|
||||
|
||||
#. Upcast inputs and resolve the the output data type.
|
||||
#. Broadcast the inputs and resolve the output shape.
|
||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||
stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed
|
||||
|
||||
Implementing the CPU Backend
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by trying to implement a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||
if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||
framework? Well, there are 3 complications to keep in mind:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only direct to it for ``float32`` types
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren't guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||
column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of ``y`` into the output array and use that as an input to ``axpby``
|
||||
|
||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||
|
||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
We have now hit a milestone! Just this much is enough to run the operation
|
||||
:meth:`axpby` on a CPU stream!
|
||||
|
||||
If you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Backend
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
all GPU kernels in MLX are written using metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to metal!
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void axpby_general(
|
||||
device const T* x [[buffer(0)]],
|
||||
device const T* y [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
will see later in :ref:`Building with CMake`. In the following example, we
|
||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||
the executable/ shared-library calling the :meth:`register_library` function.
|
||||
The :meth:`register_library` function takes the library's name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn't already been registered
|
||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now that we have come this far, let's also learn how to add implementations to
|
||||
transformations in a :class:`Primitive`. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Finally, you need not have a transformation fully defined to start using your
|
||||
own :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
| │ ├── axpby.cpp
|
||||
| │ ├── axpby.h
|
||||
| │ └── axpby.metal
|
||||
| ├── mlx_sample_extensions
|
||||
| │ └── __init__.py
|
||||
| ├── bindings.cpp
|
||||
| ├── CMakeLists.txt
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
y (array): Input array.
|
||||
alpha (float): Scaling factor for ``x``.
|
||||
beta (float): Scaling factor for ``y``.
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` needs to be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library itself is simple, it only requires that you
|
||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Add library
|
||||
add_library(mlx_ext)
|
||||
|
||||
# Add sources
|
||||
target_sources(
|
||||
mlx_ext
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||
)
|
||||
|
||||
# Add include headers
|
||||
target_include_directories(
|
||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice!
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx_ext
|
||||
mlx_ext_metallib
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the Pybind11_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages = ["mlx_sample_extensions"],
|
||||
package_dir = {"": "mlx_sample_extensions"},
|
||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
You can build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This will result in a directory structure as follows:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!
|
||||
|
||||
Let's looks at a simple script and it's results!
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block::
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 256
|
||||
N = 512
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval((x, y))
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(5000):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return e - s
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
Results:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||
|
||||
We see some modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations,
|
||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
||||
transformations such as :meth:`grad` and :meth:`simplify`!
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
|
||||
.. code: `TODO_LINK/extensions`_
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
77
docs/build/html/_sources/examples/linear_regression.rst.txt
vendored
Normal file
77
docs/build/html/_sources/examples/linear_regression.rst.txt
vendored
Normal file
@ -0,0 +1,77 @@
|
||||
.. _linear_regression:
|
||||
|
||||
Linear Regression
|
||||
-----------------
|
||||
|
||||
Let's implement a basic linear regression model as a starting point to
|
||||
learn MLX. First import the core package and setup some problem metadata:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000 # iterations of SGD
|
||||
lr = 0.01 # learning rate for SGD
|
||||
|
||||
|
||||
We'll generate a synthetic dataset by:
|
||||
|
||||
1. Sampling the design matrix ``X``.
|
||||
2. Sampling a ground truth parameter vector ``w_star``.
|
||||
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# True parameters
|
||||
w_star = mx.random.normal((num_features,))
|
||||
|
||||
# Input examples (design matrix)
|
||||
X = mx.random.normal((num_examples, num_features))
|
||||
|
||||
# Noisy labels
|
||||
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||
y = X @ w_star + eps
|
||||
|
||||
|
||||
We will use SGD to find the optimal weights. To start, define the squared loss
|
||||
and get the gradient function of the loss with respect to the parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(w):
|
||||
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
Start the optimization by initializing the parameters ``w`` randomly. Then
|
||||
repeatedly update the parameters for ``num_iters`` iterations.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
w = 1e-2 * mx.random.normal((num_features,))
|
||||
|
||||
for _ in range(num_iters):
|
||||
grad = grad_fn(w)
|
||||
w = w - lr * grad
|
||||
mx.eval(w)
|
||||
|
||||
Finally, compute the loss of the learned parameters and verify that they are
|
||||
close to the ground truth parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
loss = loss_fn(w)
|
||||
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||
|
||||
print(
|
||||
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||
)
|
||||
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
|
||||
|
||||
Complete `linear regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
|
||||
and `logistic regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
|
||||
examples are available in the MLX GitHub repo.
|
382
docs/build/html/_sources/examples/llama-inference.rst.txt
vendored
Normal file
382
docs/build/html/_sources/examples/llama-inference.rst.txt
vendored
Normal file
@ -0,0 +1,382 @@
|
||||
LLM inference
|
||||
==============
|
||||
|
||||
MLX enables efficient inference of large-ish transformers on Apple silicon
|
||||
without compromising on ease of use. In this example we will create an
|
||||
inference script for the Llama family of transformer models in which the model
|
||||
is defined in less than 200 lines of python.
|
||||
|
||||
Implementing the model
|
||||
----------------------
|
||||
|
||||
We will use the neural network building blocks defined in the :mod:`mlx.nn`
|
||||
module to concisely define the model architecture.
|
||||
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
||||
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
|
||||
:class:`mlx.nn.RoPE` for the positional encoding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
Encoder layer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
The other component of the Llama model is the encoder layer which uses RMS
|
||||
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
|
||||
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
Full model
|
||||
^^^^^^^^^^
|
||||
|
||||
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
|
||||
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.out_proj(x)
|
||||
|
||||
Note that in the implementation above we use a simple list to hold the encoder
|
||||
layers but using ``model.parameters()`` will still consider these layers.
|
||||
|
||||
Generation
|
||||
^^^^^^^^^^^
|
||||
|
||||
Our ``Llama`` module can be used for training but not inference as the
|
||||
``__call__`` method above processes one input, completely ignores the cache and
|
||||
performs no sampling whatsoever. In the rest of this subsection, we will
|
||||
implement the inference function as a python generator that processes the
|
||||
prompt and then autoregressively yields tokens one at a time.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Llama(nn.Module):
|
||||
...
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same way as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
cache.append(c) # <--- we store the per layer cache in a
|
||||
# simple python list
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
|
||||
# that generate the next token
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1/temp))
|
||||
|
||||
yield y
|
||||
|
||||
Putting it all together
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We now have everything we need to create a Llama model and sample tokens from
|
||||
it. In the following code, we randomly initialize a small Llama model, process
|
||||
6 tokens of prompt and generate 10 tokens.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
|
||||
|
||||
# Since MLX is lazily evaluated nothing has actually been materialized yet.
|
||||
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
|
||||
# code above would still run. Let's actually materialize the model.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
|
||||
# have a batch dimension even
|
||||
# though it is 1 in this case
|
||||
|
||||
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
|
||||
|
||||
# Since we haven't evaluated anything, nothing is computed yet. The list
|
||||
# `generated` contains the arrays that hold the computation graph for the
|
||||
# full processing of the prompt and the generation of 10 tokens.
|
||||
#
|
||||
# We can evaluate them one at a time, or all together. Concatenate them or
|
||||
# print them. They would all result in very similar runtimes and give exactly
|
||||
# the same results.
|
||||
mx.eval(generated)
|
||||
|
||||
Converting the weights
|
||||
----------------------
|
||||
|
||||
This section assumes that you have access to the original Llama weights and the
|
||||
SentencePiece model that comes with them. We will write a small script to
|
||||
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
|
||||
that can be loaded directly by MLX.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return key, value.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
)
|
||||
|
||||
|
||||
Weight loading and benchmarking
|
||||
-------------------------------
|
||||
|
||||
After converting the weights to be compatible to our implementation, all that is
|
||||
left is to load them from disk and we can finally use the LLM to generate text.
|
||||
We can load numpy format files using the :func:`mlx.core.load` operation.
|
||||
|
||||
To create a parameter dictionary from the key/value representation of NPZ files
|
||||
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
model.update(tree_unflatten(list(mx.load(weight_file).items())))
|
||||
|
||||
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
|
||||
like ``layers.2.attention.query_proj.weight`` and will transform them to
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
|
||||
|
||||
which can then be used to update the model. Note that the method above incurs
|
||||
several unnecessary copies from disk to numpy and then from numpy to MLX. It
|
||||
will be replaced in the future with direct loading to MLX.
|
||||
|
||||
You can download the full example code in `mlx-examples <code>`_. Assuming, the
|
||||
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
|
||||
directory we can play around with our inference script as follows (the timings
|
||||
are representative of an M1 Ultra and the 7B parameter Llama model):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python convert.py weights.pth llama-7B.mlx.npz
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
|
||||
------
|
||||
[INFO] Prompt processing: 0.437 s
|
||||
[INFO] Full generation: 4.330 s
|
||||
|
||||
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
|
||||
of those are spent processing the prompt. This amounts to a little over **39 ms
|
||||
per token**.
|
||||
|
||||
By running with a much bigger prompt we can see that the per token generation
|
||||
time as well as the prompt processing time remains almost constant.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.247 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
|
||||
------
|
||||
[INFO] Prompt processing: 0.579 s
|
||||
[INFO] Full generation: 4.690 s
|
||||
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
|
||||
[INFO] Loading model from disk: 5.628 s
|
||||
Press enter to start generation
|
||||
------
|
||||
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
|
||||
------
|
||||
[INFO] Prompt processing: 0.633 s
|
||||
[INFO] Full generation: 21.475 s
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
preprint arXiv:2104.09864.
|
||||
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
|
||||
Advances in Neural Information Processing Systems, 32.
|
||||
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
|
||||
arXiv:2002.05202.
|
131
docs/build/html/_sources/examples/mlp.rst.txt
vendored
Normal file
131
docs/build/html/_sources/examples/mlp.rst.txt
vendored
Normal file
@ -0,0 +1,131 @@
|
||||
.. _mlp:
|
||||
|
||||
Multi-Layer Perceptron
|
||||
----------------------
|
||||
|
||||
In this example we'll learn to use ``mlx.nn`` by implementing a simple
|
||||
multi-layer perceptron to classify MNIST.
|
||||
|
||||
As a first step import the MLX packages we need:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
The model is defined as the ``MLP`` class which inherits from
|
||||
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
|
||||
|
||||
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
|
||||
the :ref:`Module class docs<module_class>` for more information on how
|
||||
:class:`mlx.nn.Module` registers parameters.
|
||||
2. Define a ``__call__`` where the computation is implemented.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
We define the loss function which takes the mean of the per-example cross
|
||||
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
|
||||
commonly used loss functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
|
||||
We also need a function to compute the accuracy of the model on the validation
|
||||
set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
Next, setup the problem parameters and load the data:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
import mnist
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, mnist.mnist()
|
||||
)
|
||||
|
||||
Since we're using SGD, we need an iterator which shuffles and constructs
|
||||
minibatches of examples in the training set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
Finally, we put it all together by instantiating the model, the
|
||||
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Get a function which gives the loss and gradient of the
|
||||
# loss with respect to the model's trainable parameters
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Instantiate the optimizer
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
|
||||
# Update the optimizer state and model parameters
|
||||
# in a single call
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Force a graph evaluation
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
|
||||
|
||||
|
||||
.. note::
|
||||
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
|
||||
the gradient of a loss with respect to the trainable parameters of a model.
|
||||
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||
|
||||
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
|
||||
is available in the MLX GitHub repo.
|
49
docs/build/html/_sources/index.rst.txt
vendored
Normal file
49
docs/build/html/_sources/index.rst.txt
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
MLX
|
||||
===
|
||||
|
||||
.. toctree::
|
||||
:caption: Install
|
||||
:maxdepth: 1
|
||||
|
||||
install
|
||||
|
||||
.. toctree::
|
||||
:caption: Usage
|
||||
:maxdepth: 1
|
||||
|
||||
quick_start
|
||||
using_streams
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
:maxdepth: 1
|
||||
|
||||
examples/linear_regression
|
||||
examples/mlp
|
||||
examples/llama-inference
|
||||
|
||||
.. toctree::
|
||||
:caption: Further Reading
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
|
||||
.. toctree::
|
||||
:caption: Python API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fft
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
:caption: C++ API Reference
|
||||
:maxdepth: 1
|
||||
|
||||
cpp/ops
|
102
docs/build/html/_sources/install.rst.txt
vendored
Normal file
102
docs/build/html/_sources/install.rst.txt
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
Install from PyPI
|
||||
-----------------
|
||||
|
||||
MLX is available at Apple's internal PyPI repository. All you have to do to use
|
||||
MLX with your own Apple silicon computer is
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install apple-mlx -i https://pypi.apple.com/simple
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
|
||||
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
|
||||
Then simply build and install it using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
|
||||
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
by cloning MLX from `its GitHub repo
|
||||
<https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Create a build directory and run CMake and make:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make test
|
||||
|
||||
Install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
make install
|
||||
|
||||
Note that the built ``mlx.metallib`` file should be either at the same
|
||||
directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
* - Option
|
||||
- Default
|
||||
* - MLX_BUILD_TESTS
|
||||
- ON
|
||||
* - MLX_BUILD_EXAMPLES
|
||||
- OFF
|
||||
* - MLX_BUILD_BENCHMARKS
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
28
docs/build/html/_sources/python/_autosummary/mlx.core.Device.rst.txt
vendored
Normal file
28
docs/build/html/_sources/python/_autosummary/mlx.core.Device.rst.txt
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
mlx.core.Device
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: Device
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Device.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Device.type
|
||||
|
||||
|
28
docs/build/html/_sources/python/_autosummary/mlx.core.Dtype.rst.txt
vendored
Normal file
28
docs/build/html/_sources/python/_autosummary/mlx.core.Dtype.rst.txt
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
mlx.core.Dtype
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: Dtype
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Dtype.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Dtype.size
|
||||
|
||||
|
28
docs/build/html/_sources/python/_autosummary/mlx.core.Stream.rst.txt
vendored
Normal file
28
docs/build/html/_sources/python/_autosummary/mlx.core.Stream.rst.txt
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
mlx.core.Stream
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: Stream
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Stream.__init__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~Stream.device
|
||||
|
||||
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.abs.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.abs.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.abs
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: abs
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.add.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.add.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.add
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: add
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.all.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.all.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.all
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: all
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.allclose.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.allclose.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.allclose
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: allclose
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.any.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.any.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.any
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: any
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arange.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arange.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arange
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arange
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccos.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccos.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arccos
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arccos
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccosh.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arccosh.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arccosh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arccosh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsin.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsin.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arcsin
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arcsin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsinh.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arcsinh.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arcsinh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arcsinh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctan.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arctan
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arctan
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctanh.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.arctanh.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.arctanh
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: arctanh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmax.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmax.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.argmax
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argmax
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmin.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argmin.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.argmin
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argmin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argpartition.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argpartition.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.argpartition
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argpartition
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.argsort.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.argsort.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.argsort
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: argsort
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.T.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.T.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.T
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.T
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.abs.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.abs.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.abs
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.abs
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.all.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.all.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.all
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.all
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.any.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.any.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.any
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.any
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmax.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmax.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.argmax
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.argmax
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmin.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.argmin.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.argmin
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.argmin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.astype.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.astype.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.astype
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.astype
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cos.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.cos.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.cos
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.cos
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.dtype.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.dtype.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.dtype
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.dtype
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.exp.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.exp.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.exp
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.exp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.item.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.item.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.item
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.item
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.log
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log1p.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.log1p.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.log1p
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.log1p
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logsumexp.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.logsumexp.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.logsumexp
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.logsumexp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.max.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.max.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.max
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.max
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.mean.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.mean.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.mean
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.mean
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.min.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.min.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.min
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.min
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.ndim.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.ndim.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.ndim
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.ndim
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.prod.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.prod.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.prod
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.prod
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reciprocal.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reciprocal.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.reciprocal
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.reciprocal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reshape.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.reshape.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.reshape
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.reshape
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.rsqrt.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.rsqrt.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.rsqrt
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.rsqrt
|
66
docs/build/html/_sources/python/_autosummary/mlx.core.array.rst.txt
vendored
Normal file
66
docs/build/html/_sources/python/_autosummary/mlx.core.array.rst.txt
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
mlx.core.array
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoclass:: array
|
||||
|
||||
|
||||
.. automethod:: __init__
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~array.__init__
|
||||
~array.abs
|
||||
~array.all
|
||||
~array.any
|
||||
~array.argmax
|
||||
~array.argmin
|
||||
~array.astype
|
||||
~array.cos
|
||||
~array.cummax
|
||||
~array.cummin
|
||||
~array.cumprod
|
||||
~array.cumsum
|
||||
~array.exp
|
||||
~array.item
|
||||
~array.log
|
||||
~array.log10
|
||||
~array.log1p
|
||||
~array.log2
|
||||
~array.logsumexp
|
||||
~array.max
|
||||
~array.mean
|
||||
~array.min
|
||||
~array.prod
|
||||
~array.reciprocal
|
||||
~array.reshape
|
||||
~array.rsqrt
|
||||
~array.sin
|
||||
~array.split
|
||||
~array.sqrt
|
||||
~array.square
|
||||
~array.squeeze
|
||||
~array.sum
|
||||
~array.tolist
|
||||
~array.transpose
|
||||
~array.var
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
|
||||
~array.T
|
||||
~array.dtype
|
||||
~array.ndim
|
||||
~array.shape
|
||||
~array.size
|
||||
|
||||
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.shape.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.shape.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.shape
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.shape
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sin.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sin.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.sin
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sin
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.size.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.size.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.size
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autoproperty:: array.size
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.split.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.split.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.split
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.split
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sqrt.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sqrt.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.sqrt
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sqrt
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.square.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.square.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.square
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.square
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sum.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.sum.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.sum
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.sum
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.tolist.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.tolist.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.tolist
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.tolist
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.transpose.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.transpose.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.transpose
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.transpose
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.var.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array.var.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array.var
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. automethod:: array.var
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.array_equal.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.array_equal.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.array\_equal
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: array_equal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.broadcast_to.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.broadcast_to.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.broadcast\_to
|
||||
======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: broadcast_to
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.concatenate.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.concatenate.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.concatenate
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: concatenate
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.conv1d.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.conv1d.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.conv1d
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: conv1d
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.conv2d.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.conv2d.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.conv2d
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: conv2d
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.convolve.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.convolve.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.convolve
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: convolve
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.cos.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.cos.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.cos
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: cos
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.cosh.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.cosh.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.cosh
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: cosh
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.default_device.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.default_device.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.default\_device
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: default_device
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.default_stream.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.default_stream.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.default\_stream
|
||||
========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: default_stream
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.divide.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.divide.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.divide
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: divide
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.equal.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.equal.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.equal
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: equal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.erf.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.erf.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.erf
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: erf
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.erfinv.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.erfinv.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.erfinv
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: erfinv
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.eval.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.eval.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.eval
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: eval
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.exp.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.exp.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.exp
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: exp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.expand_dims.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.expand_dims.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.expand\_dims
|
||||
=====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: expand_dims
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fft.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fft.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.fft
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: fft
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fft2.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fft2.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.fft2
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: fft2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fftn.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.fftn.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.fftn
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: fftn
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifft.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifft.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.ifft
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: ifft
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifft2.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifft2.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.ifft2
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: ifft2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifftn.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.ifftn.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.ifftn
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: ifftn
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfft.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfft.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.irfft
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: irfft
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfft2.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfft2.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.irfft2
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: irfft2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfftn.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.irfftn.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.irfftn
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: irfftn
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfft.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfft.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.rfft
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: rfft
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfft2.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfft2.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.rfft2
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: rfft2
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfftn.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.fft.rfftn.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.fft.rfftn
|
||||
==================
|
||||
|
||||
.. currentmodule:: mlx.core.fft
|
||||
|
||||
.. autofunction:: rfftn
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.full.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.full.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.full
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: full
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.grad.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.grad.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.grad
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: grad
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.greater.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.greater.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.greater
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: greater
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.greater_equal.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.greater_equal.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.greater\_equal
|
||||
=======================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: greater_equal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.jvp.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.jvp.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.jvp
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: jvp
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.less.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.less.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.less
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: less
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.less_equal.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.less_equal.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.less\_equal
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: less_equal
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.load.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.load.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.load
|
||||
=============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: load
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.log.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.log.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.log
|
||||
============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: log
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.log10.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.log10.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.log10
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: log10
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.log1p.rst.txt
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.log1p.rst.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
mlx.core.log1p
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: log1p
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user