mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Ensure we only have one copy of the fence
This commit is contained in:
		| @@ -29,10 +29,10 @@ void AllReduce::eval_gpu( | |||||||
|  |  | ||||||
|   auto& in = inputs[0]; |   auto& in = inputs[0]; | ||||||
|  |  | ||||||
|   Fence f{stream()}; |   auto f = std::make_shared<Fence>(stream()); | ||||||
|  |  | ||||||
|   if (in.event().valid()) { |   if (in.event().valid()) { | ||||||
|     f.update_gpu(in); |     f->update_gpu(in); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto& out = outputs[0]; |   auto& out = outputs[0]; | ||||||
| @@ -41,7 +41,7 @@ void AllReduce::eval_gpu( | |||||||
|   } else { |   } else { | ||||||
|     out.set_data(allocator::malloc_or_wait(out.nbytes())); |     out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||||
|   } |   } | ||||||
|   f.wait_gpu(out); |   f->wait_gpu(out); | ||||||
|  |  | ||||||
|   auto task = [in = in, |   auto task = [in = in, | ||||||
|                out = unsafe_weak_copy(out), |                out = unsafe_weak_copy(out), | ||||||
| @@ -49,7 +49,7 @@ void AllReduce::eval_gpu( | |||||||
|                reduce_type = reduce_type_, |                reduce_type = reduce_type_, | ||||||
|                group = group()]() mutable { |                group = group()]() mutable { | ||||||
|     if (in.event().valid()) { |     if (in.event().valid()) { | ||||||
|       f.wait(); |       f->wait(); | ||||||
|     } |     } | ||||||
|     switch (reduce_type) { |     switch (reduce_type) { | ||||||
|       case Sum: |       case Sum: | ||||||
| @@ -59,7 +59,7 @@ void AllReduce::eval_gpu( | |||||||
|       default: |       default: | ||||||
|         throw std::runtime_error("Only all reduce sum is supported for now"); |         throw std::runtime_error("Only all reduce sum is supported for now"); | ||||||
|     } |     } | ||||||
|     f.update(); |     f->update(); | ||||||
|   }; |   }; | ||||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); |   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||||
| } | } | ||||||
| @@ -74,22 +74,22 @@ void AllGather::eval_gpu( | |||||||
|  |  | ||||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); |   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||||
|  |  | ||||||
|   Fence f{stream()}; |   auto f = std::make_shared<Fence>(stream()); | ||||||
|  |  | ||||||
|   if (in.event().valid()) { |   if (in.event().valid()) { | ||||||
|     f.update_gpu(in); |     f->update_gpu(in); | ||||||
|   } |   } | ||||||
|   f.wait_gpu(out); |   f->wait_gpu(out); | ||||||
|  |  | ||||||
|   auto task = [in = in, |   auto task = [in = in, | ||||||
|                out = unsafe_weak_copy(out), |                out = unsafe_weak_copy(out), | ||||||
|                f = std::move(f), |                f = std::move(f), | ||||||
|                group = group()]() mutable { |                group = group()]() mutable { | ||||||
|     if (in.event().valid()) { |     if (in.event().valid()) { | ||||||
|       f.wait(); |       f->wait(); | ||||||
|     } |     } | ||||||
|     distributed::detail::all_gather(group, in, out); |     distributed::detail::all_gather(group, in, out); | ||||||
|     f.update(); |     f->update(); | ||||||
|   }; |   }; | ||||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); |   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||||
| } | } | ||||||
| @@ -103,9 +103,9 @@ void Send::eval_gpu( | |||||||
|   auto& in = inputs[0]; |   auto& in = inputs[0]; | ||||||
|  |  | ||||||
|   // Encode a signal event for the input |   // Encode a signal event for the input | ||||||
|   Fence f{stream()}; |   auto f = std::make_shared<Fence>(stream()); | ||||||
|   if (in.event().valid()) { |   if (in.event().valid()) { | ||||||
|     f.update_gpu(in); |     f->update_gpu(in); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto& out = outputs[0]; |   auto& out = outputs[0]; | ||||||
| @@ -118,7 +118,7 @@ void Send::eval_gpu( | |||||||
|                group = group(), |                group = group(), | ||||||
|                dst = dst_]() mutable { |                dst = dst_]() mutable { | ||||||
|     if (in.event().valid()) { |     if (in.event().valid()) { | ||||||
|       f.wait(); |       f->wait(); | ||||||
|     } |     } | ||||||
|     distributed::detail::send(group, out, dst); |     distributed::detail::send(group, out, dst); | ||||||
|   }; |   }; | ||||||
| @@ -135,8 +135,8 @@ void Recv::eval_gpu( | |||||||
|  |  | ||||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); |   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||||
|  |  | ||||||
|   Fence f{stream()}; |   auto f = std::make_shared<Fence>(stream()); | ||||||
|   f.wait_gpu(out); |   f->wait_gpu(out); | ||||||
|  |  | ||||||
|   // Schedule an async recv on the comm stream |   // Schedule an async recv on the comm stream | ||||||
|   auto task = [out = unsafe_weak_copy(out), |   auto task = [out = unsafe_weak_copy(out), | ||||||
| @@ -144,7 +144,7 @@ void Recv::eval_gpu( | |||||||
|                group = group(), |                group = group(), | ||||||
|                src = src_]() mutable { |                src = src_]() mutable { | ||||||
|     distributed::detail::recv(group, out, src); |     distributed::detail::recv(group, out, src); | ||||||
|     f.update(); |     f->update(); | ||||||
|   }; |   }; | ||||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); |   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Fence::~Fence() { | Fence::~Fence() { | ||||||
|   if (use_fast_) { |   if (fence_ != nullptr && use_fast_) { | ||||||
|     cpu_value()[0] = INT_MAX; |     cpu_value()[0] = INT_MAX; | ||||||
|   } |   } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -20,6 +20,9 @@ namespace mlx::core { | |||||||
|  */ |  */ | ||||||
| class Fence { | class Fence { | ||||||
|  public: |  public: | ||||||
|  |   Fence(const Fence&) = delete; | ||||||
|  |   Fence& operator=(const Fence&) = delete; | ||||||
|  |  | ||||||
|   Fence(const Stream& stream); |   Fence(const Stream& stream); | ||||||
|   ~Fence(); |   ~Fence(); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos