Some fixes in docs (#1141)

* fixes in docs

* nit
This commit is contained in:
Awni Hannun 2024-05-20 11:51:47 -07:00 committed by GitHub
parent da83f899bb
commit e6fecbb3e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 37 deletions

View File

@ -191,14 +191,14 @@ The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do: GGUF, you can do:
```shell .. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \ cmake ..
-DBUILD_SHARED_LIBS=ON \ -DCMAKE_BUILD_TYPE=MinSizeRel \
-DMLX_BUILD_CPU=ON \ -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_SAFETENSORS=OFF \ -DMLX_BUILD_CPU=ON \
-DMLX_BUILD_GGUF=OFF -DMLX_BUILD_SAFETENSORS=OFF \
``` -DMLX_BUILD_GGUF=OFF
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@ -10,5 +10,6 @@ Linear Algebra
inv inv
norm norm
cholesky
qr qr
svd svd

View File

@ -11,9 +11,10 @@ class Conv1d(Module):
"""Applies a 1-dimensional convolution over the multi-channel input sequence. """Applies a 1-dimensional convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be ``NLC`` where: The channels are expected to be last i.e. the input shape should be ``NLC`` where:
- ``N`` is the batch dimension
- ``L`` is the sequence length * ``N`` is the batch dimension
- ``C`` is the number of input channels * ``L`` is the sequence length
* ``C`` is the number of input channels
Args: Args:
in_channels (int): The number of input channels in_channels (int): The number of input channels
@ -72,10 +73,11 @@ class Conv2d(Module):
"""Applies a 2-dimensional convolution over the multi-channel input image. """Applies a 2-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NHWC`` where: The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height * ``N`` is the batch dimension
- ``W`` is the input image width * ``H`` is the input image height
- ``C`` is the number of input channels * ``W`` is the input image width
* ``C`` is the number of input channels
Args: Args:
in_channels (int): The number of input channels. in_channels (int): The number of input channels.
@ -136,12 +138,15 @@ class Conv2d(Module):
class Conv3d(Module): class Conv3d(Module):
"""Applies a 3-dimensional convolution over the multi-channel input image. """Applies a 3-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where: The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
- ``N`` is the batch dimension
- ``D`` is the input image depth * ``N`` is the batch dimension
- ``H`` is the input image height * ``D`` is the input image depth
- ``W`` is the input image width * ``H`` is the input image height
- ``C`` is the number of input channels * ``W`` is the input image width
* ``C`` is the number of input channels
Args: Args:
in_channels (int): The number of input channels. in_channels (int): The number of input channels.
out_channels (int): The number of output channels. out_channels (int): The number of output channels.

View File

@ -235,7 +235,7 @@ void init_linalg(nb::module_& parent_module) {
Returns: Returns:
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt`` ``A = U @ diag(S) @ Vt``
)pbdoc"); )pbdoc");
m.def( m.def(
"inv", "inv",
@ -286,7 +286,8 @@ void init_linalg(nb::module_& parent_module) {
in which case the default stream of the default device is used. in which case the default stream of the default device is used.
Returns: Returns:
array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``. array: If ``upper = False``, it returns a lower trinagular ``L`` matrix such
If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. that ``dot(L, L.T) = a``. If ``upper = True``, it returns an upper triangular
``U`` matrix such that ``dot(U.T, U) = a``.
)pbdoc"); )pbdoc");
} }

View File

@ -3501,7 +3501,7 @@ void init_ops(nb::module_& m) {
support matadata. The metadata will be returned as an support matadata. The metadata will be returned as an
additional dictionary. additional dictionary.
Returns: Returns:
result (array, dict): array or dict:
A single array if loading from a ``.npy`` file or a dict A single array if loading from a ``.npy`` file or a dict
mapping names to arrays if loading from a ``.npz`` or mapping names to arrays if loading from a ``.npz`` or
``.safetensors`` file. If ``return_metadata` is ``True`` an ``.safetensors`` file. If ``return_metadata` is ``True`` an
@ -3584,7 +3584,7 @@ void init_ops(nb::module_& m) {
y (array): The input selected from where condition is ``False``. y (array): The input selected from where condition is ``False``.
Returns: Returns:
result (array): The output containing elements selected from array: The output containing elements selected from
``x`` and ``y``. ``x`` and ``y``.
)pbdoc"); )pbdoc");
m.def( m.def(
@ -3613,7 +3613,7 @@ void init_ops(nb::module_& m) {
decimals (int): Number of decimal places to round to. (default: 0) decimals (int): Number of decimal places to round to. (default: 0)
Returns: Returns:
result (array): An array of the same type as ``a`` rounded to the array: An array of the same type as ``a`` rounded to the
given number of decimals. given number of decimals.
)pbdoc"); )pbdoc");
m.def( m.def(
@ -3650,7 +3650,7 @@ void init_ops(nb::module_& m) {
``w``. (default: ``4``) ``w``. (default: ``4``)
Returns: Returns:
result (array): The result of the multiplication of ``x`` with ``w``. array: The result of the multiplication of ``x`` with ``w``.
)pbdoc"); )pbdoc");
m.def( m.def(
"quantize", "quantize",
@ -3705,11 +3705,11 @@ void init_ops(nb::module_& m) {
``w`` in the returned quantized matrix. (default: ``4``) ``w`` in the returned quantized matrix. (default: ``4``)
Returns: Returns:
(tuple): A tuple containing tuple: A tuple containing
- w_q (array): The quantized version of ``w`` * w_q (array): The quantized version of ``w``
- scales (array): The scale to multiply each element with, namely :math:`s` * scales (array): The scale to multiply each element with, namely :math:`s`
- biases (array): The biases to add to each element, namely :math:`\beta` * biases (array): The biases to add to each element, namely :math:`\beta`
)pbdoc"); )pbdoc");
m.def( m.def(
"dequantize", "dequantize",
@ -3745,7 +3745,7 @@ void init_ops(nb::module_& m) {
``w``. (default: ``4``) ``w``. (default: ``4``)
Returns: Returns:
result (array): The dequantized version of ``w`` array: The dequantized version of ``w``
)pbdoc"); )pbdoc");
m.def( m.def(
"block_sparse_qmm", "block_sparse_qmm",
@ -3790,7 +3790,7 @@ void init_ops(nb::module_& m) {
``w``. (default: ``4``) ``w``. (default: ``4``)
Returns: Returns:
result (array): The result of the multiplication of ``x`` with ``w`` array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``. after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc"); )pbdoc");
m.def( m.def(
@ -3830,7 +3830,7 @@ void init_ops(nb::module_& m) {
corresponding dimensions of ``a`` and ``b``. (default: 2) corresponding dimensions of ``a`` and ``b``. (default: 2)
Returns: Returns:
result (array): The tensor dot product. array: The tensor dot product.
)pbdoc"); )pbdoc");
m.def( m.def(
"inner", "inner",
@ -3849,7 +3849,7 @@ void init_ops(nb::module_& m) {
b (array): Input array b (array): Input array
Returns: Returns:
result (array): The inner product. array: The inner product.
)pbdoc"); )pbdoc");
m.def( m.def(
"outer", "outer",
@ -3868,7 +3868,7 @@ void init_ops(nb::module_& m) {
b (array): Input array b (array): Input array
Returns: Returns:
result (array): The outer product. array: The outer product.
)pbdoc"); )pbdoc");
m.def( m.def(
"tile", "tile",
@ -3895,7 +3895,7 @@ void init_ops(nb::module_& m) {
reps (int or list(int)): The number of times to repeat ``a`` along each axis. reps (int or list(int)): The number of times to repeat ``a`` along each axis.
Returns: Returns:
result (array): The tiled array. array: The tiled array.
)pbdoc"); )pbdoc");
m.def( m.def(
"addmm", "addmm",