mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Share more common code in Compiled (#2240)
* Share more common code in Compiled * Remove build_lib_name
This commit is contained in:
@@ -1,16 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -82,7 +86,54 @@ Compiled::Compiled(
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_(std::move(outputs)),
|
||||
tape_(std::move(tape)),
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
constant_ids_(std::move(constant_ids)),
|
||||
is_constant_([this](size_t i) {
|
||||
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
|
||||
}) {
|
||||
// Build the kernel name.
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (const auto& x : inputs_) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (const auto& a : tape_) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (const auto& x : inputs_) {
|
||||
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (const auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
kernel_lib_ = os.str();
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>&,
|
||||
|
||||
Reference in New Issue
Block a user