diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index e2cc981e2..14c5cb15e 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -88,5 +88,3 @@ class Dropout2d(Module): mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) return (1 / self._p_1) * mask * x - - diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a592b4458..096d5a486 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -569,7 +569,7 @@ void init_transforms(py::module_& m) { return lvalue # Returns lvalue, dlvalue/dparams - lvalue, grads = mx.value_and_grad(mse) + lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) @@ -580,7 +580,7 @@ void init_transforms(py::module_& m) { return loss, mse, l1 - (loss, mse, l1), grads = mx.value_and_grad(lasso) + (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: fun (function): A function which takes a variable number of