mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
Corrected the example for mx.value_and_grad
This commit is contained in:
parent
f91f450141
commit
9ca34c6287
@ -569,7 +569,7 @@ void init_transforms(py::module_& m) {
|
|||||||
return lvalue
|
return lvalue
|
||||||
|
|
||||||
# Returns lvalue, dlvalue/dparams
|
# Returns lvalue, dlvalue/dparams
|
||||||
lvalue, grads = mx.value_and_grad(mse)
|
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)
|
||||||
|
|
||||||
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
||||||
outputs = forward(params, inputs)
|
outputs = forward(params, inputs)
|
||||||
@ -580,7 +580,7 @@ void init_transforms(py::module_& m) {
|
|||||||
|
|
||||||
return loss, mse, l1
|
return loss, mse, l1
|
||||||
|
|
||||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)
|
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fun (function): A function which takes a variable number of
|
fun (function): A function which takes a variable number of
|
||||||
|
Loading…
Reference in New Issue
Block a user