mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
		| @@ -33,10 +33,12 @@ DEFAULT(ArgSort) | ||||
| DEFAULT(AsStrided) | ||||
| DEFAULT(Broadcast) | ||||
| DEFAULT(Ceil) | ||||
| DEFAULT_MULTI(Compiled) | ||||
| DEFAULT(Concatenate) | ||||
| DEFAULT(Copy) | ||||
| DEFAULT_MULTI(CustomVJP) | ||||
| DEFAULT_MULTI(Depends) | ||||
| DEFAULT_MULTI(DivMod) | ||||
| DEFAULT(Equal) | ||||
| DEFAULT(Erf) | ||||
| DEFAULT(ErfInv) | ||||
| @@ -57,6 +59,7 @@ DEFAULT(Minimum) | ||||
| DEFAULT(NotEqual) | ||||
| DEFAULT(Pad) | ||||
| DEFAULT(Partition) | ||||
| DEFAULT_MULTI(QRF) | ||||
| DEFAULT(RandomBits) | ||||
| DEFAULT(Reshape) | ||||
| DEFAULT(Round) | ||||
| @@ -68,8 +71,6 @@ DEFAULT_MULTI(Split) | ||||
| DEFAULT(Sort) | ||||
| DEFAULT(StopGradient) | ||||
| DEFAULT(Transpose) | ||||
| DEFAULT_MULTI(DivMod) | ||||
| DEFAULT_MULTI(QRF) | ||||
|  | ||||
| void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 1); | ||||
|   | ||||
| @@ -3,6 +3,7 @@ target_sources( | ||||
|   PRIVATE | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp | ||||
|   | ||||
							
								
								
									
										59
									
								
								mlx/backend/common/compiled.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								mlx/backend/common/compiled.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <queue> | ||||
|  | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| // Build the real tape | ||||
| std::pair<std::queue<array>, std::vector<array>> trace_to_real( | ||||
|     const std::vector<array>& trace_tape, | ||||
|     const std::vector<array>& trace_inputs, | ||||
|     const std::vector<array>& trace_outputs, | ||||
|     const std::vector<array>& inputs) { | ||||
|   std::unordered_map<uintptr_t, array> trace_to_real; | ||||
|   for (int i = 0; i < inputs.size(); ++i) { | ||||
|     trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); | ||||
|   } | ||||
|   std::queue<array> tape; | ||||
|   for (auto& a : trace_tape) { | ||||
|     // Find real inputs | ||||
|     std::vector<array> real_inputs; | ||||
|     for (auto& in : a.inputs()) { | ||||
|       real_inputs.push_back(trace_to_real.at(in.id())); | ||||
|     } | ||||
|     tape.push( | ||||
|         array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs))); | ||||
|     trace_to_real.insert({a.id(), tape.back()}); | ||||
|   } | ||||
|  | ||||
|   std::vector<array> outputs; | ||||
|   for (auto& o : trace_outputs) { | ||||
|     outputs.push_back(trace_to_real.at(o.id())); | ||||
|   } | ||||
|   return {tape, outputs}; | ||||
| } | ||||
|  | ||||
| void Compiled::eval( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   // Make the a real tape from the tracers | ||||
|   auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs); | ||||
|  | ||||
|   // Run the tape | ||||
|   while (!tape.empty()) { | ||||
|     auto a = std::move(tape.front()); | ||||
|     tape.pop(); | ||||
|     auto outputs = a.outputs(); | ||||
|     a.primitive().eval_cpu(a.inputs(), outputs); | ||||
|     a.detach(); | ||||
|   } | ||||
|  | ||||
|   // Copy results into outputs | ||||
|   for (int o = 0; o < real_outputs.size(); ++o) { | ||||
|     outputs[o].copy_shared_buffer(real_outputs[o]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
| @@ -41,7 +41,9 @@ DEFAULT(ArgSort) | ||||
| DEFAULT(AsType) | ||||
| DEFAULT(AsStrided) | ||||
| DEFAULT(Broadcast) | ||||
| DEFAULT_MULTI(DivMod) | ||||
| DEFAULT(Ceil) | ||||
| DEFAULT_MULTI(Compiled) | ||||
| DEFAULT(Concatenate) | ||||
| DEFAULT(Convolution) | ||||
| DEFAULT(Copy) | ||||
| @@ -78,6 +80,7 @@ DEFAULT(NotEqual) | ||||
| DEFAULT(Pad) | ||||
| DEFAULT(Partition) | ||||
| DEFAULT(Power) | ||||
| DEFAULT_MULTI(QRF) | ||||
| DEFAULT(QuantizedMatmul) | ||||
| DEFAULT(RandomBits) | ||||
| DEFAULT(Reduce) | ||||
| @@ -100,8 +103,6 @@ DEFAULT(Subtract) | ||||
| DEFAULT(Tan) | ||||
| DEFAULT(Tanh) | ||||
| DEFAULT(Transpose) | ||||
| DEFAULT_MULTI(DivMod) | ||||
| DEFAULT_MULTI(QRF) | ||||
|  | ||||
| namespace { | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ target_sources( | ||||
|   mlx | ||||
|   PRIVATE | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp | ||||
|   | ||||
							
								
								
									
										44
									
								
								mlx/backend/metal/compiled.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								mlx/backend/metal/compiled.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/metal/device.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| void Compiled::eval_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   // Just a fall-back to the original tape for now | ||||
|   std::unordered_map<uintptr_t, array> trace_to_real; | ||||
|   for (int i = 0; i < inputs.size(); ++i) { | ||||
|     trace_to_real.insert({inputs_[i].id(), inputs[i]}); | ||||
|   } | ||||
|   for (int i = 0; i < outputs.size(); ++i) { | ||||
|     trace_to_real.insert({outputs_[i].id(), outputs[i]}); | ||||
|   } | ||||
|  | ||||
|   for (auto& a : tape_) { | ||||
|     std::vector<array> p_inputs; | ||||
|     for (auto& in : a.inputs()) { | ||||
|       p_inputs.push_back(trace_to_real.at(in.id())); | ||||
|     } | ||||
|     // If a is an output get it from the map, otherwise create it | ||||
|     // NB this is safe as long as no multi-output sub primitves are allowed | ||||
|     // in Compiled | ||||
|     std::vector<array> p_outputs; | ||||
|     if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) { | ||||
|       p_outputs.push_back(it->second); | ||||
|     } else { | ||||
|       p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {})); | ||||
|       trace_to_real.insert({a.id(), p_outputs[0]}); | ||||
|     } | ||||
|     a.primitive().eval_gpu(p_inputs, p_outputs); | ||||
|   } | ||||
|   auto& s = stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   command_buffer->addCompletedHandler( | ||||
|       [trace_to_real](MTL::CommandBuffer*) mutable {}); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
| @@ -32,6 +32,7 @@ NO_GPU(AsType) | ||||
| NO_GPU(AsStrided) | ||||
| NO_GPU(Broadcast) | ||||
| NO_GPU(Ceil) | ||||
| NO_GPU_MULTI(Compiled) | ||||
| NO_GPU(Concatenate) | ||||
| NO_GPU(Convolution) | ||||
| NO_GPU(Copy) | ||||
| @@ -40,6 +41,7 @@ NO_GPU(Cosh) | ||||
| NO_GPU_MULTI(CustomVJP) | ||||
| NO_GPU_MULTI(Depends) | ||||
| NO_GPU(Divide) | ||||
| NO_GPU_MULTI(DivMod) | ||||
| NO_GPU(Remainder) | ||||
| NO_GPU(Equal) | ||||
| NO_GPU(Erf) | ||||
| @@ -69,6 +71,7 @@ NO_GPU(NotEqual) | ||||
| NO_GPU(Pad) | ||||
| NO_GPU(Partition) | ||||
| NO_GPU(Power) | ||||
| NO_GPU_MULTI(QRF) | ||||
| NO_GPU(QuantizedMatmul) | ||||
| NO_GPU(RandomBits) | ||||
| NO_GPU(Reduce) | ||||
| @@ -91,6 +94,5 @@ NO_GPU(Subtract) | ||||
| NO_GPU(Tan) | ||||
| NO_GPU(Tanh) | ||||
| NO_GPU(Transpose) | ||||
| NO_GPU_MULTI(DivMod) | ||||
| NO_GPU_MULTI(QRF) | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun