diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a592b4458..096d5a486 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -569,7 +569,7 @@ void init_transforms(py::module_& m) { return lvalue # Returns lvalue, dlvalue/dparams - lvalue, grads = mx.value_and_grad(mse) + lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) @@ -580,7 +580,7 @@ void init_transforms(py::module_& m) { return loss, mse, l1 - (loss, mse, l1), grads = mx.value_and_grad(lasso) + (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: fun (function): A function which takes a variable number of