Custom transforms (#1246)

This commit is contained in:
Angelos Katharopoulos
2024-07-10 18:00:01 -07:00
committed by GitHub
parent a3c287354f
commit 5c1fa64fb0
16 changed files with 734 additions and 39 deletions

View File

@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]);
}
void CustomVJP::eval(
void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());

View File

@@ -52,7 +52,7 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)