More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -17,7 +17,8 @@ namespace mlx::core {
class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
@@ -498,7 +499,7 @@ class array {
template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
init(&val);
}
@@ -516,7 +517,7 @@ array::array(
std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}