diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 90aca6b07..4c31b020c 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -602,7 +602,7 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. const std::vector& argnums) { // Forward mode diff that pushes along the tangents // The jvp transform on the the primitive can built with ops - // that are scheduled on the same stream as the primtive + // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the // jvp is just the tangent scaled by alpha @@ -642,7 +642,7 @@ own :class:`Primitive`. .. code-block:: C++ - /** Vectorize primitve along given axis */ + /** Vectorize primitive along given axis */ std::pair Axpby::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 5d033c70c..dacdbb9ab 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -312,7 +312,7 @@ array Axpby::jvp( const std::vector& argnums) { // Forward mode diff that pushes along the tangents // The jvp transform on the the primitive can built with ops - // that are scheduled on the same stream as the primtive + // that are scheduled on the same stream as the primitive // If argnums = {0}, we only push along x in which case the // jvp is just the tangent scaled by alpha @@ -345,7 +345,7 @@ std::vector Axpby::vjp( return vjps; } -/** Vectorize primitve along given axis */ +/** Vectorize primitive along given axis */ std::pair Axpby::vmap( const std::vector& inputs, const std::vector& axes) {