More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -25,7 +25,7 @@ bool retain_graph() {
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
init(&cval);
}
@@ -61,14 +61,14 @@ std::vector<array> array::make_arrays(
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
float32)) {
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}
@@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);
auto start = Shape(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();
shape.erase(shape.begin());