mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	register in task submission
This commit is contained in:
		| @@ -134,7 +134,6 @@ void CommandEncoder::set_input_array( | ||||
|     const array& a, | ||||
|     int idx, | ||||
|     int64_t offset /* = 0 */) { | ||||
|   all_inputs_.insert(a.buffer().ptr()); | ||||
|   auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); | ||||
|   needs_barrier_ = | ||||
|       needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); | ||||
| @@ -151,7 +150,6 @@ void CommandEncoder::set_output_array( | ||||
|     int64_t offset /* = 0 */) { | ||||
|   // Add barriers before adding the output to the output set | ||||
|   set_input_array(a, idx, offset); | ||||
|   all_outputs_.insert(a.buffer().ptr()); | ||||
|   auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); | ||||
|   if (concurrent_) { | ||||
|     concurrent_outputs_.insert(buf); | ||||
| @@ -185,6 +183,24 @@ void CommandEncoder::dispatchThreads( | ||||
|   enc_->dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void DeviceStream::register_inputs(const std::vector<array>& arrays) { | ||||
|   std::lock_guard<std::mutex> lk(fence_mtx); | ||||
|   for (auto& a : arrays) { | ||||
|     auto buf = a.buffer().ptr(); | ||||
|     if (auto it = outputs.find(buf); it != outputs.end()) { | ||||
|       waiting_on.insert(it->second); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void DeviceStream::register_outputs(const std::vector<array>& arrays) { | ||||
|   for (auto& a : arrays) { | ||||
|     if (a.data<void>() != nullptr) { | ||||
|       all_outputs.insert(a.buffer().ptr()); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| Device::Device() { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   device_ = load_device(); | ||||
| @@ -281,38 +297,22 @@ void Device::end_encoding(int index) { | ||||
|     //   boundaries. These can be removed early from the encoders inputs and | ||||
|     //   outputs since they don't need synchronization. | ||||
|     auto& enc = *stream.encoder; | ||||
|     // Remove temporaries from inputs and outputs | ||||
|     for (auto& t : stream.temporaries) { | ||||
|       if (t.data<void>() != nullptr) { | ||||
|         enc.outputs().erase(t.buffer().ptr()); | ||||
|         enc.inputs().erase(t.buffer().ptr()); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Keep references to the fences we waited on and put them | ||||
|     // in the completion handler so they are not prematurely released | ||||
|     std::unordered_set<std::shared_ptr<Fence>> waiting_on; | ||||
|     for (auto& f : stream.waiting_on) { | ||||
|       enc->waitForFence(f->fence); | ||||
|     } | ||||
|     { | ||||
|       std::lock_guard<std::mutex> lk(stream.fence_mtx); | ||||
|       for (auto in : enc.inputs()) { | ||||
|         if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { | ||||
|           // If we've already waited on a fence, don't wait on it again. | ||||
|           if (waiting_on.find(it->second) == waiting_on.end()) { | ||||
|             enc->waitForFence(it->second->fence); | ||||
|             waiting_on.insert(it->second); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       for (auto out : enc.outputs()) { | ||||
|       for (auto out : stream.all_outputs) { | ||||
|         stream.outputs[out] = stream.fence; | ||||
|       } | ||||
|     } | ||||
|     enc->updateFence(stream.fence->fence); | ||||
|     stream.buffer->addCompletedHandler( | ||||
|         [&stream, | ||||
|          waiting_on = std::move(waiting_on), | ||||
|          waiting_on = std::move(stream.waiting_on), | ||||
|          fence = std::move(stream.fence), | ||||
|          outputs = std::move(enc.outputs()), | ||||
|          outputs = std::move(stream.all_outputs), | ||||
|          temporaries = | ||||
|              std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { | ||||
|           temporaries.clear(); | ||||
| @@ -325,6 +325,9 @@ void Device::end_encoding(int index) { | ||||
|             } | ||||
|           } | ||||
|         }); | ||||
|   } else { | ||||
|     stream.all_outputs.clear(); | ||||
|     stream.waiting_on.clear(); | ||||
|   } | ||||
|   stream.encoder = nullptr; | ||||
| } | ||||
|   | ||||
| @@ -73,16 +73,6 @@ struct CommandEncoder { | ||||
|   } | ||||
|   ~CommandEncoder(); | ||||
|  | ||||
|   // Inputs to all kernels in the encoder including temporaries | ||||
|   std::unordered_set<const void*>& inputs() { | ||||
|     return all_inputs_; | ||||
|   }; | ||||
|  | ||||
|   // Outputs of all kernels in the encoder including temporaries | ||||
|   std::unordered_set<const void*> outputs() { | ||||
|     return all_outputs_; | ||||
|   }; | ||||
|  | ||||
|  private: | ||||
|   MTL::ComputeCommandEncoder* enc_; | ||||
|   bool needs_barrier_{false}; | ||||
| @@ -90,8 +80,6 @@ struct CommandEncoder { | ||||
|   std::unordered_set<MTL::Resource*> prev_outputs_; | ||||
|   std::unordered_set<MTL::Resource*> next_outputs_; | ||||
|   std::unordered_set<MTL::Resource*> concurrent_outputs_; | ||||
|   std::unordered_set<const void*> all_inputs_; | ||||
|   std::unordered_set<const void*> all_outputs_; | ||||
| }; | ||||
|  | ||||
| struct Fence { | ||||
| @@ -121,11 +109,15 @@ struct DeviceStream { | ||||
|   MTL::CommandBuffer* buffer{nullptr}; | ||||
|   int buffer_ops{0}; | ||||
|  | ||||
|   // The command encoder, fence, and temporaries are updated between command | ||||
|   // encoders | ||||
|   void register_inputs(const std::vector<array>& inputs); | ||||
|   void register_outputs(const std::vector<array>& inputs); | ||||
|  | ||||
|   // The following variables are all reset between command encoders | ||||
|   std::unique_ptr<CommandEncoder> encoder{nullptr}; | ||||
|   std::shared_ptr<Fence> fence; | ||||
|   std::vector<array> temporaries; | ||||
|   std::unordered_set<std::shared_ptr<Fence>> waiting_on; | ||||
|   std::unordered_set<const void*> all_outputs; | ||||
| }; | ||||
|  | ||||
| class Device { | ||||
| @@ -190,10 +182,15 @@ class Device { | ||||
|  | ||||
|   void set_residency_set(const MTL::ResidencySet* residency_set); | ||||
|  | ||||
|   DeviceStream& get_stream(int index) { | ||||
|     return stream_map_.find(index)->second; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   DeviceStream& get_stream_(int index) { | ||||
|     return stream_map_.find(index)->second; | ||||
|   } | ||||
|  | ||||
|   MTL::Library* get_library_cache_(const std::string& name); | ||||
|  | ||||
|   MTL::Library* get_library_(const std::string& name); | ||||
|   | ||||
| @@ -60,8 +60,20 @@ std::function<void()> make_task(array arr, bool signal) { | ||||
|         inputs = arr.inputs(); | ||||
|       } | ||||
|  | ||||
|       //      if (dstream.active_without_fence()) { | ||||
|       //        for (auto& in : arr.inputs()) { | ||||
|       //          // If any input needs a fence | ||||
|       //          if (command_encoder.needs_fence(in)) { | ||||
|       //            end_encoding(); | ||||
|       //            break; | ||||
|       //          } | ||||
|       //        } | ||||
|       //      } | ||||
|       auto& dstream = d.get_stream(s.index); | ||||
|       dstream.register_inputs(arr.inputs()); | ||||
|       debug_set_primitive_buffer_label(command_buffer, arr.primitive()); | ||||
|       arr.primitive().eval_gpu(arr.inputs(), outputs); | ||||
|       dstream.register_outputs(outputs); | ||||
|     } | ||||
|     std::vector<std::shared_ptr<array::Data>> buffers; | ||||
|     for (auto& in : arr.inputs()) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun