mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	mpi send use input as output (#1750)
* mpi send use input as output * move earlier
This commit is contained in:
		| @@ -3,6 +3,7 @@ | ||||
| #include <cassert> | ||||
|  | ||||
| #include "mlx/allocator.h" | ||||
| #include "mlx/backend/common/utils.h" | ||||
| #include "mlx/backend/metal/device.h" | ||||
| #include "mlx/distributed/ops.h" | ||||
| #include "mlx/distributed/primitives.h" | ||||
| @@ -89,13 +90,14 @@ void Send::eval_gpu( | ||||
|  | ||||
|   auto& in = inputs[0]; | ||||
|   auto& out = outputs[0]; | ||||
|   move_or_copy(in, out); | ||||
|  | ||||
|   // Schedule an async send on the comm stream | ||||
|   auto task = [in = in, out = out, group = group(), dst = dst_]() mutable { | ||||
|     if (in.event().valid()) { | ||||
|       in.event().wait(); | ||||
|     } | ||||
|     distributed::detail::send(group, in, dst); | ||||
|     distributed::detail::send(group, out, dst); | ||||
|     out.event().signal(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
| @@ -133,6 +135,7 @@ void Recv::eval_gpu( | ||||
|   // Encode a wait event as there is no input for the recv to encode a signal. | ||||
|   auto& s = stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   d.end_encoding(s.index); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   command_buffer->encodeWait( | ||||
|       static_cast<MTL::Event*>(out.event().raw_event().get()), | ||||
|   | ||||
| @@ -78,7 +78,10 @@ array send( | ||||
|   } | ||||
|  | ||||
|   return array( | ||||
|       {0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x}); | ||||
|       x.shape(), | ||||
|       x.dtype(), | ||||
|       std::make_shared<Send>(to_stream(s), group, dst), | ||||
|       {x}); | ||||
| } | ||||
|  | ||||
| array recv( | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| #include <cassert> | ||||
|  | ||||
| #include "mlx/allocator.h" | ||||
| #include "mlx/backend/common/utils.h" | ||||
| #include "mlx/distributed/ops.h" | ||||
| #include "mlx/distributed/primitives.h" | ||||
| #include "mlx/ops.h" | ||||
| @@ -105,6 +106,7 @@ void Send::eval_cpu( | ||||
|   assert(outputs.size() == 1); | ||||
|  | ||||
|   distributed::detail::send(group(), inputs[0], dst_); | ||||
|   move_or_copy(inputs[0], outputs[0]); | ||||
| } | ||||
|  | ||||
| std::pair<std::vector<array>, std::vector<int>> Send::vmap( | ||||
|   | ||||
| @@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) { | ||||
|             in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|           array: An empty array which when evaluated the send is performed. | ||||
|           array: An array identical to ``x`` which when evaluated the send is performed. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|   | ||||
| @@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) | ||||
|  | ||||
|         # Check recv and computation in same eval: | ||||
|         y = mx.ones((5, 5)) + mx.array(2.0) | ||||
|         if send: | ||||
|             x = mx.distributed.send(2 * x, neighbor, group=pairs) | ||||
|         else: | ||||
|             x = mx.distributed.recv_like(x, neighbor, group=pairs) | ||||
|         mx.eval(y, x) | ||||
|  | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun