From fdb936a5cb4b11cc3dba3612882a2ec66e88c52e Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 4 Jun 2025 08:05:21 +0900 Subject: [PATCH] Remove build_lib_name --- mlx/backend/common/compiled.cpp | 51 --------------------------------- mlx/backend/common/compiled.h | 8 ------ mlx/compile.cpp | 50 ++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 61 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index a5578b186..98c48cca9 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -2,8 +2,6 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/graph_utils.h" -#include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { @@ -80,55 +78,6 @@ std::string get_type_string(Dtype d) { } } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - NodeNamer namer; - std::ostringstream os; - std::ostringstream constant_hasher; - - // Fill the input names. This is not really necessary, I just like having A, - // B, C, ... as the inputs. - for (auto& x : inputs) { - namer.get_name(x); - } - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - // name and type of output - os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); - // computation performed - a.primitive().print(os); - // name of inputs to the function - for (auto& inp : a.inputs()) { - os << namer.get_name(inp); - } - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << (is_scalar(x) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape) { diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index e451ee239..6fccaacd6 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -3,8 +3,6 @@ #include #include -#include -#include #include "mlx/array.h" #include "mlx/primitives.h" @@ -15,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) { return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids); - std::string get_type_string(Dtype d); template diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7487b54cb..79a55ba8f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include @@ -9,9 +10,11 @@ #include "mlx/compile.h" #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "mlx/utils.h" namespace mlx::core { @@ -86,8 +89,51 @@ Compiled::Compiled( constant_ids_(std::move(constant_ids)), is_constant_([this](size_t i) { return constant_ids_.find(inputs_[i].id()) != constant_ids_.end(); - }), - kernel_lib_(build_lib_name(inputs_, outputs_, tape_, constant_ids_)) {} + }) { + // Build the kernel name. + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + // Fill the input names. This is not really necessary, I just like having A, + // B, C, ... as the inputs. + for (const auto& x : inputs_) { + namer.get_name(x); + } + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (const auto& a : tape_) { + // name and type of output + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + // computation performed + a.primitive().print(os); + // name of inputs to the function + for (auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + os << "_"; + + for (const auto& x : inputs_) { + if (constant_ids_.find(x.id()) != constant_ids_.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + os << "_"; + for (const auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + kernel_lib_ = os.str(); +} std::vector Compiled::vjp( const std::vector&,