step (float or int, optional) – Increment which defaults to 1.
-
dtype (Dtype, optional) – Specifies the data type of the output.
-If unspecified will default to float32 if any of start,
-stop, or step are float. Otherwise will default to
-int32.
-
+
+
Args:
start (float or int, optional): Starting value which defaults to 0.
+stop (float or int): Stopping value.
+step (float or int, optional): Increment which defaults to 1.
+dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to float32 if any of start, stop, or step are float. Otherwise will default to int32.
Following the Numpy convention the actual increment used to
+
Note:
Following the Numpy convention the actual increment used to
generate numbers is dtype(start+step)-dtype(start).
This can lead to unexpected results for example if start + step
is a fractional value and the dtype is integral.
kth (int) – Element index at the kth position in the output will
+
kth (int) – Element index at the kth position in the output will
give the sorted position. All indices before the kth position
will be of elements less or equal to the element at the kth
index and all indices after will be of elements greater or equal
to the element at the kth index.
-
axis (int or None, optional) – Optional axis to partition over.
+
axis (int or None, optional) – Optional axis to partition over.
If None, this partitions over the flattened array.
If unspecified, it defaults to -1.
axis (int or None, optional) – Optional axis to sort over.
+
axis (int or None, optional) – Optional axis to sort over.
If None, this sorts over the flattened array.
If unspecified, it defaults to -1 (sorting over the last axis).
Compare two arrays for equality. Returns True if and only if the arrays
have the same shape and their values are equal. The arrays need not have
@@ -881,7 +883,7 @@ the same type to be considered equal.
Create a view into the array with the given shape and strides.
The resulting array will always be as if the provided array was row
contiguous regardless of the provided arrays storage order and current
@@ -887,12 +889,12 @@ result into crashes.
Take the bitwise exclusive or of two arrays with numpy-style
broadcasting semantics. Either or both input arrays can also be
diff --git a/docs/build/html/python/_autosummary/mlx.core.block_masked_mm.html b/docs/build/html/python/_autosummary/mlx.core.block_masked_mm.html
index 4dc4b70fc..909acc9a1 100644
--- a/docs/build/html/python/_autosummary/mlx.core.block_masked_mm.html
+++ b/docs/build/html/python/_autosummary/mlx.core.block_masked_mm.html
@@ -8,7 +8,7 @@
-
Clip the values of the array between the given minimum and maximum.
If either a_min or a_max are None, then corresponding edge
is ignored. At least one of a_min and a_max cannot be None.
diff --git a/docs/build/html/python/_autosummary/mlx.core.compile.html b/docs/build/html/python/_autosummary/mlx.core.compile.html
index 345fe18b8..b28ccd89b 100644
--- a/docs/build/html/python/_autosummary/mlx.core.compile.html
+++ b/docs/build/html/python/_autosummary/mlx.core.compile.html
@@ -8,7 +8,7 @@
-
Returns a compiled function which produces the same output as fun.
Parameters:
@@ -879,17 +881,17 @@ document.write(`
fun (Callable) – A function which takes a variable number of
array or trees of array and returns
a variable number of array or trees of array.
-
inputs (list or dict, optional) – These inputs will be captured during
+
inputs (list or dict, optional) – These inputs will be captured during
the function compilation along with the inputs to fun. The inputs
-can be a list or a dict containing arbitrarily nested
+can be a list or a dict containing arbitrarily nested
lists, dictionaries, or arrays. Leaf nodes that are not
array are ignored. Default: None
-
outputs (list or dict, optional) – These outputs will be captured and
+
outputs (list or dict, optional) – These outputs will be captured and
updated in a compiled function. The outputs can be a
-list or a dict containing arbitrarily nested lists,
+list or a dict containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not array are ignored.
Default: None
-
shapeless (bool, optional) – A function compiled with the shapeless
+
shapeless (bool, optional) – A function compiled with the shapeless
option enabled will not be recompiled when the input shape changes. Not all
functions can be compiled with shapeless enabled. Attempting to compile
such functions with shapeless enabled will throw. Note, changing the number
diff --git a/docs/build/html/python/_autosummary/mlx.core.concatenate.html b/docs/build/html/python/_autosummary/mlx.core.concatenate.html
index 213bbb104..b4836e3a5 100644
--- a/docs/build/html/python/_autosummary/mlx.core.concatenate.html
+++ b/docs/build/html/python/_autosummary/mlx.core.concatenate.html
@@ -8,7 +8,7 @@
-
stride (int or tuple(int), optional) – tuple of size 2 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: 1.
padding (int or tuple(int), optional) – tuple of size 2 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: 0.
dilation (int or tuple(int), optional) – tuple of size 2 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
groups (int, optional) – input feature groups. Default: 1.
+
groups (int, optional) – input feature groups. Default: 1.
stride (int or tuple(int), optional) – tuple of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: 1.
padding (int or tuple(int), optional) – tuple of size 3 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: 0.
dilation (int or tuple(int), optional) – tuple of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
groups (int, optional) – input feature groups. Default: 1.
+
groups (int, optional) – input feature groups. Default: 1.
General convolution over an input with several channels
Parameters:
input (array) – Input array of shape (N,...,C_in).
weight (array) – Weight array of shape (C_out,...,C_in).
-
stride (int or list(int), optional) – list with kernel strides.
+
stride (int or list(int), optional) – list with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: 1.
padding (int, list(int), or tuple(list(int), list(int)), optional) – list with input padding. All spatial dimensions get the same
padding if only one number is specified. Default: 0.
-
kernel_dilation (int or list(int), optional) – list with
+
kernel_dilation (int or list(int), optional) – list with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
input_dilation (int or list(int), optional) – list with
+
input_dilation (int or list(int), optional) – list with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
groups (int, optional) – Input feature groups. Default: 1.
-
flip (bool, optional) – Flip the order in which the spatial dimensions of
+
groups (int, optional) – Input feature groups. Default: 1.
+
flip (bool, optional) – Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when
flip is False and the convolution operator otherwise.
Default: False.
stride (int or tuple(int), optional) – tuple of size 2 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: 1.
padding (int or tuple(int), optional) – tuple of size 2 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: 0.
dilation (int or tuple(int), optional) – tuple of size 2 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
groups (int, optional) – input feature groups. Default: 1.
+
groups (int, optional) – input feature groups. Default: 1.
stride (int or tuple(int), optional) – tuple of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: 1.
padding (int or tuple(int), optional) – tuple of size 3 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: 0.
dilation (int or tuple(int), optional) – tuple of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: 1
-
groups (int, optional) – input feature groups. Default: 1.
+
groups (int, optional) – input feature groups. Default: 1.
Extract a diagonal or construct a diagonal matrix.
If a is 1-D then a diagonal matrix is constructed with a on the
\(k\)-th diagonal. If a is 2-D then the \(k\)-th diagonal is
@@ -881,7 +883,7 @@ returned.
Initialize the communication backend and create the global communication group.
Parameters:
-
strict (bool, optional) – If set to False it returns a singleton group
+
strict (bool, optional) – If set to False it returns a singleton group
in case mx.distributed.is_available() returns False otherwise
it throws a runtime error. Default: False
The fuction divmod(a,b) is equivalent to but faster than
(a//b,a%b). The function uses numpy-style broadcasting
@@ -887,7 +889,7 @@ semantics. Either or both input arrays can also be scalars.
*args (arrays or trees of arrays) – Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
-list, tuple or dict. Leaves which are not
+list, tuple or dict. Leaves which are not
arrays are ignored.
The normalization is with respect to the last axis of the input x.
@@ -884,7 +886,7 @@ as the last axis of
bias (array, optional) – An additive offset to be added to the result.
The bias should be one-dimensional with the same size
as the last axis of x. If set to None then no translation happens.
-
eps (float) – A small additive constant for numerical stability.
+
eps (float) – A small additive constant for numerical stability.
input_names (List[str]) – The parameter names of the inputs in the
function signature.
-
output_names (List[str]) – The parameter names of the outputs in the
+
output_names (List[str]) – The parameter names of the outputs in the
function signature.
-
source (str) – Source code. This is the body of a function in Metal,
+
source (str) – Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
-
header (str) – Header source code to include before the main function.
+
header (str) – Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
-
ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous
+
ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous
before the kernel runs. Default: True.
-
atomic_outputs (bool) – Whether to use atomic outputs in the function signature
+
atomic_outputs (bool) – Whether to use atomic outputs in the function signature
e.g. deviceatomic<float>. Default: False.
dims (int) – The feature dimensions to be rotated. If the input feature
+
dims (int) – The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
-
traditional (bool) – If set to True choose the traditional
+
traditional (bool) – If set to True choose the traditional
implementation which rotates consecutive dimensions.
-
base (float, optional) – The base used to compute angular frequency for
+
base (float, optional) – The base used to compute angular frequency for
each dimension in the positional encodings. Exactly one of base and
freqs must be None.
-
scale (float) – The scale used to scale the positions.
n (int, optional) – Size of the transformed axis. The
+
n (int, optional) – Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match n. The default value is a.shape[axis].
-
axis (int, optional) – Axis along which to perform the FFT. The
+
axis (int, optional) – Axis along which to perform the FFT. The
default is -1.
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is [-2,-1].
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is None in which case the FFT is over the last
len(s) axes are or all axes if s is also None.
n (int, optional) – Size of the transformed axis. The
+
n (int, optional) – Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match n. The default value is a.shape[axis].
-
axis (int, optional) – Axis along which to perform the FFT. The
+
axis (int, optional) – Axis along which to perform the FFT. The
default is -1.
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is [-2,-1].
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is None in which case the FFT is over the last
len(s) axes or all axes if s is also None.
n (int, optional) – Size of the transformed axis. The
+
n (int, optional) – Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match n//2+1. The default value is
a.shape[axis]//2+1.
-
axis (int, optional) – Axis along which to perform the FFT. The
+
axis (int, optional) – Axis along which to perform the FFT. The
default is -1.
Note the input is generally complex. The dimensions of the input
specified in axes are padded or truncated to match the sizes
@@ -881,12 +883,12 @@ and will have size
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s except for the last axis
which has size s[-1]//2+1. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is [-2,-1].
Note the input is generally complex. The dimensions of the input
specified in axes are padded or truncated to match the sizes
@@ -881,11 +883,11 @@ and will have size
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is None in which case the FFT is over the last
len(s) axes or all axes if s is also None.
One dimensional discrete Fourier Transform on a real input.
The output has the same shape as the input except along axis in
which case it has size n//2+1.
@@ -880,10 +882,10 @@ which case it has size
a (array) – The input array. If the array is complex it will be silently
cast to a real type.
-
n (int, optional) – Size of the transformed axis. The
+
n (int, optional) – Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match n. The default value is a.shape[axis].
-
axis (int, optional) – Axis along which to perform the FFT. The
+
axis (int, optional) – Axis along which to perform the FFT. The
default is -1.
The output has the same shape as the input except along the dimensions in
axes in which case it has sizes from s. The last axis in axes is
@@ -881,11 +883,11 @@ treated as the real axis and will have size
a (array) – The input array. If the array is complex it will be silently
cast to a real type.
-
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is [-2,-1].
The output has the same shape as the input except along the dimensions in
axes in which case it has sizes from s. The last axis in axes is
@@ -881,11 +883,11 @@ treated as the real axis and will have size
a (array) – The input array. If the array is complex it will be silently
cast to a real type.
-
s (list(int), optional) – Sizes of the transformed axes. The
+
s (list(int), optional) – Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in s. The default value is the
sizes of a along axes.
-
axes (list(int), optional) – Axes along which to perform the FFT.
+
axes (list(int), optional) – Axes along which to perform the FFT.
The default is None in which case the FFT is over the last
len(s) axes or all axes if s is also None.
The axes flattened will be between start_axis and end_axis,
inclusive. Negative axes are supported. After converting negative axis to
@@ -881,8 +883,8 @@ positive, axes outside the valid range will be clamped to a valid value,
Performs a gather of the operands with the given indices followed by a
(possibly batched) matrix multiplication of two arrays. This operation
diff --git a/docs/build/html/python/_autosummary/mlx.core.gather_qmm.html b/docs/build/html/python/_autosummary/mlx.core.gather_qmm.html
index 871df93a5..39126cb9b 100644
--- a/docs/build/html/python/_autosummary/mlx.core.gather_qmm.html
+++ b/docs/build/html/python/_autosummary/mlx.core.gather_qmm.html
@@ -8,7 +8,7 @@
-
Returns a function which computes the gradient of fun.
Parameters:
@@ -879,12 +881,12 @@ document.write(`
fun (Callable) – A function which takes a variable number of
array or trees of array and returns
a scalar output array.
-
argnums (int or list(int), optional) – Specify the index (or indices)
+
argnums (int or list(int), optional) – Specify the index (or indices)
of the positional arguments of fun to compute the gradient
with respect to. If neither argnums nor argnames are
provided argnums defaults to 0 indicating fun’s first
argument.
-
argnames (str or list(str), optional) – Specify keyword arguments of
+
argnames (str or list(str), optional) – Specify keyword arguments of
fun to compute gradients with respect to. It defaults to [] so
no gradients for keyword arguments by default.
primals (list(array)) – A list of array at which to
evaluate the Jacobian.
-
tangents (list(array)) – A list of array which are the
+
tangents (list(array)) – A list of array which are the
“vector” in the Jacobian-vector product. The tangents should be the
same in number, shape, and type as the inputs of fun (i.e. the primals).
@@ -892,7 +894,7 @@ same in number, shape, and type as the inputs of fun.
Shift the bits of the first input to the left by the second using
numpy-style broadcasting semantics. Either or both input arrays can
diff --git a/docs/build/html/python/_autosummary/mlx.core.less.html b/docs/build/html/python/_autosummary/mlx.core.less.html
index 79f405ced..4f2ab6ff4 100644
--- a/docs/build/html/python/_autosummary/mlx.core.less.html
+++ b/docs/build/html/python/_autosummary/mlx.core.less.html
@@ -8,7 +8,7 @@
-
Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the Cholesky decomposition is computed for each matrix
@@ -881,7 +883,7 @@ in the last two dimensions of Parameters:
Compute the cross product of two arrays along a specified axis.
The cross product is defined for arrays with size 2 or 3 in the
specified axis. If the size is 2 then the third value is assumed
@@ -881,7 +883,7 @@ to be zero.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the inverse is computed for each matrix
diff --git a/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html b/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
index 9024683e5..834e4bb6b 100644
--- a/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
+++ b/docs/build/html/python/_autosummary/mlx.core.linalg.norm.html
@@ -8,7 +8,7 @@
-
This function computes vector or matrix norms depending on the value of
the ord and axis parameters.
@@ -882,16 +884,16 @@ the ord
a (array) – Input array. If axis is None, a must be 1-D or 2-D,
unless ord is None. If both axis and ord are None, the
2-norm of a.flatten will be returned.
-
ord (int, float or str, optional) – Order of the norm (see table under Notes).
+
ord (int, float or str, optional) – Order of the norm (see table under Notes).
If None, the 2-norm (or Frobenius norm for matrices) will be computed
along the given axis. Default: None.
-
axis (int or list(int), optional) – If axis is an integer, it specifies the
+
axis (int or list(int), optional) – If axis is an integer, it specifies the
axis of a along which to compute the vector norms. If axis is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
norms of these matrices are computed. If axis is None then
either a vector norm (when a is 1-D) or a matrix norm (when a is
2-D) is returned. Default: None.
-
keepdims (bool, optional) – If True, the axes which are normed over are
+
keepdims (bool, optional) – If True, the axes which are normed over are
left in the result as dimensions with size one. Default False.
This function supports arrays with at least 2 dimensions. The matrices
which are factorized are assumed to be in the last two dimensions of
@@ -888,7 +890,7 @@ in which case the default stream of the default device is used.
The Singular Value Decomposition (SVD) of the input matrix.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the function iterates over all indices of the first
@@ -889,7 +891,7 @@ in which case the default stream of the default device is used.
Compute the inverse of a triangular square matrix.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the inverse is computed for each matrix
diff --git a/docs/build/html/python/_autosummary/mlx.core.linspace.html b/docs/build/html/python/_autosummary/mlx.core.linspace.html
index 2965bd13d..cba457235 100644
--- a/docs/build/html/python/_autosummary/mlx.core.linspace.html
+++ b/docs/build/html/python/_autosummary/mlx.core.linspace.html
@@ -8,7 +8,7 @@
-
The supported formats are .npy, .npz, .safetensors, and
.gguf.
Parameters:
-
file (file, str) – File in which the array is saved.
-
format (str, optional) – Format of the file. If None, the
+
file (file, str) – File in which the array is saved.
+
format (str, optional) – Format of the file. If None, the
format is inferred from the file extension. Supported formats:
npy, npz, and safetensors. Default: None.
-
return_metadata (bool, optional) – Load the metadata for formats
+
return_metadata (bool, optional) – Load the metadata for formats
which support matadata. The metadata will be returned as an
additional dictionary. Default: False.
@@ -894,7 +896,7 @@ mapping names to arrays if loading from a Return type:
-
sparse (bool, optional) – If True, a sparse grid is returned in which each output
+
sparse (bool, optional) – If True, a sparse grid is returned in which each output
array has a single non-zero element. If False, a dense grid is returned.
Defaults to False.
-
indexing (str, optional) – Cartesian (‘xy’) or matrix (‘ij’) indexing of the output arrays.
+
indexing (str, optional) – Cartesian (‘xy’) or matrix (‘ij’) indexing of the output arrays.
Defaults to 'xy'.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
@@ -880,13 +882,13 @@ the limit to 0set_memory_limit() for more details.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
@@ -882,8 +884,8 @@ size reported by the device.
nan (float, optional) – Value to replace NaN with. Default: 0.
-
posinf (float, optional) – Value to replace positive infinities
+
nan (float, optional) – Value to replace NaN with. Default: 0.
+
posinf (float, optional) – Value to replace positive infinities
with. If None, defaults to largest finite value for the
given data type. Default: None.
-
neginf (float, optional) – Value to replace negative infinities
+
neginf (float, optional) – Value to replace negative infinities
with. If None, defaults to the negative of the largest
finite value for the given data type. Default: None.
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))) – Number of padded
values to add to the edges of each axis:((before_1,after_1),(before_2,after_2),...,(before_N,after_N)). If a single pair
of integers is passed then (before_i,after_i) are all the same.
diff --git a/docs/build/html/python/_autosummary/mlx.core.partition.html b/docs/build/html/python/_autosummary/mlx.core.partition.html
index 2a3cc9676..276abab7c 100644
--- a/docs/build/html/python/_autosummary/mlx.core.partition.html
+++ b/docs/build/html/python/_autosummary/mlx.core.partition.html
@@ -8,7 +8,7 @@
-
kth (int) – Element at the kth index will be in its sorted
+
kth (int) – Element at the kth index will be in its sorted
position in the output. All elements before the kth index will
be less or equal to the kth element and all elements after
will be greater or equal to the kth element in the output.
-
axis (int or None, optional) – Optional axis to partition over.
+
axis (int or None, optional) – Optional axis to partition over.
If None, this partitions over the flattened array.
If unspecified, it defaults to -1.
Quantize the matrix w using bits bits per element.
Note, every group_size elements in a row of w are quantized
together. Hence, number of columns of w should be divisible by
@@ -904,9 +906,9 @@ save \(s\) and Parameters:
Perform the matrix multiplication with the quantized matrix w. The
quantization uses one floating point scale and bias per group_size of
elements. Each element in w takes bits bits and is packed in an
@@ -883,12 +885,12 @@ unsigned 32 bit integer.
w (array) – Quantized matrix packed in unsigned integers
scales (array) – The scales to use per group_size elements of w
biases (array) – The biases to use per group_size elements of w
-
transpose (bool, optional) – Defines whether to multiply with the
+
transpose (bool, optional) – Defines whether to multiply with the
transposed w or not, namely whether we are performing
x@w.T or x@w. Default: True.
-
group_size (int, optional) – The size of the group in w that
+
group_size (int, optional) – The size of the group in w that
shares a scale and bias. Default: 64.
-
bits (int, optional) – The number of bits occupied by each element in
+
bits (int, optional) – The number of bits occupied by each element in
w. Default: 4.
The values are sampled from the bernoulli distribution with parameter
-p. The parameter p can be a float or array and
+p. The parameter p can be a float or array and
must be broadcastable to shape.
Parameters:
-
p (float or array, optional) – Parameter of the Bernoulli
+
p (float or array, optional) – Parameter of the Bernoulli
distribution. Default: 0.5.
-
shape (list(int), optional) – Shape of the output.
+
shape (list(int), optional) – Shape of the output.
Default: p.shape.
key (array, optional) – A PRNG key. Default: None.
The values are sampled from the categorical distribution specified by
the unnormalized values in logits. Note, at most one of shape
@@ -881,12 +883,12 @@ has the same shape as Parameters:
logits (array) – The unnormalized categorical distribution(s).
-
axis (int, optional) – The axis which specifies the distribution.
+
axis (int, optional) – The axis which specifies the distribution.
Default: -1.
-
shape (list(int), optional) – The shape of the output. This must
+
shape (list(int), optional) – The shape of the output. This must
be broadcast compatable with logits.shape with the axis
dimension removed. Default: None
-
num_samples (int, optional) – The number of samples to draw from each
+
num_samples (int, optional) – The number of samples to draw from each
of the categorical distributions in logits. The output will have
num_samples in the last dimension. Default: None.
key (array, optional) – A PRNG key. Default: None.
Generate jointly-normal random samples given a mean and covariance.
The matrix cov must be positive semi-definite. The behavior is
undefined if it is not. The only supported dtype is float32.
@@ -883,7 +885,7 @@ distribution.
cov (array) – array of shape (...,n,n), the covariance
matrix of the distribution. The batch shape ... must be
broadcast-compatible with that of mean.
-
shape (list(int), optional) – The output shape must be
+
shape (list(int), optional) – The output shape must be
broadcast-compatible with mean.shape[:-1] and cov.shape[:-2].
If empty, the result shape is determined by broadcasting the batch
shapes of mean and cov. Default: [].
Generate a random permutation or permute the entries of an array.
+
+
Parameters:
+
+
x (int or array, optional) – If an integer is provided a random
+permtuation of mx.arange(x) is returned. Otherwise the entries
+of x along the given axis are randomly permuted.
+
axis (int, optional) – The axis to permute along. Default: 0.
+
key (array, optional) – A PRNG key. Default: None.
+
+
+
Returns:
+
The generated random permutation or randomly permuted input array.
The values are sampled with equal probability from the integers in
half-open interval [low,high). The lower and upper bound can be
@@ -881,7 +883,7 @@ scalars or arrays and must be roadcastable to
low (scalar or array) – Lower bound of the interval.
high (scalar or array) – Upper bound of the interval.
-
shape (list(int), optional) – Shape of the output. Default: ().
+
shape (list(int), optional) – Shape of the output. Default: ().
dtype (Dtype, optional) – Type of the output. Default: int32.
key (array, optional) – A PRNG key. Default: None.
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution
on the domain (lower,upper). The bounds lower and upper
@@ -881,7 +883,7 @@ can be scalars or arrays and must be broadcastable to
lower (scalar or array) – Lower bound of the domain.
upper (scalar or array) – Upper bound of the domain.
-
shape (list(int), optional) – The shape of the output.
+
shape (list(int), optional) – The shape of the output.
Default:().
dtype (Dtype, optional) – The data type of the output.
Default: float32.
The values are sampled uniformly in the half-open interval [low,high).
The lower and upper bound can be scalars or arrays and must be
@@ -883,7 +885,7 @@ broadcastable to sh
Default: 0.
high (scalar or array, optional) – Upper bound of the distribution.
Default: 1.
-
shape (list(int), optional) – Shape of the output. Default:().
+
shape (list(int), optional) – Shape of the output. Default:().
dtype (Dtype, optional) – Type of the output. Default: float32.
key (array, optional) – A PRNG key. Default: None.
Computes the remainder of dividing a with b with numpy-style
broadcasting semantics. Either or both input arrays can also be
diff --git a/docs/build/html/python/_autosummary/mlx.core.repeat.html b/docs/build/html/python/_autosummary/mlx.core.repeat.html
index b5cad8fc7..02523905a 100644
--- a/docs/build/html/python/_autosummary/mlx.core.repeat.html
+++ b/docs/build/html/python/_autosummary/mlx.core.repeat.html
@@ -8,7 +8,7 @@
-
Shift the bits of the first input to the right by the second using
numpy-style broadcasting semantics. Either or both input arrays can
@@ -915,11 +917,11 @@ also be scalars.
shift (int or tuple(int)) – The number of places by which elements
+are shifted. If positive the array is rolled to the right, if
+negative it is rolled to the left. If an int is provided but the
+axis is a tuple then the same value is used for all axes.
+
axis (int or tuple(int), optional) – The axis or axes along which to
+roll the elements.
metadata (dict(str, Union[array, str, list(str)])) – The dictionary
of metadata to be saved. The values can be a scalar or 1D
-obj:array, a str, or a list of str.
axis (int or None, optional) – Optional axis to sort over.
+
axis (int or None, optional) – Optional axis to sort over.
If None, this sorts over the flattened array.
If unspecified, it defaults to -1 (sorting over the last axis).
indices_or_sections (int or list(int)) – If indices_or_sections
+
indices_or_sections (int or list(int)) – If indices_or_sections
is an integer the array is split into that many sections of equal
size. An error is raised if this is not possible. If indices_or_sections
is a list, the list contains the indices of the start of each subarray
along the given axis.
-
axis (int, optional) – Axis to split along, defaults to 0.
+
axis (int, optional) – Axis to split along, defaults to 0.
The elements are taken from indices along the specified axis.
If the axis is not specified the array is treated as a flattened
@@ -881,8 +883,8 @@ If the axis is not specified the array is treated as a flattened
axes (int or list(list(int)), optional) – The number of dimensions to
+
axes (int or list(list(int)), optional) – The number of dimensions to
sum over. If an integer is provided, then sum over the last
axes dimensions of a and the first axes dimensions of
b. If a list of lists is provided, then sum over the
diff --git a/docs/build/html/python/_autosummary/mlx.core.tile.html b/docs/build/html/python/_autosummary/mlx.core.tile.html
index b3c5300f7..f6c2f24ce 100644
--- a/docs/build/html/python/_autosummary/mlx.core.tile.html
+++ b/docs/build/html/python/_autosummary/mlx.core.tile.html
@@ -8,7 +8,7 @@
-
axis (int or None, optional) – Optional axis to select over.
If None, this selects the top k elements over the
flattened array. If unspecified, it defaults to -1.
Returns a function which computes the value and gradient of fun.
The function passed to value_and_grad() should return either
a scalar loss or a tuple in which the first element is a scalar
@@ -905,12 +907,12 @@ loss and the remaining elements can be anything.
array or trees of array and returns
a scalar output array or a tuple the first element
of which should be a scalar array.
-
argnums (int or list(int), optional) – Specify the index (or indices)
+
argnums (int or list(int), optional) – Specify the index (or indices)
of the positional arguments of fun to compute the gradient
with respect to. If neither argnums nor argnames are
provided argnums defaults to 0 indicating fun’s first
argument.
-
argnames (str or list(str), optional) – Specify keyword arguments of
+
argnames (str or list(str), optional) – Specify keyword arguments of
fun to compute gradients with respect to. It defaults to [] so
no gradients for keyword arguments by default.
primals (list(array)) – A list of array at which to
evaluate the Jacobian.
-
cotangents (list(array)) – A list of array which are the
+
cotangents (list(array)) – A list of array which are the
“vector” in the vector-Jacobian product. The cotangents should be the
same in number, shape, and type as the outputs of fun.
@@ -892,7 +894,7 @@ same in number, shape, and type as the outputs of fun.
fun (Callable) – A function which takes a variable number of
array or a tree of array and returns
a variable number of array or a tree of array.
-
in_axes (int, optional) – An integer or a valid prefix tree of the
+
in_axes (int, optional) – An integer or a valid prefix tree of the
inputs to fun where each node specifies the vmapped axis. If
the value is None then the corresponding input(s) are not vmapped.
Defaults to 0.
-
out_axes (int, optional) – An integer or a valid prefix tree of the
+
out_axes (int, optional) – An integer or a valid prefix tree of the
outputs of fun where each node specifies the vmapped axis. If
the value is None then the corresponding outputs(s) are not vmapped.
Defaults to 0.
Quantize the sub-modules of a module according to a predicate.
By default all layers that define a to_quantized(group_size,bits)
method will be quantized. Both Linear and Embedding layers
@@ -880,9 +882,9 @@ will be quantized. Note also, the module is updated in-place.
Parameters:
model (Module) – The model whose leaf modules may be quantized.
-
group_size (int) – The quantization group size (see
+
class_predicate (Optional[Callable]) – A callable which receives the
Module path and Module itself and returns True if
diff --git a/docs/build/html/python/_autosummary/mlx.nn.value_and_grad.html b/docs/build/html/python/_autosummary/mlx.nn.value_and_grad.html
index 059bae60e..b24b1df49 100644
--- a/docs/build/html/python/_autosummary/mlx.nn.value_and_grad.html
+++ b/docs/build/html/python/_autosummary/mlx.nn.value_and_grad.html
@@ -8,7 +8,7 @@
-
Applies fn to the leaves of the Python tree tree and
returns a new collection with the results.
If rest is provided, every item is assumed to be a superset of tree
and the corresponding leaves are provided as extra positional arguments to
-fn. In that respect, tree_map() is closer to itertools.starmap()
-than to map().
Assuming an input of shape \((N, L, C)\) and kernel_size is
\(k\), the output is a tensor of shape \((N, L_{out}, C)\), given
@@ -887,10 +889,10 @@ by:
Parameters:
-
kernel_size (int or tuple(int)) – The size of the pooling window kernel.
-
stride (int or tuple(int), optional) – The stride of the pooling window.
+
kernel_size (int or tuple(int)) – The size of the pooling window kernel.
+
stride (int or tuple(int), optional) – The stride of the pooling window.
Default: kernel_size.
-
padding (int or tuple(int), optional) – How much zero padding to apply to
+
padding (int or tuple(int), optional) – How much zero padding to apply to
the input. The padding amount is applied to both sides of the spatial
axis. Default: 0.
Assuming an input of shape \((N, H, W, C)\) and kernel_size is
\((k_H, k_W)\), the output is a tensor of shape \((N, H_{out},
@@ -897,10 +899,10 @@ used for the height axis, the second
Parameters:
-
kernel_size (int or tuple(int, int)) – The size of the pooling window.
-
stride (int or tuple(int, int), optional) – The stride of the pooling
+
kernel_size (int or tuple(int, int)) – The size of the pooling window.
+
stride (int or tuple(int, int), optional) – The stride of the pooling
window. Default: kernel_size.
-
padding (int or tuple(int, int), optional) – How much zero
+
padding (int or tuple(int, int), optional) – How much zero
padding to apply to the input. The padding is applied on both sides
of the height and width axis. Default: 0.
Randomly zero a portion of the elements during training.
The remaining elements are multiplied with \(\frac{1}{1-p}\) where
\(p\) is the probability of zeroing an element. This is done so the
expected value of a given element will remain the same.
Randomly zero out entire channels independently with probability \(p\).
This layer expects the channels to be last, i.e. the input shape should be
@@ -889,7 +891,7 @@ regularize activations. For more details, see [1].
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
Parameters:
-
p (float) – Probability of zeroing a channel during training.
+
p (float) – Probability of zeroing a channel during training.
Randomly zero out entire channels independently with probability \(p\).
This layer expects the channels to be last, i.e., the input shape should be
@@ -886,7 +888,7 @@ often beneficial for convolutional layers processing 3D data, like in
medical imaging or video processing.
Parameters:
-
p (float) – Probability of zeroing a channel during training.
+
p (float) – Probability of zeroing a channel during training.
Assuming an input of shape \((N, L, C)\) and kernel_size is
\(k\), the output is a tensor of shape \((N, L_{out}, C)\), given
@@ -887,10 +889,10 @@ by:
Parameters:
-
kernel_size (int or tuple(int)) – The size of the pooling window kernel.
-
stride (int or tuple(int), optional) – The stride of the pooling window.
+
kernel_size (int or tuple(int)) – The size of the pooling window kernel.
+
stride (int or tuple(int), optional) – The stride of the pooling window.
Default: kernel_size.
-
padding (int or tuple(int), optional) – How much negative infinity
+
padding (int or tuple(int), optional) – How much negative infinity
padding to apply to the input. The padding amount is applied to
both sides of the spatial axis. Default: 0.
Assuming an input of shape \((N, H, W, C)\) and kernel_size is
\((k_H, k_W)\), the output is a tensor of shape \((N, H_{out},
@@ -897,10 +899,10 @@ used for the height axis, the second
Parameters:
-
kernel_size (int or tuple(int, int)) – The size of the pooling window.
-
stride (int or tuple(int, int), optional) – The stride of the pooling
+
kernel_size (int or tuple(int, int)) – The size of the pooling window.
+
stride (int or tuple(int, int), optional) – The stride of the pooling
window. Default: kernel_size.
-
padding (int or tuple(int, int), optional) – How much negative infinity
+
padding (int or tuple(int, int), optional) – How much negative infinity
padding to apply to the input. The padding is applied on both sides
of the height and width axis. Default: 0.
Recursively filter the contents of the module using filter_fn,
namely only select keys and values where filter_fn returns true.
This is used to implement parameters() and trainable_parameters()
diff --git a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.freeze.html b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.freeze.html
index d7c558b94..bac2441ca 100644
--- a/docs/build/html/python/nn/_autosummary/mlx.nn.Module.freeze.html
+++ b/docs/build/html/python/nn/_autosummary/mlx.nn.Module.freeze.html
@@ -8,7 +8,7 @@
-
Freeze the Module’s parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent i.e. freezing a frozen model is a no-op.
@@ -885,13 +887,13 @@ computing gradients for it.
Parameters:
-
recurse (bool, optional) – If True then freeze the parameters of the
+
recurse (bool, optional) – If True then freeze the parameters of the
submodules as well. Default: True.
-
keys (str or list[str], optional) – If provided then only these
+
keys (str or list[str], optional) – If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
module.freeze(keys="bias").
-
strict (bool, optional) – If set to True validate that the passed keys exist.
+
strict (bool, optional) – If set to True validate that the passed keys exist.
Default: False.
Update the model’s weights from a .npz, a .safetensors file, or a list.
Parameters:
-
file_or_weights (str or list(tuple(str, mx.array))) – The path to
+
file_or_weights (str or list(tuple(str, mx.array))) – The path to
the weights .npz file (.npz or .safetensors) or a list
of pairs of parameter names and arrays.
-
strict (bool, optional) – If True then checks that the provided
+
strict (bool, optional) – If True then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: True.
Save the model’s weights to a file. The saving method is determined by the file extension:
- .npz will use mx.savez()
- .safetensors will use mx.save_safetensors()
predicate (Callable, optional) – A predicate to select
+
predicate (Callable, optional) – A predicate to select
parameters to cast. By default, only parameters of type
floating will be updated to avoid casting integer
parameters to the new dtype.
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
@@ -885,13 +887,13 @@ a noop.
Parameters:
-
recurse (bool, optional) – If True then unfreeze the parameters of the
+
recurse (bool, optional) – If True then unfreeze the parameters of the
submodules as well. Default: True.
-
keys (str or list[str], optional) – If provided then only these
+
keys (str or list[str], optional) – If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
module.unfreeze(keys="bias").
-
strict (bool, optional) – If set to True validate that the passed keys exist.
+
strict (bool, optional) – If set to True validate that the passed keys exist.
Default: False.
Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the MultiHeadAttention
produces new values by aggregating information from the input values
@@ -885,20 +887,20 @@ that should not be attended to.
Parameters:
-
dims (int) – The model dimensions. This is also the default
+
dims (int) – The model dimensions. This is also the default
value for the queries, keys, values, and the output.
-
num_heads (int) – The number of attention heads to use.
-
query_input_dims (int, optional) – The input dimensions of the queries.
+
num_heads (int) – The number of attention heads to use.
+
query_input_dims (int, optional) – The input dimensions of the queries.
Default: dims.
-
key_input_dims (int, optional) – The input dimensions of the keys.
+
key_input_dims (int, optional) – The input dimensions of the keys.
Default: dims.
-
value_input_dims (int, optional) – The input dimensions of the values.
+
value_input_dims (int, optional) – The input dimensions of the values.
Default: key_input_dims.
-
value_dims (int, optional) – The dimensions of the values after the
+
value_dims (int, optional) – The dimensions of the values after the
projection. Default: dims.
-
value_output_dims (int, optional) – The dimensions the new values will
+
value_output_dims (int, optional) – The dimensions the new values will
be projected to. Default: dims.
-
bias (bool, optional) – Whether or not to use a bias in the projections.
+
bias (bool, optional) – Whether or not to use a bias in the projections.
Default: False.
Applies an affine transformation to the input using a quantized weight matrix.
It is the quantized equivalent of mlx.nn.Linear. For now its
parameters are frozen and will not be included in any gradient computation
@@ -881,13 +883,13 @@ convert linear layers to
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
@@ -881,13 +883,13 @@ Embedding.
Parameters:
-
dims (int) – The feature dimensions to be rotated. If the input feature
+
dims (int) – The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
-
traditional (bool, optional) – If set to True choose the traditional
+
traditional (bool, optional) – If set to True choose the traditional
implementation which is slightly less efficient. Default: False.
-
base (float, optional) – The base used to compute angular frequency for
+
base (float, optional) – The base used to compute angular frequency for
each dimension in the positional encodings. Default: 10000.
-
scale (float, optional) – The scale used to scale the positions. Default: 1.0.
+
scale (float, optional) – The scale used to scale the positions. Default: 1.0.
dims (int, optional) – The number of expected features in the
+
dims (int, optional) – The number of expected features in the
encoder/decoder inputs. Default: 512.
-
num_heads (int, optional) – The number of attention heads. Default:
+
num_heads (int, optional) – The number of attention heads. Default:
8.
-
num_encoder_layers (int, optional) – The number of encoder layers in the
+
num_encoder_layers (int, optional) – The number of encoder layers in the
Transformer encoder. Default: 6.
-
num_decoder_layers (int, optional) – The number of decoder layers in the
+
num_decoder_layers (int, optional) – The number of decoder layers in the
Transformer decoder. Default: 6.
-
mlp_dims (int, optional) – The hidden dimension of the MLP block in each
+
mlp_dims (int, optional) – The hidden dimension of the MLP block in each
Transformer layer. Defaults to 4*dims if not provided. Default:
None.
-
dropout (float, optional) – The dropout value for the Transformer
+
dropout (float, optional) – The dropout value for the Transformer
encoder and decoder. Dropout is used after each attention layer and
the activation in the MLP layer. Default: 0.0.
activation (function, optional) – the activation function for the MLP
@@ -901,10 +903,10 @@ hidden layer. Default: None.
custom_decoder (Module, optional) – A custom decoder to replace the
standard Transformer decoder. Default: None.
-
norm_first (bool, optional) – if True, encoder and decoder layers
+
norm_first (bool, optional) – if True, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: True.
-
checkpoint (bool, optional) – if True perform gradient checkpointing
+
checkpoint (bool, optional) – if True perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: False.
The spatial dimensions are by convention dimensions 1 to x.ndim-2. The first is the batch dimension and the last is the feature
@@ -893,13 +895,13 @@ output will be matching as will the bottom right edge.
Parameters:
-
scale_factor (float or tuple) – The multiplier for the spatial size.
+
scale_factor (float or tuple) – The multiplier for the spatial size.
If a float is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
-
mode (str, optional) – The upsampling algorithm, either "nearest",
+
mode (str, optional) – The upsampling algorithm, either "nearest",
"linear" or "cubic". Default: "nearest".
-
align_corners (bool, optional) – Changes the way the corners are treated
+
align_corners (bool, optional) – Changes the way the corners are treated
during "linear" and "cubic" upsampling. See the note above and the
examples below for more details. Default: False.
This initializer samples from a normal distribution with a standard
deviation computed from the number of input (fan_in) and output
@@ -891,7 +893,7 @@ with the same shape as the input, filled with samples from the Glorot
normal distribution.
This initializer samples from a uniform distribution with a range
computed from the number of input (fan_in) and output (fan_out)
@@ -891,7 +893,7 @@ with the same shape as the input, filled with samples from the Glorot
uniform distribution.
This initializer samples from a normal distribution with a standard
deviation computed from the number of input (fan_in) or output
@@ -894,7 +896,7 @@ array with the same shape as the input, filled with samples from the He
normal distribution.
This initializer samples from a uniform distribution with a range
computed from the number of input (fan_in) or output (fan_out)
@@ -894,7 +896,7 @@ array with the same shape as the input, filled with samples from the
He uniform distribution.
By default, this function takes the pre-sigmoid logits, which results in a faster
and more precise loss. For improved numerical stability when with_logits=False,
@@ -883,9 +885,9 @@ of -100
inputs (array) – The predicted values. If with_logits is True, then
inputs are unnormalized logits. Otherwise, inputs are probabilities.
targets (array) – The binary target values in {0, 1}.
-
with_logits (bool, optional) – Whether inputs are logits. Default: True.
+
with_logits (bool, optional) – Whether inputs are logits. Default: True.
weights (array, optional) – Optional weights for each target. Default: None.
-
reduction (str, optional) – Specifies the reduction to apply to the output:
+
reduction (str, optional) – Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. Default: 'mean'.
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
and like the L1 loss for large errors, reducing sensitivity to outliers. This
@@ -886,7 +888,7 @@ dual behavior offers a balanced, robust approach for regression tasks.
The smooth L1 loss is a variant of the L1 loss which replaces the absolute
difference with a squared difference when the absolute difference is less
@@ -888,9 +890,9 @@ than beta
learning_rate (float or callable, optional) – The learning rate.
+
learning_rate (float or callable, optional) – The learning rate.
Default: None.
-
eps (tuple(float, float), optional) – The first term \(\epsilon_1\)
+
eps (tuple(float, float), optional) – The first term \(\epsilon_1\)
added to the square of the gradients to improve numerical
stability and the second term \(\epsilon_2\) is used for
parameter scaling if parameter_scale is set to True.
Default: (1e-30,1e-3).
-
clip_threshold (float, optional) – Clips the unscaled update at
+
clip_threshold (float, optional) – Clips the unscaled update at
clip_threshold. Default: 1.0.
-
decay_rate (float, optional) – Coefficient for the running average
+
decay_rate (float, optional) – Coefficient for the running average
of the squared gradient. Default: -0.8.
-
beta_1 (float, optional) – If set to a value bigger than zero
+
beta_1 (float, optional) – If set to a value bigger than zero
then first moment will be used. Default: None.
-
weight_decay (float, optional) – The weight decay \(\lambda\).
+
weight_decay (float, optional) – The weight decay \(\lambda\).
Default: 0.0.
-
scale_parameter (bool, optional) – If set to True the learning rate
+
scale_parameter (bool, optional) – If set to True the learning rate
will be scaled by \(\max(\epsilon_1, \text{RMS}(w_{t-1}))\).
Default: True.
-
relative_step (bool, optional) – If set to True the learning_rate
+
relative_step (bool, optional) – If set to True the learning_rate
will be ignored and relative step size will be computed.
Default: True.
-
warmup_init (bool, optional) – If set to True then the relative
+
warmup_init (bool, optional) – If set to True then the relative
step size will be calculated by the current step. Default:
False.
learning_rate (float or callable) – The learning rate \(\lambda\).
-
betas (Tuple[float, float], optional) – The coefficients
+
learning_rate (float or callable) – The learning rate \(\lambda\).
+
betas (Tuple[float, float], optional) – The coefficients
\((\beta_1, \beta_2)\) used for computing running averages of the
gradient and its square. Default: (0.9,0.999)
-
eps (float, optional) – The term \(\epsilon\) added to the
+
eps (float, optional) – The term \(\epsilon\) added to the
denominator to improve numerical stability. Default: 1e-8
Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights
@@ -886,13 +888,13 @@ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda
Parameters:
-
learning_rate (float or callable) – The learning rate \(\alpha\).
-
betas (Tuple[float, float], optional) – The coefficients
+
learning_rate (float or callable) – The learning rate \(\alpha\).
+
betas (Tuple[float, float], optional) – The coefficients
\((\beta_1, \beta_2)\) used for computing running averages of the
gradient and its square. Default: (0.9,0.999)
-
eps (float, optional) – The term \(\epsilon\) added to the
+
eps (float, optional) – The term \(\epsilon\) added to the
denominator to improve numerical stability. Default: 1e-8
-
weight_decay (float, optional) – The weight decay \(\lambda\).
+
weight_decay (float, optional) – The weight decay \(\lambda\).
Default: 0.
learning_rate (float or callable) – The learning rate \(\lambda\).
-
betas (Tuple[float, float], optional) – The coefficients
+
learning_rate (float or callable) – The learning rate \(\lambda\).
+
betas (Tuple[float, float], optional) – The coefficients
\((\beta_1, \beta_2)\) used for computing running averages of the
gradient and its square. Default: (0.9,0.999)
-
eps (float, optional) – The term \(\epsilon\) added to the
+
eps (float, optional) – The term \(\epsilon\) added to the
denominator to improve numerical stability. Default: 1e-8
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
@@ -889,11 +891,11 @@ w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)\end{split}\]
Parameters:
-
learning_rate (float or callable) – The learning rate \(\eta\).
-
betas (Tuple[float, float], optional) – The coefficients
+
learning_rate (float or callable) – The learning rate \(\eta\).
+
betas (Tuple[float, float], optional) – The coefficients
\((\beta_1, \beta_2)\) used for computing the gradient
momentum and update direction. Default: (0.9,0.99)
-
weight_decay (float, optional) – The weight decay \(\lambda\). Default: 0.0
+
weight_decay (float, optional) – The weight decay \(\lambda\). Default: 0.0
parameters (dict) – A Python tree of parameters. It can be a
superset of the gradients. In that case the returned python
tree will be of the same structure as the gradients.
This function can be used to initialize optimizers which have state
(like momentum in SGD). Using this method is optional as the
@@ -881,7 +883,7 @@ to have access to the Optimizer.update().
For the most part, indexing an MLX array works the same as indexing a
NumPy numpy.ndarray. See the NumPy documentation for more details on
how that works.
-
For example, you can use regular integers and slices (slice) to index arrays:
+
For example, you can use regular integers and slices (slice) to index arrays:
@@ -958,7 +960,7 @@ stochastic gradient descent). A natural and usually efficient place to use
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you print an array, convert it to an
-numpy.ndarray, or otherwise access it’s memory via memoryview,
+numpy.ndarray, or otherwise access it’s memory via memoryview,
the graph will be evaluated. Saving arrays via save() (or any other MLX
saving functions) will also evaluate the array.
Calling array.item() on a scalar array will also evaluate it. In the
diff --git a/docs/build/html/usage/numpy.html b/docs/build/html/usage/numpy.html
index 889a18ad8..156381858 100644
--- a/docs/build/html/usage/numpy.html
+++ b/docs/build/html/usage/numpy.html
@@ -8,7 +8,7 @@
-
Conversion to NumPy and Other Frameworks — MLX 0.18.0 documentation
+ Conversion to NumPy and Other Frameworks — MLX 0.18.1 documentation
@@ -36,7 +36,7 @@
-
+
@@ -131,8 +131,8 @@
-
-
+
+
@@ -370,6 +370,7 @@