mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
94eeae4b28
commit
684e946b44
@ -602,7 +602,7 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@ -642,7 +642,7 @@ own :class:`Primitive`.
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<array, int> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -312,7 +312,7 @@ array Axpby::jvp(
|
|||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the the primitive can built with ops
|
// The jvp transform on the the primitive can built with ops
|
||||||
// that are scheduled on the same stream as the primtive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
// jvp is just the tangent scaled by alpha
|
// jvp is just the tangent scaled by alpha
|
||||||
@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Vectorize primitve along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<array, int> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
Loading…
Reference in New Issue
Block a user