mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -17,7 +17,8 @@ namespace mlx::core {
|
||||
class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
@@ -498,7 +499,7 @@ class array {
|
||||
|
||||
template <typename T>
|
||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
init(&val);
|
||||
}
|
||||
|
||||
@@ -516,7 +517,7 @@ array::array(
|
||||
std::initializer_list<T> data,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user