mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix malloc or wait deadlock (#1976)
This commit is contained in:
		| @@ -57,23 +57,19 @@ void init_metal(nb::module_& m) { | ||||
|       "set_memory_limit", | ||||
|       &mx::metal::set_memory_limit, | ||||
|       "limit"_a, | ||||
|       nb::kw_only(), | ||||
|       "relaxed"_a = true, | ||||
|       R"pbdoc( | ||||
|       Set the memory limit. | ||||
|  | ||||
|       Memory allocations will wait on scheduled tasks to complete if the limit | ||||
|       is exceeded. If there are no more scheduled tasks an error will be raised | ||||
|       if ``relaxed`` is ``False``. Otherwise memory will be allocated | ||||
|       (including the potential for swap) if ``relaxed`` is ``True``. | ||||
|       The memory limit is a guideline for the maximum amount of memory to use | ||||
|       during graph evaluation. If the memory limit is exceeded and there is no | ||||
|       more RAM (including swap when available) allocations will result in an | ||||
|       exception. | ||||
|  | ||||
|       The memory limit defaults to 1.5 times the maximum recommended working set | ||||
|       size reported by the device. | ||||
|       When metal is available the memory limit defaults to 1.5 times the | ||||
|       maximum recommended working set size reported by the device. | ||||
|  | ||||
|       Args: | ||||
|         limit (int): Memory limit in bytes. | ||||
|         relaxed (bool, optional): If `False`` an error is raised if the limit | ||||
|           is exceeded. Default: ``True`` | ||||
|  | ||||
|       Returns: | ||||
|         int: The previous memory limit in bytes. | ||||
|   | ||||
| @@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|             x = mx.abs(x, stream=s2) | ||||
|         mx.eval(x) | ||||
|  | ||||
|         s1 = mx.default_stream(mx.gpu) | ||||
|         s2 = mx.new_stream(mx.gpu) | ||||
|         old_limit = mx.metal.set_memory_limit(1000) | ||||
|  | ||||
|         x = mx.ones((512, 512), stream=s2) | ||||
|         for _ in range(80): | ||||
|             x = mx.abs(x, stream=s1) | ||||
|         y = mx.abs(x, stream=s2) | ||||
|         z = mx.abs(y, stream=s2) | ||||
|         mx.eval(z) | ||||
|         mx.metal.set_memory_limit(old_limit) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun