From 9ca34c62879ba767897a11aec4020871a496f960 Mon Sep 17 00:00:00 2001 From: vidit Date: Sun, 24 Dec 2023 00:08:17 +0530 Subject: [PATCH] Corrected the example for mx.value_and_grad --- python/src/transforms.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a592b4458..096d5a486 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -569,7 +569,7 @@ void init_transforms(py::module_& m) { return lvalue # Returns lvalue, dlvalue/dparams - lvalue, grads = mx.value_and_grad(mse) + lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) @@ -580,7 +580,7 @@ void init_transforms(py::module_& m) { return loss, mse, l1 - (loss, mse, l1), grads = mx.value_and_grad(lasso) + (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: fun (function): A function which takes a variable number of