mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-25 12:48:14 +08:00
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
@@ -33,10 +33,12 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
@@ -57,6 +59,7 @@ DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
@@ -68,8 +71,6 @@ DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT_MULTI(QRF)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
@@ -3,6 +3,7 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
|
||||
59
mlx/backend/common/compiled.cpp
Normal file
59
mlx/backend/common/compiled.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Build the real tape
|
||||
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||
const std::vector<array>& trace_tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
std::queue<array> tape;
|
||||
for (auto& a : trace_tape) {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
tape.push(
|
||||
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||
trace_to_real.insert({a.id(), tape.back()});
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
}
|
||||
return {tape, outputs};
|
||||
}
|
||||
|
||||
void Compiled::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the a real tape from the tracers
|
||||
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||
|
||||
// Run the tape
|
||||
while (!tape.empty()) {
|
||||
auto a = std::move(tape.front());
|
||||
tape.pop();
|
||||
auto outputs = a.outputs();
|
||||
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||
a.detach();
|
||||
}
|
||||
|
||||
// Copy results into outputs
|
||||
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -41,7 +41,9 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
@@ -78,6 +80,7 @@ DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
@@ -100,8 +103,6 @@ DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT_MULTI(QRF)
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
|
||||
44
mlx/backend/metal/compiled.cpp
Normal file
44
mlx/backend/metal/compiled.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Just a fall-back to the original tape for now
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({inputs_[i].id(), inputs[i]});
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
trace_to_real.insert({outputs_[i].id(), outputs[i]});
|
||||
}
|
||||
|
||||
for (auto& a : tape_) {
|
||||
std::vector<array> p_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
p_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
// If a is an output get it from the map, otherwise create it
|
||||
// NB this is safe as long as no multi-output sub primitves are allowed
|
||||
// in Compiled
|
||||
std::vector<array> p_outputs;
|
||||
if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) {
|
||||
p_outputs.push_back(it->second);
|
||||
} else {
|
||||
p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {}));
|
||||
trace_to_real.insert({a.id(), p_outputs[0]});
|
||||
}
|
||||
a.primitive().eval_gpu(p_inputs, p_outputs);
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[trace_to_real](MTL::CommandBuffer*) mutable {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -32,6 +32,7 @@ NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
@@ -40,6 +41,7 @@ NO_GPU(Cosh)
|
||||
NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
@@ -69,6 +71,7 @@ NO_GPU(NotEqual)
|
||||
NO_GPU(Pad)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Reduce)
|
||||
@@ -91,6 +94,5 @@ NO_GPU(Subtract)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU_MULTI(QRF)
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user