Fix exporting with constants (#2769)

This commit is contained in:
Awni Hannun
2025-11-14 12:52:08 -08:00
committed by GitHub
parent 3b2ffcefc3
commit 27ff069175
3 changed files with 24 additions and 3 deletions

View File

@@ -716,7 +716,7 @@ void FunctionExporter::export_with_callback(
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
continue;
}
if (constants.insert(arr.id()).second) {
if (constants.insert({arr.id(), arr}).second) {
new_constants.emplace_back(namer.get_name(arr), arr);
}
}
@@ -848,7 +848,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
if (input_set.find(arr.id()) == input_set.end()) {
serialize(os, true);
// Save constant data if not already saved
if (constants.insert(arr.id()).second) {
if (constants.insert({arr.id(), arr}).second) {
serialize(os, arr.shape());
serialize(os, arr.dtype());
os.write(arr.data<char>(), arr.nbytes());

View File

@@ -72,7 +72,7 @@ struct FunctionExporter {
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::vector<std::string>& kwarg_keys);
std::set<std::uintptr_t> constants;
std::unordered_map<std::uintptr_t, array> constants;
int count{0};
bool closed{false};
std::shared_ptr<FunctionTable> ftable;