mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides * Convert SmallVector to tuple
This commit is contained in:
@@ -179,7 +179,7 @@ void serialize(Writer& os, const array& arr) {
|
||||
}
|
||||
template <>
|
||||
array deserialize(Reader& is) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
||||
}
|
||||
@@ -640,7 +640,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
auto outputs = arr.outputs();
|
||||
serialize(os, arrays_to_ids(outputs));
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (auto& o : outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
@@ -813,14 +813,14 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
||||
std::shared_ptr<Primitive> prim = factory.load(is);
|
||||
auto num_siblings = deserialize<uint64_t>(is);
|
||||
if (num_siblings == 0) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
tape.emplace_back(
|
||||
std::move(shape), type, std::move(prim), std::move(inputs));
|
||||
array_map.emplace(id, tape.back());
|
||||
} else {
|
||||
auto ids = deserialize<std::vector<uint64_t>>(is);
|
||||
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
|
||||
auto shapes = deserialize<std::vector<Shape>>(is);
|
||||
auto types = deserialize<std::vector<Dtype>>(is);
|
||||
auto arrays = array::make_arrays(
|
||||
std::move(shapes),
|
||||
@@ -841,7 +841,7 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
||||
if (auto it = constants.find(id); it != constants.end()) {
|
||||
tape.push_back(it->second);
|
||||
} else {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
size_t offset = is.tell();
|
||||
tape.push_back(array(
|
||||
|
||||
Reference in New Issue
Block a user