spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
Josh Soref 2024-01-01 22:43:44 -05:00
parent 94eeae4b28
commit 684e946b44
2 changed files with 4 additions and 4 deletions

View File

@ -602,7 +602,7 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops // The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha // jvp is just the tangent scaled by alpha
@ -642,7 +642,7 @@ own :class:`Primitive`.
.. code-block:: C++ .. code-block:: C++
/** Vectorize primitve along given axis */ /** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap( std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {

View File

@ -312,7 +312,7 @@ array Axpby::jvp(
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops // The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha // jvp is just the tangent scaled by alpha
@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
return vjps; return vjps;
} }
/** Vectorize primitve along given axis */ /** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap( std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {