mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
@@ -171,12 +171,12 @@ auto py_value_and_grad(
|
||||
nb::list tree;
|
||||
tree.append(args);
|
||||
tree.append(kwargs);
|
||||
tree_replace(tree, arrays, a);
|
||||
tree_fill(tree, a);
|
||||
|
||||
// Call the python function
|
||||
py_value_out = fun(*tree[0], **tree[1]);
|
||||
|
||||
tree_replace(tree, arrays, a);
|
||||
tree_fill(tree, arrays);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
|
Reference in New Issue
Block a user