Indexing Arrays#
+For the most part, indexing an MLX array
works the same as indexing a
+NumPy numpy.ndarray
. See the NumPy documentation for more details on
+how that works
diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo
index 0bdadc036..f0399f624 100644
--- a/docs/build/html/.buildinfo
+++ b/docs/build/html/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 78e86a9caf7acb193f064f97ea2f4572
+config: 38bd5d82efdab9011af8239531f26d1f
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst
index 9aae931a3..0a134e7f5 100644
--- a/docs/build/html/_sources/dev/extensions.rst
+++ b/docs/build/html/_sources/dev/extensions.rst
@@ -15,7 +15,7 @@ Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
-``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
+``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
@@ -69,7 +69,7 @@ C++ API:
.. code-block:: C++
/**
- * Scale and sum two vectors elementwise
+ * Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
This operation now handles the following:
-#. Upcast inputs and resolve the the output data type.
+#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
@@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
T alpha = static_cast
-
+
+
@@ -278,12 +278,14 @@
+
-
+
+
+
+
+
+
-
+
+
@@ -277,12 +277,14 @@
+
-
+
+
+
+
+
+
Let’s say that you would like an operation that takes in two arrays,
-x
and y
, scales them both by some coefficents alpha
and beta
+x
and y
, scales them both by some coefficients alpha
and beta
respectively, and then adds them together to get the result
z = alpha * x + beta * y
. Well, you can very easily do that by just
writing out a function as follows:
beta
. This is how we would define it in the
C++ API:
/**
-* Scale and sum two vectors elementwise
+* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -833,7 +871,7 @@ data type, shape, the
-Upcast inputs and resolve the the output data type.
+Upcast inputs and resolve the output data type.
Broadcast the inputs and resolve the output shape.
Construct the primitive Axpby
using the given stream, alpha
, and beta
.
Construct the output array
using the primitive and the inputs.
@@ -883,14 +921,14 @@ pointwise. This is captured in the templated function T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
- // Do the elementwise operation for each output
+ // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
- // (defaults to row major) and hence it doesn't need additonal mapping
+ // (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
@@ -902,7 +940,7 @@ for all incoming floating point arrays. Accordingly, we add dispatches for
if we encounter an unexpected type.
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
- // Check the inputs (registered in the op while contructing the out array)
+ // Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -1071,7 +1109,7 @@ each data type.
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
-instantiate_axpby(bflot16, bfloat16_t);
+instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
@@ -1120,7 +1158,7 @@ below.
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
- // those in the kernel decelaration at axpby.metal
+ // those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@@ -1151,7 +1189,7 @@ below.
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
- // Launch the grid with the given number of threads divded among
+ // Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@@ -1164,7 +1202,7 @@ to give us the active metal compute command encoder instead of building a
new one and calling compute_encoder->end_encoding()
at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
-for synchronization. MLX also handles enqueuing and commiting the associated
+for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
metal::Device
if you would like to study this routine further.
@@ -1180,8 +1218,8 @@ us the following const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
- // The jvp transform on the the primitive can built with ops
- // that are scheduled on the same stream as the primtive
+ // The jvp transform on the primitive can built with ops
+ // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@@ -1218,7 +1256,7 @@ us the following
Finally, you need not have a transformation fully defined to start using your
own Primitive
.
-/** Vectorize primitve along given axis */
+/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -1245,7 +1283,7 @@ own
extensions/axpby/
defines the C++ extension library
-extensions/mlx_sample_extensions
sets out the strucutre for the
+
extensions/mlx_sample_extensions
sets out the structure for the
associated python package
extensions/bindings.cpp
provides python bindings for our operation
extensions/CMakeLists.txt
holds CMake rules to build the library and
@@ -1272,7 +1310,7 @@ are already provided, adding our
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
- Scale and sum two vectors elementwise
+ Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
@@ -1405,7 +1443,7 @@ bindings and copied together if the package is installed
…
When you try to install using the command python -m pip install .
-(in extensions/
), the package will be installed with the same strucutre as
+(in extensions/
), the package will be installed with the same structure as
extensions/mlx_sample_extensions
and the C++ and metal library will be
copied along with the python binding since they are specified as package_data
.
@@ -1482,7 +1520,7 @@ with the naive
We see some modest improvements right away!
This operation is now good to be used to build other operations,
-in mlx.nn.Module
calls, and also as a part of graph
+in mlx.nn.Module
calls, and also as a part of graph
transformations such as grad()
and simplify()
!
diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html
index 54ac2b5ec..ced953d14 100644
--- a/docs/build/html/examples/linear_regression.html
+++ b/docs/build/html/examples/linear_regression.html
@@ -9,7 +9,7 @@
- Linear Regression — MLX 0.0.6 documentation
+ Linear Regression — MLX 0.0.7 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -278,12 +278,14 @@
mlx.core.quantize
mlx.core.quantized_matmul
mlx.core.reciprocal
+mlx.core.repeat
mlx.core.reshape
mlx.core.round
mlx.core.rsqrt
mlx.core.save
mlx.core.savez
mlx.core.savez_compressed
+mlx.core.save_safetensors
mlx.core.sigmoid
mlx.core.sign
mlx.core.sin
@@ -303,6 +305,7 @@
mlx.core.take_along_axis
mlx.core.tan
mlx.core.tanh
+mlx.core.tensordot
mlx.core.transpose
mlx.core.tri
mlx.core.tril
@@ -351,11 +354,35 @@
mlx.core.fft.irfftn
-Neural Networks
+- Linear Algebra
+
+- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Module
-- Layers
-- mlx.nn.Embedding
+- Module
+- mlx.nn.Module.training
+- mlx.nn.Module.apply
+- mlx.nn.Module.apply_to_modules
+- mlx.nn.Module.children
+- mlx.nn.Module.eval
+- mlx.nn.Module.filter_and_map
+- mlx.nn.Module.freeze
+- mlx.nn.Module.leaf_modules
+- mlx.nn.Module.load_weights
+- mlx.nn.Module.modules
+- mlx.nn.Module.named_modules
+- mlx.nn.Module.parameters
+- mlx.nn.Module.save_weights
+- mlx.nn.Module.train
+- mlx.nn.Module.trainable_parameters
+- mlx.nn.Module.unfreeze
+- mlx.nn.Module.update
+- mlx.nn.Module.update_modules
+
+
+- Layers
+- mlx.nn.Sequential
- mlx.nn.ReLU
- mlx.nn.PReLU
- mlx.nn.GELU
@@ -363,19 +390,27 @@
- mlx.nn.Step
- mlx.nn.SELU
- mlx.nn.Mish
+- mlx.nn.Embedding
- mlx.nn.Linear
+- mlx.nn.QuantizedLinear
- mlx.nn.Conv1d
- mlx.nn.Conv2d
+- mlx.nn.BatchNorm
- mlx.nn.LayerNorm
- mlx.nn.RMSNorm
- mlx.nn.GroupNorm
-- mlx.nn.RoPE
+- mlx.nn.InstanceNorm
+- mlx.nn.Dropout
+- mlx.nn.Dropout2d
+- mlx.nn.Dropout3d
+- mlx.nn.Transformer
- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.QuantizedLinear
+- mlx.nn.ALiBi
+- mlx.nn.RoPE
+- mlx.nn.SinusoidalPositionalEncoding
-- Functions
+- Functions
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
@@ -387,7 +422,7 @@
- mlx.nn.mish
-- Loss Functions
-- Optimizers
+- Optimizers
- mlx.optimizers.OptimizerState
- mlx.optimizers.Optimizer
- mlx.optimizers.SGD
@@ -413,7 +451,7 @@
- mlx.optimizers.Lion
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
diff --git a/docs/build/html/examples/llama-inference.html b/docs/build/html/examples/llama-inference.html
index afba02217..f26a3d212 100644
--- a/docs/build/html/examples/llama-inference.html
+++ b/docs/build/html/examples/llama-inference.html
@@ -9,7 +9,7 @@
- LLM inference — MLX 0.0.6 documentation
+ LLM inference — MLX 0.0.7 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -278,12 +278,14 @@
mlx.core.quantize
mlx.core.quantized_matmul
mlx.core.reciprocal
+mlx.core.repeat
mlx.core.reshape
mlx.core.round
mlx.core.rsqrt
mlx.core.save
mlx.core.savez
mlx.core.savez_compressed
+mlx.core.save_safetensors
mlx.core.sigmoid
mlx.core.sign
mlx.core.sin
@@ -303,6 +305,7 @@
mlx.core.take_along_axis
mlx.core.tan
mlx.core.tanh
+mlx.core.tensordot
mlx.core.transpose
mlx.core.tri
mlx.core.tril
@@ -351,11 +354,35 @@
mlx.core.fft.irfftn
-Neural Networks
+- Linear Algebra
+
+- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Module
-- Layers
-- mlx.nn.Embedding
+- Module
+- mlx.nn.Module.training
+- mlx.nn.Module.apply
+- mlx.nn.Module.apply_to_modules
+- mlx.nn.Module.children
+- mlx.nn.Module.eval
+- mlx.nn.Module.filter_and_map
+- mlx.nn.Module.freeze
+- mlx.nn.Module.leaf_modules
+- mlx.nn.Module.load_weights
+- mlx.nn.Module.modules
+- mlx.nn.Module.named_modules
+- mlx.nn.Module.parameters
+- mlx.nn.Module.save_weights
+- mlx.nn.Module.train
+- mlx.nn.Module.trainable_parameters
+- mlx.nn.Module.unfreeze
+- mlx.nn.Module.update
+- mlx.nn.Module.update_modules
+
+
+- Layers
+- mlx.nn.Sequential
- mlx.nn.ReLU
- mlx.nn.PReLU
- mlx.nn.GELU
@@ -363,19 +390,27 @@
- mlx.nn.Step
- mlx.nn.SELU
- mlx.nn.Mish
+- mlx.nn.Embedding
- mlx.nn.Linear
+- mlx.nn.QuantizedLinear
- mlx.nn.Conv1d
- mlx.nn.Conv2d
+- mlx.nn.BatchNorm
- mlx.nn.LayerNorm
- mlx.nn.RMSNorm
- mlx.nn.GroupNorm
-- mlx.nn.RoPE
+- mlx.nn.InstanceNorm
+- mlx.nn.Dropout
+- mlx.nn.Dropout2d
+- mlx.nn.Dropout3d
+- mlx.nn.Transformer
- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.QuantizedLinear
+- mlx.nn.ALiBi
+- mlx.nn.RoPE
+- mlx.nn.SinusoidalPositionalEncoding
-- Functions
+- Functions
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
@@ -387,7 +422,7 @@
- mlx.nn.mish
-- Loss Functions
-- Optimizers
+- Optimizers
- mlx.optimizers.OptimizerState
- mlx.optimizers.Optimizer
- mlx.optimizers.SGD
@@ -413,7 +451,7 @@
- mlx.optimizers.Lion
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
diff --git a/docs/build/html/examples/mlp.html b/docs/build/html/examples/mlp.html
index 99e83333b..c58dbaf32 100644
--- a/docs/build/html/examples/mlp.html
+++ b/docs/build/html/examples/mlp.html
@@ -9,7 +9,7 @@
- Multi-Layer Perceptron — MLX 0.0.6 documentation
+ Multi-Layer Perceptron — MLX 0.0.7 documentation
@@ -134,8 +134,8 @@
-
-
+
+
@@ -278,12 +278,14 @@
mlx.core.quantize
mlx.core.quantized_matmul
mlx.core.reciprocal
+mlx.core.repeat
mlx.core.reshape
mlx.core.round
mlx.core.rsqrt
mlx.core.save
mlx.core.savez
mlx.core.savez_compressed
+mlx.core.save_safetensors
mlx.core.sigmoid
mlx.core.sign
mlx.core.sin
@@ -303,6 +305,7 @@
mlx.core.take_along_axis
mlx.core.tan
mlx.core.tanh
+mlx.core.tensordot
mlx.core.transpose
mlx.core.tri
mlx.core.tril
@@ -351,11 +354,35 @@
mlx.core.fft.irfftn
-Neural Networks
+- Linear Algebra
+
+- Neural Networks
- mlx.nn.value_and_grad
-- mlx.nn.Module
-- Layers
-- mlx.nn.Embedding
+- Module
+- mlx.nn.Module.training
+- mlx.nn.Module.apply
+- mlx.nn.Module.apply_to_modules
+- mlx.nn.Module.children
+- mlx.nn.Module.eval
+- mlx.nn.Module.filter_and_map
+- mlx.nn.Module.freeze
+- mlx.nn.Module.leaf_modules
+- mlx.nn.Module.load_weights
+- mlx.nn.Module.modules
+- mlx.nn.Module.named_modules
+- mlx.nn.Module.parameters
+- mlx.nn.Module.save_weights
+- mlx.nn.Module.train
+- mlx.nn.Module.trainable_parameters
+- mlx.nn.Module.unfreeze
+- mlx.nn.Module.update
+- mlx.nn.Module.update_modules
+
+
+- Layers
+- mlx.nn.Sequential
- mlx.nn.ReLU
- mlx.nn.PReLU
- mlx.nn.GELU
@@ -363,19 +390,27 @@
- mlx.nn.Step
- mlx.nn.SELU
- mlx.nn.Mish
+- mlx.nn.Embedding
- mlx.nn.Linear
+- mlx.nn.QuantizedLinear
- mlx.nn.Conv1d
- mlx.nn.Conv2d
+- mlx.nn.BatchNorm
- mlx.nn.LayerNorm
- mlx.nn.RMSNorm
- mlx.nn.GroupNorm
-- mlx.nn.RoPE
+- mlx.nn.InstanceNorm
+- mlx.nn.Dropout
+- mlx.nn.Dropout2d
+- mlx.nn.Dropout3d
+- mlx.nn.Transformer
- mlx.nn.MultiHeadAttention
-- mlx.nn.Sequential
-- mlx.nn.QuantizedLinear
+- mlx.nn.ALiBi
+- mlx.nn.RoPE
+- mlx.nn.SinusoidalPositionalEncoding
-- Functions
+- Functions
- mlx.nn.gelu
- mlx.nn.gelu_approx
- mlx.nn.gelu_fast_approx
@@ -387,7 +422,7 @@
- mlx.nn.mish
-- Loss Functions
-- Optimizers
+- Optimizers
- mlx.optimizers.OptimizerState
- mlx.optimizers.Optimizer
- mlx.optimizers.SGD
@@ -413,7 +451,7 @@
- mlx.optimizers.Lion
-- Tree Utils
+- Tree Utils
- mlx.utils.tree_flatten
- mlx.utils.tree_unflatten
- mlx.utils.tree_map
@@ -605,11 +643,11 @@ multi-layer perceptron to classify MNIST.
The model is defined as the MLP
class which inherits from
-mlx.nn.Module
. We follow the standard idiom to make a new module:
mlx.nn.Module
. We follow the standard idiom to make a new module:
Define an __init__
where the parameters and/or submodules are setup. See
the Module class docs for more information on how
-mlx.nn.Module
registers parameters.
mlx.nn.Module
registers parameters.
Define a __call__
where the computation is implemented.
class MLP(nn.Module):
diff --git a/docs/build/html/genindex.html b/docs/build/html/genindex.html
index 3cfe9a7df..1f553269c 100644
--- a/docs/build/html/genindex.html
+++ b/docs/build/html/genindex.html
@@ -8,7 +8,7 @@
- Index — MLX 0.0.6 documentation
+ Index — MLX 0.0.7 documentation
@@ -131,8 +131,8 @@
-
-
+
+
@@ -685,6 +730,8 @@ document.write(` | + |
|
|
|
+ | + |