mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix exporting with constants (#2769)
This commit is contained in:
@@ -716,7 +716,7 @@ void FunctionExporter::export_with_callback(
|
||||
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||
continue;
|
||||
}
|
||||
if (constants.insert(arr.id()).second) {
|
||||
if (constants.insert({arr.id(), arr}).second) {
|
||||
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||
}
|
||||
}
|
||||
@@ -848,7 +848,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
if (input_set.find(arr.id()) == input_set.end()) {
|
||||
serialize(os, true);
|
||||
// Save constant data if not already saved
|
||||
if (constants.insert(arr.id()).second) {
|
||||
if (constants.insert({arr.id(), arr}).second) {
|
||||
serialize(os, arr.shape());
|
||||
serialize(os, arr.dtype());
|
||||
os.write(arr.data<char>(), arr.nbytes());
|
||||
|
||||
@@ -72,7 +72,7 @@ struct FunctionExporter {
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<std::string>& kwarg_keys);
|
||||
std::set<std::uintptr_t> constants;
|
||||
std::unordered_map<std::uintptr_t, array> constants;
|
||||
int count{0};
|
||||
bool closed{false};
|
||||
std::shared_ptr<FunctionTable> ftable;
|
||||
|
||||
Reference in New Issue
Block a user