Real and Imag (#1490)

* real and imag

* fix

* fix
This commit is contained in:
Awni Hannun
2024-10-15 16:23:15 -07:00
committed by GitHub
parent 2b8ace6a03
commit 3f86399922
21 changed files with 275 additions and 46 deletions

View File

@@ -1797,6 +1797,36 @@ std::vector<array> GreaterEqual::jvp(
return {zeros(shape, bool_, stream())};
}
std::vector<array> Imag::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
cotangents[0],
stream())};
}
std::vector<array> Imag::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {imag(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> Imag::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{imag(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
return shape_ == r_other.shape_;
}
std::vector<array> Real::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {astype(cotangents[0], primals[0].dtype(), stream())};
}
std::vector<array> Real::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {real(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> Real::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{real(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {