contiguous op / prim (#1612)

This commit is contained in:
Awni Hannun
2024-11-21 19:51:49 -08:00
committed by GitHub
parent 0d5e7716ad
commit dcca0d7477
11 changed files with 104 additions and 25 deletions

View File

@@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
return {{conjugate(inputs[0], stream())}, axes};
}
std::vector<array> Contiguous::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {cotangents};
}
std::vector<array> Contiguous::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {tangents};
}
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
}
bool Contiguous::is_equivalent(const Primitive& other) const {
const Contiguous& c_other = static_cast<const Contiguous&>(other);
return allow_col_major_ == c_other.allow_col_major_;
}
array conv_weight_backward_patches(
const array& in,
const array& wt,