mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
rebase
This commit is contained in:
14
docs/build/html/_sources/dev/extensions.rst
vendored
14
docs/build/html/_sources/dev/extensions.rst
vendored
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -469,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -481,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -735,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -743,7 +743,7 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.async_eval.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.async_eval.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.async\_eval
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: async_eval
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.broadcast_arrays.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.broadcast_arrays.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.broadcast\_arrays
|
||||
==========================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: broadcast_arrays
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.contiguous.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.contiguous.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.contiguous
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autofunction:: contiguous
|
6
docs/build/html/_sources/python/_autosummary/mlx.core.linalg.pinv.rst
vendored
Normal file
6
docs/build/html/_sources/python/_autosummary/mlx.core.linalg.pinv.rst
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
mlx.core.linalg.pinv
|
||||
====================
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autofunction:: pinv
|
1
docs/build/html/_sources/python/linalg.rst
vendored
1
docs/build/html/_sources/python/linalg.rst
vendored
@@ -20,5 +20,6 @@ Linear Algebra
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
2
docs/build/html/_sources/python/ops.rst
vendored
2
docs/build/html/_sources/python/ops.rst
vendored
@@ -36,10 +36,12 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
|
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
Reference in New Issue
Block a user