mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-22 19:28:14 +08:00 
			
		
		
		
	ensure io/comm streams are active before eval (#1412)
This commit is contained in:
		| @@ -199,7 +199,6 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
| } | } | ||||||
|  |  | ||||||
| void Load::eval_gpu(const std::vector<array>& inputs, array& out) { | void Load::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||||
|   static Stream io_stream = new_stream(Device::cpu); |  | ||||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); |   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||||
|  |  | ||||||
|   auto read_task = [out = out, |   auto read_task = [out = out, | ||||||
| @@ -213,7 +212,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     fut.wait(); |     fut.wait(); | ||||||
|     out.event().signal(); |     out.event().signal(); | ||||||
|   }; |   }; | ||||||
|   scheduler::enqueue(io_stream, std::move(signal_task)); |   scheduler::enqueue(io_stream(), std::move(signal_task)); | ||||||
|   auto& d = metal::device(stream().device); |   auto& d = metal::device(stream().device); | ||||||
|   d.end_encoding(stream().index); |   d.end_encoding(stream().index); | ||||||
|   auto command_buffer = d.get_command_buffer(stream().index); |   auto command_buffer = d.get_command_buffer(stream().index); | ||||||
|   | |||||||
| @@ -255,6 +255,9 @@ Group init(bool strict /* = false */) { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   // Ensure the communication stream is alive before | ||||||
|  |   // the graph is evaluated | ||||||
|  |   detail::communication_stream(); | ||||||
|   return Group(global_group); |   return Group(global_group); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1146,9 +1146,13 @@ class Load : public UnaryPrimitive { | |||||||
|       size_t offset, |       size_t offset, | ||||||
|       bool swap_endianness = false) |       bool swap_endianness = false) | ||||||
|       : UnaryPrimitive(stream), |       : UnaryPrimitive(stream), | ||||||
|         reader_(reader), |         reader_(std::move(reader)), | ||||||
|         offset_(offset), |         offset_(offset), | ||||||
|         swap_endianness_(swap_endianness) {} |         swap_endianness_(swap_endianness) { | ||||||
|  |     if (stream.device == Device::gpu) { | ||||||
|  |       io_stream(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|   void eval_cpu(const std::vector<array>& inputs, array& out) override; |   void eval_cpu(const std::vector<array>& inputs, array& out) override; | ||||||
|   void eval_gpu(const std::vector<array>& inputs, array& out) override; |   void eval_gpu(const std::vector<array>& inputs, array& out) override; | ||||||
| @@ -1156,6 +1160,10 @@ class Load : public UnaryPrimitive { | |||||||
|   DEFINE_PRINT(Load) |   DEFINE_PRINT(Load) | ||||||
|  |  | ||||||
|  private: |  private: | ||||||
|  |   Stream& io_stream() { | ||||||
|  |     static Stream io_stream = new_stream(Device::cpu); | ||||||
|  |     return io_stream; | ||||||
|  |   }; | ||||||
|   void eval(const std::vector<array>& inputs, array& out); |   void eval(const std::vector<array>& inputs, array& out); | ||||||
|   std::shared_ptr<io::Reader> reader_; |   std::shared_ptr<io::Reader> reader_; | ||||||
|   size_t offset_; |   size_t offset_; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun