diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index b8c3a4995..2aef28f99 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -93,9 +93,9 @@ Primitives ^^^^^^^^^^^ A :class:`Primitive` is part of the computation graph of an :class:`array`. It -defines how to create outputs arrays given a input arrays. Further, a +defines how to create output arrays given input arrays. Further, a :class:`Primitive` has methods to run on the CPU or GPU and for function -transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be +transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be more concrete: .. code-block:: C++ @@ -128,7 +128,7 @@ more concrete: /** The vector-Jacobian product. */ std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; @@ -469,7 +469,7 @@ one we just defined: const std::vector& tangents, const std::vector& argnums) { // Forward mode diff that pushes along the tangents - // The jvp transform on the primitive can built with ops + // The jvp transform on the primitive can be built with ops // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the @@ -481,7 +481,7 @@ one we just defined: auto scale_arr = array(scale, tangents[0].dtype()); return {multiply(scale_arr, tangents[0], stream())}; } - // If, argnums = {0, 1}, we take contributions from both + // If argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; @@ -735,7 +735,7 @@ Let's look at a simple script and its results: print(f"c shape: {c.shape}") print(f"c dtype: {c.dtype}") - print(f"c correct: {mx.all(c == 6.0).item()}") + print(f"c is correct: {mx.all(c == 6.0).item()}") Output: @@ -743,7 +743,7 @@ Output: c shape: [3, 4] c dtype: float32 - c correctness: True + c is correct: True Results ^^^^^^^