mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -1797,6 +1797,36 @@ std::vector<array> GreaterEqual::jvp(
|
||||
return {zeros(shape, bool_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Imag::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(
|
||||
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
|
||||
cotangents[0],
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Imag::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {imag(tangents[0], stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Imag::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{imag(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == r_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<array> Real::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {astype(cotangents[0], primals[0].dtype(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Real::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {real(tangents[0], stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Real::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{real(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
Reference in New Issue
Block a user