mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
contiguous op / prim (#1612)
This commit is contained in:
@@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||
return {{conjugate(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Contiguous::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {cotangents};
|
||||
}
|
||||
|
||||
std::vector<array> Contiguous::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {tangents};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
|
||||
}
|
||||
|
||||
bool Contiguous::is_equivalent(const Primitive& other) const {
|
||||
const Contiguous& c_other = static_cast<const Contiguous&>(other);
|
||||
return allow_col_major_ == c_other.allow_col_major_;
|
||||
}
|
||||
|
||||
array conv_weight_backward_patches(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
|
||||
Reference in New Issue
Block a user