mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom transforms (#1246)
This commit is contained in:
committed by
GitHub
parent
a3c287354f
commit
5c1fa64fb0
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
|
||||
@@ -52,7 +52,7 @@ DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
|
||||
Reference in New Issue
Block a user