mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Ensure we only have one copy of the fence
This commit is contained in:
		| @@ -29,10 +29,10 @@ void AllReduce::eval_gpu( | ||||
|  | ||||
|   auto& in = inputs[0]; | ||||
|  | ||||
|   Fence f{stream()}; | ||||
|   auto f = std::make_shared<Fence>(stream()); | ||||
|  | ||||
|   if (in.event().valid()) { | ||||
|     f.update_gpu(in); | ||||
|     f->update_gpu(in); | ||||
|   } | ||||
|  | ||||
|   auto& out = outputs[0]; | ||||
| @@ -41,7 +41,7 @@ void AllReduce::eval_gpu( | ||||
|   } else { | ||||
|     out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|   } | ||||
|   f.wait_gpu(out); | ||||
|   f->wait_gpu(out); | ||||
|  | ||||
|   auto task = [in = in, | ||||
|                out = unsafe_weak_copy(out), | ||||
| @@ -49,7 +49,7 @@ void AllReduce::eval_gpu( | ||||
|                reduce_type = reduce_type_, | ||||
|                group = group()]() mutable { | ||||
|     if (in.event().valid()) { | ||||
|       f.wait(); | ||||
|       f->wait(); | ||||
|     } | ||||
|     switch (reduce_type) { | ||||
|       case Sum: | ||||
| @@ -59,7 +59,7 @@ void AllReduce::eval_gpu( | ||||
|       default: | ||||
|         throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|     } | ||||
|     f.update(); | ||||
|     f->update(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
| } | ||||
| @@ -74,22 +74,22 @@ void AllGather::eval_gpu( | ||||
|  | ||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|  | ||||
|   Fence f{stream()}; | ||||
|   auto f = std::make_shared<Fence>(stream()); | ||||
|  | ||||
|   if (in.event().valid()) { | ||||
|     f.update_gpu(in); | ||||
|     f->update_gpu(in); | ||||
|   } | ||||
|   f.wait_gpu(out); | ||||
|   f->wait_gpu(out); | ||||
|  | ||||
|   auto task = [in = in, | ||||
|                out = unsafe_weak_copy(out), | ||||
|                f = std::move(f), | ||||
|                group = group()]() mutable { | ||||
|     if (in.event().valid()) { | ||||
|       f.wait(); | ||||
|       f->wait(); | ||||
|     } | ||||
|     distributed::detail::all_gather(group, in, out); | ||||
|     f.update(); | ||||
|     f->update(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
| } | ||||
| @@ -103,9 +103,9 @@ void Send::eval_gpu( | ||||
|   auto& in = inputs[0]; | ||||
|  | ||||
|   // Encode a signal event for the input | ||||
|   Fence f{stream()}; | ||||
|   auto f = std::make_shared<Fence>(stream()); | ||||
|   if (in.event().valid()) { | ||||
|     f.update_gpu(in); | ||||
|     f->update_gpu(in); | ||||
|   } | ||||
|  | ||||
|   auto& out = outputs[0]; | ||||
| @@ -118,7 +118,7 @@ void Send::eval_gpu( | ||||
|                group = group(), | ||||
|                dst = dst_]() mutable { | ||||
|     if (in.event().valid()) { | ||||
|       f.wait(); | ||||
|       f->wait(); | ||||
|     } | ||||
|     distributed::detail::send(group, out, dst); | ||||
|   }; | ||||
| @@ -135,8 +135,8 @@ void Recv::eval_gpu( | ||||
|  | ||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|  | ||||
|   Fence f{stream()}; | ||||
|   f.wait_gpu(out); | ||||
|   auto f = std::make_shared<Fence>(stream()); | ||||
|   f->wait_gpu(out); | ||||
|  | ||||
|   // Schedule an async recv on the comm stream | ||||
|   auto task = [out = unsafe_weak_copy(out), | ||||
| @@ -144,7 +144,7 @@ void Recv::eval_gpu( | ||||
|                group = group(), | ||||
|                src = src_]() mutable { | ||||
|     distributed::detail::recv(group, out, src); | ||||
|     f.update(); | ||||
|     f->update(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
| } | ||||
|   | ||||
| @@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) { | ||||
| } | ||||
|  | ||||
| Fence::~Fence() { | ||||
|   if (use_fast_) { | ||||
|   if (fence_ != nullptr && use_fast_) { | ||||
|     cpu_value()[0] = INT_MAX; | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -20,6 +20,9 @@ namespace mlx::core { | ||||
|  */ | ||||
| class Fence { | ||||
|  public: | ||||
|   Fence(const Fence&) = delete; | ||||
|   Fence& operator=(const Fence&) = delete; | ||||
|  | ||||
|   Fence(const Stream& stream); | ||||
|   ~Fence(); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos