mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	more donation take 2
This commit is contained in:
		| @@ -20,21 +20,16 @@ void eval(array& arr) { | ||||
|     } | ||||
|     arr.primitive().eval_cpu(arr.inputs(), outputs); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   std::unordered_set<std::shared_ptr<array::Data>> buffers; | ||||
|   for (auto& in : arr.inputs()) { | ||||
|     buffers.insert(in.data_shared_ptr()); | ||||
|   } | ||||
|   for (auto& s : arr.siblings()) { | ||||
|     buffers.insert(s.data_shared_ptr()); | ||||
|   } | ||||
|   // Remove the output if it was donated to by an input | ||||
|   if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { | ||||
|     buffers.erase(it); | ||||
|   } | ||||
| void finalize( | ||||
|     Stream s, | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) { | ||||
|   auto& encoder = cpu::get_command_encoder(s); | ||||
|   encoder.dispatch([buffers = std::move(buffers), | ||||
|                     temps = std::move(encoder.temporaries())]() {}); | ||||
|   encoder.dispatch([s, | ||||
|                     buffers = std::move(retain_buffers), | ||||
|                     temps = std::move(encoder.temporaries())]() { | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cpu | ||||
|   | ||||
| @@ -2,11 +2,16 @@ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <unordered_set> | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "mlx/stream.h" | ||||
|  | ||||
| namespace mlx::core::cpu { | ||||
|  | ||||
| void eval(array& arr); | ||||
| void finalize( | ||||
|     Stream s, | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers); | ||||
|  | ||||
| } // namespace mlx::core::cpu | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <cstdlib> | ||||
| #include <sstream> | ||||
|  | ||||
|   | ||||
| @@ -40,46 +40,32 @@ void eval(array& arr) { | ||||
|     debug_set_primitive_buffer_label(command_buffer, arr.primitive()); | ||||
|     arr.primitive().eval_gpu(arr.inputs(), outputs); | ||||
|   } | ||||
|   std::unordered_set<std::shared_ptr<array::Data>> buffers; | ||||
|   for (auto& in : arr.inputs()) { | ||||
|     buffers.insert(in.data_shared_ptr()); | ||||
|   } | ||||
|   for (auto& s : arr.siblings()) { | ||||
|     buffers.insert(s.data_shared_ptr()); | ||||
|   } | ||||
|   // Remove the output if it was donated to by an input | ||||
|   if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { | ||||
|     buffers.erase(it); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   if (d.command_buffer_needs_commit(s.index)) { | ||||
| void finalize( | ||||
|     Stream s, | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||
|     bool force_commit) { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   if (d.command_buffer_needs_commit(s.index) || force_commit) { | ||||
|     d.end_encoding(s.index); | ||||
|     scheduler::notify_new_task(s); | ||||
|     command_buffer->addCompletedHandler( | ||||
|         [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { | ||||
|         [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { | ||||
|           scheduler::notify_task_completion(s); | ||||
|           check_error(cbuf); | ||||
|         }); | ||||
|     d.commit_command_buffer(s.index); | ||||
|     d.get_command_buffer(s.index); | ||||
|   } else { | ||||
|     command_buffer->addCompletedHandler( | ||||
|         [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { | ||||
|         [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { | ||||
|           check_error(cbuf); | ||||
|         }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void finalize(Stream s) { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   auto cb = d.get_command_buffer(s.index); | ||||
|   d.end_encoding(s.index); | ||||
|   cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); | ||||
|   d.commit_command_buffer(s.index); | ||||
|   d.get_command_buffer(s.index); | ||||
| } | ||||
|  | ||||
| void synchronize(Stream s) { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
|  | ||||
| #include <future> | ||||
| #include <memory> | ||||
| #include <unordered_set> | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "mlx/stream.h" | ||||
| @@ -15,7 +16,10 @@ void new_stream(Stream stream); | ||||
| std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool(); | ||||
|  | ||||
| void eval(array& arr); | ||||
| void finalize(Stream s); | ||||
| void finalize( | ||||
|     Stream s, | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||
|     bool force_commit); | ||||
| void synchronize(Stream s); | ||||
|  | ||||
| } // namespace mlx::core::metal | ||||
|   | ||||
| @@ -21,7 +21,10 @@ void eval(array&) { | ||||
|       "[metal::eval] Cannot eval on GPU without metal backend"); | ||||
| } | ||||
|  | ||||
| void finalize(Stream) { | ||||
| void finalize( | ||||
|     Stream, | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||
|     bool) { | ||||
|   throw std::runtime_error( | ||||
|       "[metal::finalize] Cannot finalize GPU without metal backend"); | ||||
| } | ||||
|   | ||||
| @@ -160,6 +160,7 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|         } | ||||
|  | ||||
|         auto it = cache.find(in.id()); | ||||
|  | ||||
|         it->second -= 1; | ||||
|  | ||||
|         if (it->second != 0) { | ||||
| @@ -180,6 +181,8 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::unordered_map<std::uintptr_t, std::weak_ptr<array::Data>> | ||||
|       unretained_buffers; | ||||
|   while (!tape.empty()) { | ||||
|     auto arr = std::move(tape.back()); | ||||
|     tape.pop_back(); | ||||
| @@ -225,7 +228,7 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|       // Commit any open streams | ||||
|       for (auto& [_, e] : events) { | ||||
|         if (e.stream().device == Device::gpu) { | ||||
|           metal::finalize(e.stream()); | ||||
|           metal::finalize(e.stream(), {}, true); | ||||
|         } | ||||
|       } | ||||
|       scheduler::wait_for_one(); | ||||
| @@ -246,24 +249,59 @@ array eval_impl(std::vector<array> outputs, bool async) { | ||||
|     }; | ||||
|  | ||||
|     arr.set_status(array::Status::evaluated); | ||||
|     // TODO Maybe always want the fence coherent kernel in the same cbuf | ||||
|     // as the other kernels? | ||||
|  | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers; | ||||
|  | ||||
|     maybe_update_fence(arr); | ||||
|     for (auto& sib : arr.siblings()) { | ||||
|       sib.set_status(array::Status::evaluated); | ||||
|       maybe_update_fence(sib); | ||||
|       retain_buffers.insert(sib.data_shared_ptr()); | ||||
|     } | ||||
|  | ||||
|     for (auto& in : arr.inputs()) { | ||||
|       retain_buffers.insert(in.data_shared_ptr()); | ||||
|     } | ||||
|  | ||||
|     if (!arr.is_tracer()) { | ||||
|       arr.detach(); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|     for (auto it = retain_buffers.begin(); it != retain_buffers.end();) { | ||||
|       if (it->use_count() > 1) { | ||||
|         // At this point the buffer must be in one of two states: | ||||
|         // 1. Held by another array | ||||
|         // 2. Held from a prevous async_eval | ||||
|         unretained_buffers.emplace(std::uintptr_t(it->get()), *it); | ||||
|         it = retain_buffers.erase(it); | ||||
|       } else { | ||||
|         unretained_buffers.erase(std::uintptr_t(it->get())); | ||||
|         ++it; | ||||
|       } | ||||
|     } | ||||
|     if (stream.device == Device::gpu) { | ||||
|       metal::finalize(stream, std::move(retain_buffers), false); | ||||
|     } else { | ||||
|       cpu::finalize(stream, std::move(retain_buffers)); | ||||
|     } | ||||
|   } | ||||
|   // Signal the event in its stream | ||||
|   for (auto& [_, e] : events) { | ||||
|     auto s = e.stream(); | ||||
|     e.signal(s); | ||||
|     std::unordered_set<std::shared_ptr<array::Data>> retain; | ||||
|     if (s == stream) { | ||||
|       for (auto& [_, b] : unretained_buffers) { | ||||
|         auto ptr = b.lock(); | ||||
|         if (ptr) { | ||||
|           retain.insert(ptr); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     if (s.device == Device::gpu) { | ||||
|       metal::finalize(s); | ||||
|       metal::finalize(s, std::move(retain), true); | ||||
|     } else { | ||||
|       cpu::finalize(s, std::move(retain)); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|         mx.eval(z) | ||||
|         mx.set_memory_limit(old_limit) | ||||
|  | ||||
|     def test_donation_multiple_inputs(self): | ||||
|         def fun(its, x, y): | ||||
|             for _ in range(its): | ||||
|                 a = x + y  # y should donate | ||||
|                 b = x + a  # x should donate | ||||
|                 x, y = a, b | ||||
|             return x, y | ||||
|  | ||||
|         x = mx.zeros((128, 128)) | ||||
|         y = mx.zeros((128, 128)) | ||||
|         mx.reset_peak_memory() | ||||
|         a, b = fun(2, x, y) | ||||
|         mx.eval(a, b) | ||||
|         mx.synchronize() | ||||
|         mem2 = mx.get_peak_memory() | ||||
|         a, b = fun(10, x, y) | ||||
|         mx.eval(a, b) | ||||
|         mx.synchronize() | ||||
|         mem10 = mx.get_peak_memory() | ||||
|         self.assertEqual(mem2, mem10) | ||||
|  | ||||
|     def test_async_with_delete(self): | ||||
|         a = mx.ones((5, 5)) | ||||
|         for _ in range(100): | ||||
|             a = mx.abs(a) | ||||
|         mx.async_eval(a) | ||||
|         del a | ||||
|         mx.clear_cache() | ||||
|         mx.synchronize() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun