mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	more donation take 2
This commit is contained in:
		| @@ -20,21 +20,16 @@ void eval(array& arr) { | |||||||
|     } |     } | ||||||
|     arr.primitive().eval_cpu(arr.inputs(), outputs); |     arr.primitive().eval_cpu(arr.inputs(), outputs); | ||||||
|   } |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|   std::unordered_set<std::shared_ptr<array::Data>> buffers; | void finalize( | ||||||
|   for (auto& in : arr.inputs()) { |     Stream s, | ||||||
|     buffers.insert(in.data_shared_ptr()); |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) { | ||||||
|   } |  | ||||||
|   for (auto& s : arr.siblings()) { |  | ||||||
|     buffers.insert(s.data_shared_ptr()); |  | ||||||
|   } |  | ||||||
|   // Remove the output if it was donated to by an input |  | ||||||
|   if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { |  | ||||||
|     buffers.erase(it); |  | ||||||
|   } |  | ||||||
|   auto& encoder = cpu::get_command_encoder(s); |   auto& encoder = cpu::get_command_encoder(s); | ||||||
|   encoder.dispatch([buffers = std::move(buffers), |   encoder.dispatch([s, | ||||||
|                     temps = std::move(encoder.temporaries())]() {}); |                     buffers = std::move(retain_buffers), | ||||||
|  |                     temps = std::move(encoder.temporaries())]() { | ||||||
|  |   }); | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace mlx::core::cpu | } // namespace mlx::core::cpu | ||||||
|   | |||||||
| @@ -2,11 +2,16 @@ | |||||||
|  |  | ||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
|  | #include <unordered_set> | ||||||
|  |  | ||||||
| #include "mlx/array.h" | #include "mlx/array.h" | ||||||
| #include "mlx/stream.h" | #include "mlx/stream.h" | ||||||
|  |  | ||||||
| namespace mlx::core::cpu { | namespace mlx::core::cpu { | ||||||
|  |  | ||||||
| void eval(array& arr); | void eval(array& arr); | ||||||
|  | void finalize( | ||||||
|  |     Stream s, | ||||||
|  |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers); | ||||||
|  |  | ||||||
| } // namespace mlx::core::cpu | } // namespace mlx::core::cpu | ||||||
|   | |||||||
| @@ -1,5 +1,4 @@ | |||||||
| // Copyright © 2023-2024 Apple Inc. | // Copyright © 2023-2024 Apple Inc. | ||||||
|  |  | ||||||
| #include <cstdlib> | #include <cstdlib> | ||||||
| #include <sstream> | #include <sstream> | ||||||
|  |  | ||||||
|   | |||||||
| @@ -40,46 +40,32 @@ void eval(array& arr) { | |||||||
|     debug_set_primitive_buffer_label(command_buffer, arr.primitive()); |     debug_set_primitive_buffer_label(command_buffer, arr.primitive()); | ||||||
|     arr.primitive().eval_gpu(arr.inputs(), outputs); |     arr.primitive().eval_gpu(arr.inputs(), outputs); | ||||||
|   } |   } | ||||||
|   std::unordered_set<std::shared_ptr<array::Data>> buffers; | } | ||||||
|   for (auto& in : arr.inputs()) { |  | ||||||
|     buffers.insert(in.data_shared_ptr()); |  | ||||||
|   } |  | ||||||
|   for (auto& s : arr.siblings()) { |  | ||||||
|     buffers.insert(s.data_shared_ptr()); |  | ||||||
|   } |  | ||||||
|   // Remove the output if it was donated to by an input |  | ||||||
|   if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { |  | ||||||
|     buffers.erase(it); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (d.command_buffer_needs_commit(s.index)) { | void finalize( | ||||||
|  |     Stream s, | ||||||
|  |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||||
|  |     bool force_commit) { | ||||||
|  |   auto pool = new_scoped_memory_pool(); | ||||||
|  |   auto& d = metal::device(s.device); | ||||||
|  |   auto command_buffer = d.get_command_buffer(s.index); | ||||||
|  |   if (d.command_buffer_needs_commit(s.index) || force_commit) { | ||||||
|     d.end_encoding(s.index); |     d.end_encoding(s.index); | ||||||
|     scheduler::notify_new_task(s); |     scheduler::notify_new_task(s); | ||||||
|     command_buffer->addCompletedHandler( |     command_buffer->addCompletedHandler( | ||||||
|         [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { |         [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { | ||||||
|           scheduler::notify_task_completion(s); |           scheduler::notify_task_completion(s); | ||||||
|           check_error(cbuf); |           check_error(cbuf); | ||||||
|         }); |         }); | ||||||
|     d.commit_command_buffer(s.index); |     d.commit_command_buffer(s.index); | ||||||
|     d.get_command_buffer(s.index); |  | ||||||
|   } else { |   } else { | ||||||
|     command_buffer->addCompletedHandler( |     command_buffer->addCompletedHandler( | ||||||
|         [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { |         [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { | ||||||
|           check_error(cbuf); |           check_error(cbuf); | ||||||
|         }); |         }); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| void finalize(Stream s) { |  | ||||||
|   auto pool = new_scoped_memory_pool(); |  | ||||||
|   auto& d = metal::device(s.device); |  | ||||||
|   auto cb = d.get_command_buffer(s.index); |  | ||||||
|   d.end_encoding(s.index); |  | ||||||
|   cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); |  | ||||||
|   d.commit_command_buffer(s.index); |  | ||||||
|   d.get_command_buffer(s.index); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void synchronize(Stream s) { | void synchronize(Stream s) { | ||||||
|   auto pool = new_scoped_memory_pool(); |   auto pool = new_scoped_memory_pool(); | ||||||
|   auto& d = metal::device(s.device); |   auto& d = metal::device(s.device); | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ | |||||||
|  |  | ||||||
| #include <future> | #include <future> | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <unordered_set> | ||||||
|  |  | ||||||
| #include "mlx/array.h" | #include "mlx/array.h" | ||||||
| #include "mlx/stream.h" | #include "mlx/stream.h" | ||||||
| @@ -15,7 +16,10 @@ void new_stream(Stream stream); | |||||||
| std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool(); | std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool(); | ||||||
|  |  | ||||||
| void eval(array& arr); | void eval(array& arr); | ||||||
| void finalize(Stream s); | void finalize( | ||||||
|  |     Stream s, | ||||||
|  |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||||
|  |     bool force_commit); | ||||||
| void synchronize(Stream s); | void synchronize(Stream s); | ||||||
|  |  | ||||||
| } // namespace mlx::core::metal | } // namespace mlx::core::metal | ||||||
|   | |||||||
| @@ -21,7 +21,10 @@ void eval(array&) { | |||||||
|       "[metal::eval] Cannot eval on GPU without metal backend"); |       "[metal::eval] Cannot eval on GPU without metal backend"); | ||||||
| } | } | ||||||
|  |  | ||||||
| void finalize(Stream) { | void finalize( | ||||||
|  |     Stream, | ||||||
|  |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers, | ||||||
|  |     bool) { | ||||||
|   throw std::runtime_error( |   throw std::runtime_error( | ||||||
|       "[metal::finalize] Cannot finalize GPU without metal backend"); |       "[metal::finalize] Cannot finalize GPU without metal backend"); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -160,6 +160,7 @@ array eval_impl(std::vector<array> outputs, bool async) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         auto it = cache.find(in.id()); |         auto it = cache.find(in.id()); | ||||||
|  |  | ||||||
|         it->second -= 1; |         it->second -= 1; | ||||||
|  |  | ||||||
|         if (it->second != 0) { |         if (it->second != 0) { | ||||||
| @@ -180,6 +181,8 @@ array eval_impl(std::vector<array> outputs, bool async) { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   std::unordered_map<std::uintptr_t, std::weak_ptr<array::Data>> | ||||||
|  |       unretained_buffers; | ||||||
|   while (!tape.empty()) { |   while (!tape.empty()) { | ||||||
|     auto arr = std::move(tape.back()); |     auto arr = std::move(tape.back()); | ||||||
|     tape.pop_back(); |     tape.pop_back(); | ||||||
| @@ -225,7 +228,7 @@ array eval_impl(std::vector<array> outputs, bool async) { | |||||||
|       // Commit any open streams |       // Commit any open streams | ||||||
|       for (auto& [_, e] : events) { |       for (auto& [_, e] : events) { | ||||||
|         if (e.stream().device == Device::gpu) { |         if (e.stream().device == Device::gpu) { | ||||||
|           metal::finalize(e.stream()); |           metal::finalize(e.stream(), {}, true); | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       scheduler::wait_for_one(); |       scheduler::wait_for_one(); | ||||||
| @@ -246,24 +249,59 @@ array eval_impl(std::vector<array> outputs, bool async) { | |||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     arr.set_status(array::Status::evaluated); |     arr.set_status(array::Status::evaluated); | ||||||
|     // TODO Maybe always want the fence coherent kernel in the same cbuf |  | ||||||
|     // as the other kernels? |     std::unordered_set<std::shared_ptr<array::Data>> retain_buffers; | ||||||
|  |  | ||||||
|     maybe_update_fence(arr); |     maybe_update_fence(arr); | ||||||
|     for (auto& sib : arr.siblings()) { |     for (auto& sib : arr.siblings()) { | ||||||
|       sib.set_status(array::Status::evaluated); |       sib.set_status(array::Status::evaluated); | ||||||
|       maybe_update_fence(sib); |       maybe_update_fence(sib); | ||||||
|  |       retain_buffers.insert(sib.data_shared_ptr()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     for (auto& in : arr.inputs()) { | ||||||
|  |       retain_buffers.insert(in.data_shared_ptr()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if (!arr.is_tracer()) { |     if (!arr.is_tracer()) { | ||||||
|       arr.detach(); |       arr.detach(); | ||||||
|     } |     } | ||||||
|   } |  | ||||||
|  |  | ||||||
|  |     for (auto it = retain_buffers.begin(); it != retain_buffers.end();) { | ||||||
|  |       if (it->use_count() > 1) { | ||||||
|  |         // At this point the buffer must be in one of two states: | ||||||
|  |         // 1. Held by another array | ||||||
|  |         // 2. Held from a prevous async_eval | ||||||
|  |         unretained_buffers.emplace(std::uintptr_t(it->get()), *it); | ||||||
|  |         it = retain_buffers.erase(it); | ||||||
|  |       } else { | ||||||
|  |         unretained_buffers.erase(std::uintptr_t(it->get())); | ||||||
|  |         ++it; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     if (stream.device == Device::gpu) { | ||||||
|  |       metal::finalize(stream, std::move(retain_buffers), false); | ||||||
|  |     } else { | ||||||
|  |       cpu::finalize(stream, std::move(retain_buffers)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|   // Signal the event in its stream |   // Signal the event in its stream | ||||||
|   for (auto& [_, e] : events) { |   for (auto& [_, e] : events) { | ||||||
|     auto s = e.stream(); |     auto s = e.stream(); | ||||||
|     e.signal(s); |     e.signal(s); | ||||||
|  |     std::unordered_set<std::shared_ptr<array::Data>> retain; | ||||||
|  |     if (s == stream) { | ||||||
|  |       for (auto& [_, b] : unretained_buffers) { | ||||||
|  |         auto ptr = b.lock(); | ||||||
|  |         if (ptr) { | ||||||
|  |           retain.insert(ptr); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|     if (s.device == Device::gpu) { |     if (s.device == Device::gpu) { | ||||||
|       metal::finalize(s); |       metal::finalize(s, std::move(retain), true); | ||||||
|  |     } else { | ||||||
|  |       cpu::finalize(s, std::move(retain)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase): | |||||||
|         mx.eval(z) |         mx.eval(z) | ||||||
|         mx.set_memory_limit(old_limit) |         mx.set_memory_limit(old_limit) | ||||||
|  |  | ||||||
|  |     def test_donation_multiple_inputs(self): | ||||||
|  |         def fun(its, x, y): | ||||||
|  |             for _ in range(its): | ||||||
|  |                 a = x + y  # y should donate | ||||||
|  |                 b = x + a  # x should donate | ||||||
|  |                 x, y = a, b | ||||||
|  |             return x, y | ||||||
|  |  | ||||||
|  |         x = mx.zeros((128, 128)) | ||||||
|  |         y = mx.zeros((128, 128)) | ||||||
|  |         mx.reset_peak_memory() | ||||||
|  |         a, b = fun(2, x, y) | ||||||
|  |         mx.eval(a, b) | ||||||
|  |         mx.synchronize() | ||||||
|  |         mem2 = mx.get_peak_memory() | ||||||
|  |         a, b = fun(10, x, y) | ||||||
|  |         mx.eval(a, b) | ||||||
|  |         mx.synchronize() | ||||||
|  |         mem10 = mx.get_peak_memory() | ||||||
|  |         self.assertEqual(mem2, mem10) | ||||||
|  |  | ||||||
|  |     def test_async_with_delete(self): | ||||||
|  |         a = mx.ones((5, 5)) | ||||||
|  |         for _ in range(100): | ||||||
|  |             a = mx.abs(a) | ||||||
|  |         mx.async_eval(a) | ||||||
|  |         del a | ||||||
|  |         mx.clear_cache() | ||||||
|  |         mx.synchronize() | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun