Share more common code in Compiled (#2240)

* Share more common code in Compiled

* Remove build_lib_name
This commit is contained in:
Cheng
2025-06-04 08:48:50 +09:00
committed by GitHub
parent 5685ceb3c7
commit 0bb89e9e5f
6 changed files with 193 additions and 233 deletions

View File

@@ -1,16 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -82,7 +86,54 @@ Compiled::Compiled(
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {}
constant_ids_(std::move(constant_ids)),
is_constant_([this](size_t i) {
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
}) {
// Build the kernel name.
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (const auto& x : inputs_) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (const auto& a : tape_) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (const auto& x : inputs_) {
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (const auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
kernel_lib_ = os.str();
}
std::vector<array> Compiled::vjp(
const std::vector<array>&,