mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix grad with inplace updates (#1961)
This commit is contained in:
		| @@ -176,7 +176,17 @@ auto py_value_and_grad( | ||||
|           // Call the python function | ||||
|           py_value_out = fun(*tree[0], **tree[1]); | ||||
|  | ||||
|           tree_fill(tree, arrays); | ||||
|           // Replace the tracers with the originals. Don't overwrite | ||||
|           // locations which were written to during the call to fun | ||||
|           int index = 0; | ||||
|           tree_visit_update(tree, [&](nb::handle node) { | ||||
|             auto replace_arr = nb::cast<mx::array>(node); | ||||
|             if (replace_arr.id() == a[index].id()) { | ||||
|               return nb::cast(arrays[index++]); | ||||
|             } else { | ||||
|               return nb::cast(replace_arr); | ||||
|             } | ||||
|           }); | ||||
|  | ||||
|           // Validate the return value of the python function | ||||
|           if (!nb::isinstance<mx::array>(py_value_out)) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun