mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Shape and Strides 1 / N (#1645)
* shape and stride type def * more shape
This commit is contained in:
		| @@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) | ||||
| } | ||||
|  | ||||
| array::array( | ||||
|     std::vector<int> shape, | ||||
|     Shape shape, | ||||
|     Dtype dtype, | ||||
|     std::shared_ptr<Primitive> primitive, | ||||
|     std::vector<array> inputs) | ||||
| @@ -42,7 +42,7 @@ array::array( | ||||
|           std::move(inputs))) {} | ||||
|  | ||||
| std::vector<array> array::make_arrays( | ||||
|     std::vector<std::vector<int>> shapes, | ||||
|     std::vector<Shape> shapes, | ||||
|     const std::vector<Dtype>& dtypes, | ||||
|     const std::shared_ptr<Primitive>& primitive, | ||||
|     const std::vector<array>& inputs) { | ||||
| @@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype) | ||||
| } | ||||
|  | ||||
| /* Build an array from a shared buffer */ | ||||
| array::array( | ||||
|     allocator::Buffer data, | ||||
|     std::vector<int> shape, | ||||
|     Dtype dtype, | ||||
|     deleter_t deleter) | ||||
| array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) | ||||
|     : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { | ||||
|   set_data(data, deleter); | ||||
| } | ||||
| @@ -126,7 +122,7 @@ bool array::is_tracer() const { | ||||
|   return array_desc_->is_tracer && in_tracing() || retain_graph(); | ||||
| } | ||||
|  | ||||
| void array::set_data(allocator::Buffer buffer, deleter_t d) { | ||||
| void array::set_data(allocator::Buffer buffer, Deleter d) { | ||||
|   array_desc_->data = std::make_shared<Data>(buffer, d); | ||||
|   array_desc_->data_ptr = buffer.raw_ptr(); | ||||
|   array_desc_->data_size = size(); | ||||
| @@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) { | ||||
| void array::set_data( | ||||
|     allocator::Buffer buffer, | ||||
|     size_t data_size, | ||||
|     std::vector<size_t> strides, | ||||
|     Strides strides, | ||||
|     Flags flags, | ||||
|     deleter_t d) { | ||||
|     Deleter d) { | ||||
|   array_desc_->data = std::make_shared<Data>(buffer, d); | ||||
|   array_desc_->data_ptr = buffer.raw_ptr(); | ||||
|   array_desc_->data_size = data_size; | ||||
| @@ -151,7 +147,7 @@ void array::set_data( | ||||
|  | ||||
| void array::copy_shared_buffer( | ||||
|     const array& other, | ||||
|     const std::vector<size_t>& strides, | ||||
|     const Strides& strides, | ||||
|     Flags flags, | ||||
|     size_t data_size, | ||||
|     size_t offset /* = 0 */) { | ||||
| @@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) { | ||||
|  | ||||
| void array::move_shared_buffer( | ||||
|     array other, | ||||
|     const std::vector<size_t>& strides, | ||||
|     const Strides& strides, | ||||
|     Flags flags, | ||||
|     size_t data_size, | ||||
|     size_t offset /* = 0 */) { | ||||
| @@ -237,13 +233,13 @@ void array::ArrayDesc::init() { | ||||
|   } | ||||
| } | ||||
|  | ||||
| array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype) | ||||
| array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype) | ||||
|     : shape(std::move(shape)), dtype(dtype), status(Status::available) { | ||||
|   init(); | ||||
| } | ||||
|  | ||||
| array::ArrayDesc::ArrayDesc( | ||||
|     std::vector<int> shape, | ||||
|     Shape shape, | ||||
|     Dtype dtype, | ||||
|     std::shared_ptr<Primitive> primitive, | ||||
|     std::vector<array> inputs) | ||||
|   | ||||
							
								
								
									
										51
									
								
								mlx/array.h
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								mlx/array.h
									
									
									
									
									
								
							| @@ -15,7 +15,10 @@ namespace mlx::core { | ||||
|  | ||||
| // Forward declaration | ||||
| class Primitive; | ||||
| using deleter_t = std::function<void(allocator::Buffer)>; | ||||
|  | ||||
| using Deleter = std::function<void(allocator::Buffer)>; | ||||
| using Shape = std::vector<int32_t>; | ||||
| using Strides = std::vector<size_t>; | ||||
|  | ||||
| class array { | ||||
|   /* An array is really a node in a graph. It contains a shared ArrayDesc | ||||
| @@ -33,7 +36,7 @@ class array { | ||||
|   template <typename It> | ||||
|   array( | ||||
|       It data, | ||||
|       std::vector<int> shape, | ||||
|       Shape shape, | ||||
|       Dtype dtype = | ||||
|           TypeToDtype<typename std::iterator_traits<It>::value_type>()); | ||||
|  | ||||
| @@ -49,15 +52,15 @@ class array { | ||||
|   template <typename T> | ||||
|   array( | ||||
|       std::initializer_list<T> data, | ||||
|       std::vector<int> shape, | ||||
|       Shape shape, | ||||
|       Dtype dtype = TypeToDtype<T>()); | ||||
|  | ||||
|   /* Build an array from a buffer */ | ||||
|   array( | ||||
|       allocator::Buffer data, | ||||
|       std::vector<int> shape, | ||||
|       Shape shape, | ||||
|       Dtype dtype, | ||||
|       deleter_t deleter = allocator::free); | ||||
|       Deleter deleter = allocator::free); | ||||
|  | ||||
|   /** Assignment to rvalue does not compile. */ | ||||
|   array& operator=(const array& other) && = delete; | ||||
| @@ -96,7 +99,7 @@ class array { | ||||
|   } | ||||
|  | ||||
|   /** The shape of the array as a vector of integers. */ | ||||
|   const std::vector<int>& shape() const { | ||||
|   const Shape& shape() const { | ||||
|     return array_desc_->shape; | ||||
|   } | ||||
|  | ||||
| @@ -105,12 +108,12 @@ class array { | ||||
|    * | ||||
|    *  This function supports negative indexing and provides | ||||
|    *  bounds checking. */ | ||||
|   int shape(int dim) const { | ||||
|   auto shape(int dim) const { | ||||
|     return shape().at(dim < 0 ? dim + ndim() : dim); | ||||
|   } | ||||
|  | ||||
|   /** The strides of the array. */ | ||||
|   const std::vector<size_t>& strides() const { | ||||
|   const Strides& strides() const { | ||||
|     return array_desc_->strides; | ||||
|   } | ||||
|  | ||||
| @@ -119,7 +122,7 @@ class array { | ||||
|    * | ||||
|    *  This function supports negative indexing and provides | ||||
|    *  bounds checking. */ | ||||
|   size_t strides(int dim) const { | ||||
|   auto strides(int dim) const { | ||||
|     return strides().at(dim < 0 ? dim + ndim() : dim); | ||||
|   } | ||||
|  | ||||
| @@ -184,13 +187,13 @@ class array { | ||||
|    */ | ||||
|  | ||||
|   array( | ||||
|       std::vector<int> shape, | ||||
|       Shape shape, | ||||
|       Dtype dtype, | ||||
|       std::shared_ptr<Primitive> primitive, | ||||
|       std::vector<array> inputs); | ||||
|  | ||||
|   static std::vector<array> make_arrays( | ||||
|       std::vector<std::vector<int>> shapes, | ||||
|       std::vector<Shape> shapes, | ||||
|       const std::vector<Dtype>& dtypes, | ||||
|       const std::shared_ptr<Primitive>& primitive, | ||||
|       const std::vector<array>& inputs); | ||||
| @@ -207,8 +210,8 @@ class array { | ||||
|  | ||||
|   struct Data { | ||||
|     allocator::Buffer buffer; | ||||
|     deleter_t d; | ||||
|     Data(allocator::Buffer buffer, deleter_t d = allocator::free) | ||||
|     Deleter d; | ||||
|     Data(allocator::Buffer buffer, Deleter d = allocator::free) | ||||
|         : buffer(buffer), d(d) {} | ||||
|     // Not copyable | ||||
|     Data(const Data& d) = delete; | ||||
| @@ -397,18 +400,18 @@ class array { | ||||
|   // Check if the array is a tracer array | ||||
|   bool is_tracer() const; | ||||
|  | ||||
|   void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); | ||||
|   void set_data(allocator::Buffer buffer, Deleter d = allocator::free); | ||||
|  | ||||
|   void set_data( | ||||
|       allocator::Buffer buffer, | ||||
|       size_t data_size, | ||||
|       std::vector<size_t> strides, | ||||
|       Strides strides, | ||||
|       Flags flags, | ||||
|       deleter_t d = allocator::free); | ||||
|       Deleter d = allocator::free); | ||||
|  | ||||
|   void copy_shared_buffer( | ||||
|       const array& other, | ||||
|       const std::vector<size_t>& strides, | ||||
|       const Strides& strides, | ||||
|       Flags flags, | ||||
|       size_t data_size, | ||||
|       size_t offset = 0); | ||||
| @@ -417,7 +420,7 @@ class array { | ||||
|  | ||||
|   void move_shared_buffer( | ||||
|       array other, | ||||
|       const std::vector<size_t>& strides, | ||||
|       const Strides& strides, | ||||
|       Flags flags, | ||||
|       size_t data_size, | ||||
|       size_t offset = 0); | ||||
| @@ -436,8 +439,8 @@ class array { | ||||
|   void init(const It src); | ||||
|  | ||||
|   struct ArrayDesc { | ||||
|     std::vector<int> shape; | ||||
|     std::vector<size_t> strides; | ||||
|     Shape shape; | ||||
|     Strides strides; | ||||
|     size_t size; | ||||
|     Dtype dtype; | ||||
|     std::shared_ptr<Primitive> primitive; | ||||
| @@ -471,10 +474,10 @@ class array { | ||||
|     // The arrays position in the output list | ||||
|     uint32_t position{0}; | ||||
|  | ||||
|     explicit ArrayDesc(std::vector<int> shape, Dtype dtype); | ||||
|     explicit ArrayDesc(Shape shape, Dtype dtype); | ||||
|  | ||||
|     explicit ArrayDesc( | ||||
|         std::vector<int> shape, | ||||
|         Shape shape, | ||||
|         Dtype dtype, | ||||
|         std::shared_ptr<Primitive> primitive, | ||||
|         std::vector<array> inputs); | ||||
| @@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */) | ||||
| template <typename It> | ||||
| array::array( | ||||
|   It data, | ||||
|   std::vector<int> shape, | ||||
|   Shape shape, | ||||
|   Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) : | ||||
|     array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { | ||||
|   init(data); | ||||
| @@ -521,7 +524,7 @@ array::array( | ||||
| template <typename T> | ||||
| array::array( | ||||
|     std::initializer_list<T> data, | ||||
|     std::vector<int> shape, | ||||
|     Shape shape, | ||||
|     Dtype dtype /* = TypeToDtype<T>() */) | ||||
|     : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { | ||||
|   if (data.size() != size()) { | ||||
|   | ||||
							
								
								
									
										244
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										244
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							| @@ -16,10 +16,9 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool> | ||||
| compute_reduce_shape( | ||||
| std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape( | ||||
|     const std::vector<int>& axes, | ||||
|     const std::vector<int>& shape) { | ||||
|     const Shape& shape) { | ||||
|   bool is_noop = true; | ||||
|   std::set<int> axes_set; | ||||
|   auto ndim = shape.size(); | ||||
| @@ -36,8 +35,8 @@ compute_reduce_shape( | ||||
|   if (axes_set.size() != axes.size()) { | ||||
|     throw std::invalid_argument("Duplicate axes detected in reduction."); | ||||
|   } | ||||
|   std::vector<int> out_shape; | ||||
|   std::vector<int> squeezed_shape; | ||||
|   Shape out_shape; | ||||
|   Shape squeezed_shape; | ||||
|   for (int i = 0; i < ndim; ++i) { | ||||
|     if (axes_set.count(i) == 0) { | ||||
|       out_shape.push_back(shape[i]); | ||||
| @@ -63,7 +62,7 @@ array indices_or_default( | ||||
|     return indices.value(); | ||||
|   } | ||||
|  | ||||
|   std::vector<int> shape(x.shape().begin(), x.shape().end() - 2); | ||||
|   Shape shape(x.shape().begin(), x.shape().end() - 2); | ||||
|   int total = | ||||
|       std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>()); | ||||
|   return reshape(arange(total, uint32, s), shape, s); | ||||
| @@ -254,8 +253,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { | ||||
|  | ||||
| array as_strided( | ||||
|     array a, | ||||
|     std::vector<int> shape, | ||||
|     std::vector<size_t> strides, | ||||
|     Shape shape, | ||||
|     Strides strides, | ||||
|     size_t offset, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   auto copied_shape = shape; // |shape| will be moved | ||||
| @@ -279,12 +278,8 @@ array copy(array a, StreamOrDevice s /* = {} */) { | ||||
|       {std::move(a)}); | ||||
| } | ||||
|  | ||||
| array full( | ||||
|     std::vector<int> shape, | ||||
|     array vals, | ||||
|     Dtype dtype, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) { | ||||
| array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) { | ||||
|   if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { | ||||
|     throw std::invalid_argument("[full] Negative dimensions not allowed."); | ||||
|   } | ||||
|   auto copied_shape = shape; // |shape| will be moved | ||||
| @@ -295,15 +290,12 @@ array full( | ||||
|       {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); | ||||
| } | ||||
|  | ||||
| array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) { | ||||
| array full(Shape shape, array vals, StreamOrDevice s /* = {} */) { | ||||
|   auto dtype = vals.dtype(); // |vals| will be moved | ||||
|   return full(std::move(shape), std::move(vals), dtype, to_stream(s)); | ||||
| } | ||||
|  | ||||
| array zeros( | ||||
|     const std::vector<int>& shape, | ||||
|     Dtype dtype, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
| array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { | ||||
|   return full(shape, array(0, dtype), to_stream(s)); | ||||
| } | ||||
|  | ||||
| @@ -311,10 +303,7 @@ array zeros_like(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   return zeros(a.shape(), a.dtype(), to_stream(s)); | ||||
| } | ||||
|  | ||||
| array ones( | ||||
|     const std::vector<int>& shape, | ||||
|     Dtype dtype, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
| array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { | ||||
|   return full(shape, array(1, dtype), to_stream(s)); | ||||
| } | ||||
|  | ||||
| @@ -368,10 +357,7 @@ array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { | ||||
|   return where(mask, zeros_like(x, s), x, s); | ||||
| } | ||||
|  | ||||
| array reshape( | ||||
|     const array& a, | ||||
|     std::vector<int> shape, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
| array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { | ||||
|   if (a.shape() == shape) { | ||||
|     return a; | ||||
|   } | ||||
| @@ -445,11 +431,11 @@ array flatten( | ||||
|   if (start_ax == end_ax) { | ||||
|     return a; | ||||
|   } | ||||
|   std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax); | ||||
|   Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax); | ||||
|   new_shape.push_back(-1); | ||||
|   new_shape.insert( | ||||
|       new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); | ||||
|   return reshape(a, new_shape, s); | ||||
|   return reshape(a, std::move(new_shape), s); | ||||
| } | ||||
|  | ||||
| array flatten(const array& a, StreamOrDevice s /* = {} */) { | ||||
| @@ -496,7 +482,7 @@ array squeeze( | ||||
|     throw std::invalid_argument("[squeeze] Received duplicate axes."); | ||||
|   } | ||||
|   std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end()); | ||||
|   std::vector<int> shape; | ||||
|   Shape shape; | ||||
|   for (int i = 0, j = 0; i < a.ndim(); ++i) { | ||||
|     if (j < sorted_axes.size() && i == sorted_axes[j]) { | ||||
|       j++; | ||||
| @@ -584,12 +570,9 @@ array expand_dims( | ||||
| // Slice helper | ||||
| namespace { | ||||
|  | ||||
| inline auto normalize_slice( | ||||
|     const std::vector<int>& shape, | ||||
|     std::vector<int>& start, | ||||
|     std::vector<int>& stop, | ||||
|     std::vector<int>& strides) { | ||||
|   std::vector<int> out_shape(shape.size()); | ||||
| inline auto | ||||
| normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { | ||||
|   Shape out_shape(shape.size()); | ||||
|   bool has_neg_strides = false; | ||||
|  | ||||
|   for (int i = 0; i < shape.size(); ++i) { | ||||
| @@ -641,9 +624,9 @@ inline auto normalize_slice( | ||||
|  | ||||
| array slice( | ||||
|     const array& a, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     std::vector<int> strides, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     Shape strides, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   if (start.size() != a.ndim() || stop.size() != a.ndim() || | ||||
|       strides.size() != a.ndim()) { | ||||
| @@ -670,24 +653,20 @@ array slice( | ||||
|  | ||||
| array slice( | ||||
|     const array& a, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   return slice( | ||||
|       a, | ||||
|       std::move(start), | ||||
|       std::move(stop), | ||||
|       std::vector<int>(a.ndim(), 1), | ||||
|       to_stream(s)); | ||||
|       a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s)); | ||||
| } | ||||
|  | ||||
| /** Update a slice from the source array */ | ||||
| array slice_update( | ||||
|     const array& src, | ||||
|     const array& update, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     std::vector<int> strides, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     Shape strides, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   // Check dimensions | ||||
|   if (start.size() != src.ndim() || stop.size() != src.ndim() || | ||||
| @@ -721,12 +700,11 @@ array slice_update( | ||||
| array slice_update( | ||||
|     const array& src, | ||||
|     const array& update, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   auto strides = std::vector<int>(src.ndim(), 1); | ||||
|   return slice_update( | ||||
|       src, update, std::move(start), std::move(stop), std::move(strides), s); | ||||
|       src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); | ||||
| } | ||||
|  | ||||
| std::vector<array> split( | ||||
| @@ -750,7 +728,7 @@ std::vector<array> split( | ||||
|       std::is_sorted(indices.begin(), indices.end(), std::less<>{}) && | ||||
|       indices[0] > 0 && indices.back() < a.shape(ax)) { | ||||
|     std::vector<Dtype> dtypes(indices.size() + 1, a.dtype()); | ||||
|     std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape()); | ||||
|     std::vector<Shape> shapes(indices.size() + 1, a.shape()); | ||||
|     shapes[0][ax] = indices[0]; | ||||
|     for (int i = 1; i < indices.size(); i++) { | ||||
|       shapes[i][ax] = indices[i] - indices[i - 1]; | ||||
| @@ -765,8 +743,7 @@ std::vector<array> split( | ||||
|   } | ||||
|  | ||||
|   std::vector<array> res; | ||||
|   auto out_shape = a.shape(); | ||||
|   auto start_indices = std::vector<int>(a.ndim(), 0); | ||||
|   auto start_indices = Shape(a.ndim(), 0); | ||||
|   auto stop_indices = a.shape(); | ||||
|   for (int i = 0; i < indices.size() + 1; ++i) { | ||||
|     stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax); | ||||
| @@ -826,13 +803,13 @@ std::vector<array> meshgrid( | ||||
|   auto ndim = arrays.size(); | ||||
|   std::vector<array> outputs; | ||||
|   for (int i = 0; i < ndim; ++i) { | ||||
|     std::vector<int> shape(ndim, 1); | ||||
|     Shape shape(ndim, 1); | ||||
|     shape[i] = -1; | ||||
|     outputs.push_back(reshape(arrays[i], std::move(shape), s)); | ||||
|   } | ||||
|  | ||||
|   if (indexing == "xy" and ndim > 1) { | ||||
|     std::vector<int> shape(ndim, 1); | ||||
|     Shape shape(ndim, 1); | ||||
|  | ||||
|     shape[1] = arrays[0].size(); | ||||
|     outputs[0] = reshape(arrays[0], shape, s); | ||||
| @@ -895,7 +872,7 @@ array concatenate( | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   }; | ||||
|  | ||||
|   std::vector<int> shape = arrays[0].shape(); | ||||
|   auto shape = arrays[0].shape(); | ||||
|   shape[ax] = 0; | ||||
|   // Make the output shape and validate that all arrays have the same shape | ||||
|   // except for the concatenation axis. | ||||
| @@ -980,7 +957,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { | ||||
|   } | ||||
|  | ||||
|   // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...) | ||||
|   std::vector<int> shape(arr.shape()); | ||||
|   auto shape = arr.shape(); | ||||
|   shape.insert(shape.begin() + axis + 1, repeats); | ||||
|   array out = expand_dims(arr, axis + 1, s); | ||||
|   out = broadcast_to(out, shape, s); | ||||
| @@ -1009,9 +986,9 @@ array tile( | ||||
|     shape.insert(shape.begin(), reps.size() - shape.size(), 1); | ||||
|   } | ||||
|  | ||||
|   std::vector<int> expand_shape; | ||||
|   std::vector<int> broad_shape; | ||||
|   std::vector<int> final_shape; | ||||
|   Shape expand_shape; | ||||
|   Shape broad_shape; | ||||
|   Shape final_shape; | ||||
|   for (int i = 0; i < shape.size(); i++) { | ||||
|     if (reps[i] != 1) { | ||||
|       expand_shape.push_back(1); | ||||
| @@ -1022,17 +999,17 @@ array tile( | ||||
|     final_shape.push_back(reps[i] * shape[i]); | ||||
|   } | ||||
|  | ||||
|   auto x = reshape(arr, expand_shape, s); | ||||
|   x = broadcast_to(x, broad_shape, s); | ||||
|   return reshape(x, final_shape, s); | ||||
|   auto x = reshape(arr, std::move(expand_shape), s); | ||||
|   x = broadcast_to(x, std::move(broad_shape), s); | ||||
|   return reshape(x, std::move(final_shape), s); | ||||
| } | ||||
|  | ||||
| array edge_pad( | ||||
|     const array& a, | ||||
|     const std::vector<int>& axes, | ||||
|     const std::vector<int>& low_pad_size, | ||||
|     const std::vector<int>& high_pad_size, | ||||
|     const std::vector<int>& out_shape, | ||||
|     const Shape& low_pad_size, | ||||
|     const Shape& high_pad_size, | ||||
|     const Shape& out_shape, | ||||
|     StreamOrDevice s /* = {}*/) { | ||||
|   array out = zeros(out_shape, a.dtype(), s); | ||||
|   auto stops = a.shape(); | ||||
| @@ -1044,7 +1021,7 @@ array edge_pad( | ||||
|  | ||||
|   for (int axis = 0; axis < a.ndim(); axis++) { | ||||
|     if (low_pad_size[axis] > 0) { | ||||
|       std::vector<int> starts(a.ndim(), 0); | ||||
|       Shape starts(a.ndim(), 0); | ||||
|       starts[axis] = low_pad_size[axis]; | ||||
|       auto stops = out.shape(); | ||||
|       stops[axis] = low_pad_size[axis] + 1; | ||||
| @@ -1058,7 +1035,7 @@ array edge_pad( | ||||
|     } | ||||
|  | ||||
|     if (high_pad_size[axis] > 0) { | ||||
|       std::vector<int> starts(a.ndim(), 0); | ||||
|       Shape starts(a.ndim(), 0); | ||||
|       starts[axis] = -high_pad_size[axis] - 1; | ||||
|       auto stops = out.shape(); | ||||
|       stops[axis] = -high_pad_size[axis]; | ||||
| @@ -1075,9 +1052,9 @@ array edge_pad( | ||||
| /** Pad an array with a constant value */ | ||||
| array pad( | ||||
|     const array& a, | ||||
|     const std::vector<int>& axes, | ||||
|     const std::vector<int>& low_pad_size, | ||||
|     const std::vector<int>& high_pad_size, | ||||
|     const Shape& axes, | ||||
|     const Shape& low_pad_size, | ||||
|     const Shape& high_pad_size, | ||||
|     const array& pad_value /*= array(0)*/, | ||||
|     const std::string mode /*= "constant"*/, | ||||
|     StreamOrDevice s /* = {}*/) { | ||||
| @@ -1089,7 +1066,7 @@ array pad( | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   std::vector<int> out_shape = a.shape(); | ||||
|   auto out_shape = a.shape(); | ||||
|  | ||||
|   for (int i = 0; i < axes.size(); i++) { | ||||
|     if (low_pad_size[i] < 0) { | ||||
| @@ -1113,7 +1090,7 @@ array pad( | ||||
|  | ||||
|   if (mode == "constant") { | ||||
|     return array( | ||||
|         out_shape, | ||||
|         std::move(out_shape), | ||||
|         a.dtype(), | ||||
|         std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size), | ||||
|         {a, astype(pad_value, a.dtype(), s)}); | ||||
| @@ -1136,8 +1113,8 @@ array pad( | ||||
|   std::vector<int> axes(a.ndim(), 0); | ||||
|   std::iota(axes.begin(), axes.end(), 0); | ||||
|  | ||||
|   std::vector<int> lows; | ||||
|   std::vector<int> highs; | ||||
|   Shape lows; | ||||
|   Shape highs; | ||||
|  | ||||
|   for (auto& pads : pad_width) { | ||||
|     lows.push_back(pads.first); | ||||
| @@ -1240,7 +1217,7 @@ array transpose( | ||||
|   } | ||||
|  | ||||
|   // Check in bounds and for duplicates | ||||
|   std::vector<int> shape(axes.size(), 0); | ||||
|   Shape shape(axes.size(), 0); | ||||
|   for (auto& ax : axes) { | ||||
|     if (ax < 0 || ax >= a.ndim()) { | ||||
|       std::ostringstream msg; | ||||
| @@ -1272,7 +1249,7 @@ array transpose(const array& a, StreamOrDevice s /* = {} */) { | ||||
|  | ||||
| array broadcast_to( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shape, | ||||
|     const Shape& shape, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   if (a.shape() == shape) { | ||||
|     return a; | ||||
| @@ -1295,14 +1272,14 @@ array broadcast_to( | ||||
|  | ||||
| std::vector<array> | ||||
| broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { | ||||
|   std::vector<int> shape = broadcast_shapes(a.shape(), b.shape()); | ||||
|   auto shape = broadcast_shapes(a.shape(), b.shape()); | ||||
|   return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; | ||||
| } | ||||
|  | ||||
| std::vector<array> broadcast_arrays( | ||||
|     const std::vector<array>& inputs, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   std::vector<int> shape{}; | ||||
|   Shape shape{}; | ||||
|   for (const auto& in : inputs) { | ||||
|     shape = broadcast_shapes(shape, in.shape()); | ||||
|   } | ||||
| @@ -1913,7 +1890,7 @@ array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { | ||||
|   int size = a.size(); | ||||
|   auto result = argmax(reshape(a, {size}, s), 0, true, s); | ||||
|   if (keepdims) { | ||||
|     result = reshape(result, std::vector<int>(a.shape().size(), 1), s); | ||||
|     result = reshape(result, Shape(a.shape().size(), 1), s); | ||||
|   } else { | ||||
|     result = squeeze(result, s); | ||||
|   } | ||||
| @@ -2098,8 +2075,8 @@ array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { | ||||
|   } | ||||
|  | ||||
|   array a_partitioned = partition(a, -k, axis_, s); | ||||
|   std::vector<int> slice_starts(a.ndim(), 0); | ||||
|   std::vector<int> slice_ends = a.shape(); | ||||
|   Shape slice_starts(a.ndim(), 0); | ||||
|   auto slice_ends = a.shape(); | ||||
|   slice_starts[axis_] = a.shape(axis_) - k; | ||||
|   return slice(a_partitioned, slice_starts, slice_ends, s); | ||||
| } | ||||
| @@ -2613,8 +2590,8 @@ array matmul( | ||||
|   } | ||||
|  | ||||
|   if (a.ndim() > 2 || b.ndim() > 2) { | ||||
|     std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|     std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|     Shape bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|     Shape bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|     auto inner_shape = broadcast_shapes(bsx_a, bsx_b); | ||||
|  | ||||
|     // Broadcast a | ||||
| @@ -2648,7 +2625,7 @@ array gather( | ||||
|     const array& a, | ||||
|     const std::vector<array>& indices, | ||||
|     const std::vector<int>& axes, | ||||
|     const std::vector<int>& slice_sizes, | ||||
|     const Shape& slice_sizes, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   // Checks that indices, dimensions, and slice_sizes are all valid | ||||
|   if (indices.size() > a.ndim()) { | ||||
| @@ -2703,7 +2680,7 @@ array gather( | ||||
|     idx = astype(idx, dtype, s); | ||||
|   } | ||||
|  | ||||
|   std::vector<int> out_shape; | ||||
|   Shape out_shape; | ||||
|   if (!inputs.empty()) { | ||||
|     out_shape = inputs[0].shape(); | ||||
|   } | ||||
| @@ -2741,7 +2718,7 @@ array take( | ||||
|   axis = axis < 0 ? a.ndim() + axis : axis; | ||||
|  | ||||
|   // Make slice sizes to pass to gather | ||||
|   std::vector<int> slice_sizes = a.shape(); | ||||
|   Shape slice_sizes = a.shape(); | ||||
|   slice_sizes[axis] = indices.size() > 0 ? 1 : 0; | ||||
|  | ||||
|   auto out = gather(a, indices, axis, slice_sizes, s); | ||||
| @@ -2759,7 +2736,7 @@ array take( | ||||
|   } | ||||
|  | ||||
|   // Squeeze the axis we take over | ||||
|   std::vector<int> out_shape = out.shape(); | ||||
|   auto out_shape = out.shape(); | ||||
|   out_shape.erase(out_shape.begin() + indices.ndim() + axis); | ||||
|   return reshape(out, std::move(out_shape), s); | ||||
| } | ||||
| @@ -2787,8 +2764,8 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) { | ||||
|   // Handle negative axis | ||||
|   axis = axis < 0 ? a.ndim() + axis : axis; | ||||
|  | ||||
|   std::vector<int> starts(a.ndim(), 0); | ||||
|   std::vector<int> stops = a.shape(); | ||||
|   Shape starts(a.ndim(), 0); | ||||
|   Shape stops = a.shape(); | ||||
|   starts[axis] = index; | ||||
|   stops[axis] = index + 1; | ||||
|   return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s); | ||||
| @@ -2821,7 +2798,7 @@ array take_along_axis( | ||||
|   axis = axis < 0 ? a.ndim() + axis : axis; | ||||
|  | ||||
|   std::vector<array> nd_indices; | ||||
|   std::vector<int> index_shape(a.ndim(), 1); | ||||
|   Shape index_shape(a.ndim(), 1); | ||||
|   for (int i = 0; i < a.ndim(); ++i) { | ||||
|     if (i == axis) { | ||||
|       nd_indices.push_back(indices); | ||||
| @@ -2834,12 +2811,11 @@ array take_along_axis( | ||||
|   } | ||||
|   std::vector<int> dims(a.ndim()); | ||||
|   std::iota(dims.begin(), dims.end(), 0); | ||||
|   std::vector<int> slice_sizes(a.ndim(), a.size() > 0); | ||||
|   Shape slice_sizes(a.ndim(), a.size() > 0); | ||||
|   auto out = gather(a, nd_indices, dims, slice_sizes, s); | ||||
|  | ||||
|   // Squeeze out the slice shape | ||||
|   std::vector<int> out_shape( | ||||
|       out.shape().begin(), out.shape().begin() + a.ndim()); | ||||
|   Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim()); | ||||
|   return reshape(out, std::move(out_shape), s); | ||||
| } | ||||
|  | ||||
| @@ -2867,7 +2843,7 @@ array put_along_axis( | ||||
|   axis = axis < 0 ? a.ndim() + axis : axis; | ||||
|  | ||||
|   std::vector<array> nd_indices; | ||||
|   std::vector<int> index_shape(a.ndim(), 1); | ||||
|   Shape index_shape(a.ndim(), 1); | ||||
|   for (int i = 0; i < a.ndim(); ++i) { | ||||
|     if (i == axis) { | ||||
|       nd_indices.push_back(indices); | ||||
| @@ -2927,7 +2903,7 @@ array scatter( | ||||
|   // Broadcast and cast indices if necessary | ||||
|   auto inputs = broadcast_arrays(indices); | ||||
|  | ||||
|   std::vector<int> idx_shape; | ||||
|   Shape idx_shape; | ||||
|   if (!inputs.empty()) { | ||||
|     idx_shape = inputs[0].shape(); | ||||
|   } | ||||
| @@ -3198,7 +3174,7 @@ inline int dilate_size(int dim, int dil) { | ||||
|   return 1 + dil * (dim - 1); | ||||
| } | ||||
|  | ||||
| inline std::vector<int> conv_out_shape( | ||||
| Shape conv_out_shape( | ||||
|     const std::vector<int>& in_shape, | ||||
|     const std::vector<int>& wt_shape, | ||||
|     const std::vector<int>& strides, | ||||
| @@ -3208,7 +3184,7 @@ inline std::vector<int> conv_out_shape( | ||||
|     const std::vector<int>& input_dilation) { | ||||
|   int N = in_shape[0]; | ||||
|   int O = wt_shape[0]; | ||||
|   std::vector<int> out_shape(in_shape.size()); | ||||
|   Shape out_shape(in_shape.size()); | ||||
|   int i = 0; | ||||
|   out_shape[i++] = N; | ||||
|  | ||||
| @@ -3577,8 +3553,8 @@ array conv_general( | ||||
|  | ||||
|   // Handle negative padding | ||||
|   if (has_neg_padding) { | ||||
|     std::vector<int> starts(in.ndim(), 0); | ||||
|     std::vector<int> stops = in.shape(); | ||||
|     Shape starts(in.ndim(), 0); | ||||
|     auto stops = in.shape(); | ||||
|  | ||||
|     for (int i = 0; i < spatial_dims; i++) { | ||||
|       if (padding_lo[i] < 0) { | ||||
| @@ -3596,7 +3572,7 @@ array conv_general( | ||||
|   } | ||||
|  | ||||
|   // Get output shapes | ||||
|   std::vector<int> out_shape = conv_out_shape( | ||||
|   auto out_shape = conv_out_shape( | ||||
|       in.shape(), | ||||
|       wt.shape(), | ||||
|       stride, | ||||
| @@ -3606,7 +3582,7 @@ array conv_general( | ||||
|       input_dilation); | ||||
|  | ||||
|   return array( | ||||
|       out_shape, | ||||
|       std::move(out_shape), | ||||
|       in.dtype(), | ||||
|       std::make_shared<Convolution>( | ||||
|           to_stream(s), | ||||
| @@ -3634,8 +3610,8 @@ array quantized_matmul( | ||||
|  | ||||
|   // QuantizedMatmul handles w.ndim == 2 case. | ||||
|   if (x.ndim() > 2 && w.ndim() > 2) { | ||||
|     std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2); | ||||
|     std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2); | ||||
|     Shape bsx_x(x.shape().begin(), x.shape().end() - 2); | ||||
|     Shape bsx_w(w.shape().begin(), w.shape().end() - 2); | ||||
|     auto inner_shape = broadcast_shapes(bsx_x, bsx_w); | ||||
|  | ||||
|     // Broadcast x | ||||
| @@ -3731,7 +3707,7 @@ array gather_qmm( | ||||
|   // and output type | ||||
|   auto out_type = result_type(x, scales, biases); | ||||
|  | ||||
|   auto out = array( | ||||
|   return array( | ||||
|       std::move(out_shape), | ||||
|       out_type, | ||||
|       std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), | ||||
| @@ -3741,8 +3717,6 @@ array gather_qmm( | ||||
|        astype(biases, out_type, s), | ||||
|        lhs_indices, | ||||
|        rhs_indices}); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| array tensordot( | ||||
| @@ -3802,7 +3776,7 @@ array tensordot( | ||||
|  | ||||
|   std::vector<int> t1; | ||||
|   std::vector<int> t2; | ||||
|   std::vector<int> rshape; | ||||
|   Shape rshape; | ||||
|   int size1 = 1; | ||||
|   int size2 = 1; | ||||
|   for (int i = 0; i < a.ndim(); i++) { | ||||
| @@ -3898,7 +3872,7 @@ array addmm( | ||||
|  | ||||
|   // We can batch the multiplication by reshaping a | ||||
|   if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) { | ||||
|     std::vector<int> out_shape = a.shape(); | ||||
|     auto out_shape = a.shape(); | ||||
|     a = reshape(a, {-1, out_shape.back()}, s); | ||||
|     out_shape.back() = b.shape(-1); | ||||
|  | ||||
| @@ -3917,8 +3891,8 @@ array addmm( | ||||
|   } | ||||
|  | ||||
|   if (a.ndim() > 2 || b.ndim() > 2) { | ||||
|     std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|     std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|     Shape bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|     Shape bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|     auto inner_shape = broadcast_shapes(bsx_a, bsx_b); | ||||
|  | ||||
|     // Broadcast a | ||||
| @@ -4042,8 +4016,8 @@ array block_masked_mm( | ||||
|   b = astype(b, out_type, s); | ||||
|  | ||||
|   // Handle broadcasting | ||||
|   std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|   std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|   Shape bsx_a(a.shape().begin(), a.shape().end() - 2); | ||||
|   Shape bsx_b(b.shape().begin(), b.shape().end() - 2); | ||||
|  | ||||
|   auto bsx_shape = broadcast_shapes(bsx_a, bsx_b); | ||||
|  | ||||
| @@ -4079,7 +4053,7 @@ array block_masked_mm( | ||||
|  | ||||
|   // Broadcast and astype mask | ||||
|   auto broadcast_mask = [](array mask, | ||||
|                            std::vector<int>& bs_shape, | ||||
|                            Shape& bs_shape, | ||||
|                            int y, | ||||
|                            int x, | ||||
|                            Dtype mask_dtype, | ||||
| @@ -4397,7 +4371,7 @@ std::vector<array> depends( | ||||
|   Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream() | ||||
|                                          : to_stream({}); | ||||
|   // Make the output info | ||||
|   std::vector<std::vector<int>> shapes; | ||||
|   std::vector<Shape> shapes; | ||||
|   std::vector<Dtype> dtypes; | ||||
|   for (const auto& in : inputs) { | ||||
|     shapes.emplace_back(in.shape()); | ||||
| @@ -4434,7 +4408,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { | ||||
|     case 0: | ||||
|       return reshape(a, {1, 1}, s); | ||||
|     case 1: | ||||
|       return reshape(a, {1, static_cast<int>(a.size())}, s); | ||||
|       return reshape(a, {1, a.shape(0)}, s); | ||||
|     default: | ||||
|       return a; | ||||
|   } | ||||
| @@ -4456,7 +4430,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { | ||||
|     case 0: | ||||
|       return reshape(a, {1, 1, 1}, s); | ||||
|     case 1: | ||||
|       return reshape(a, {1, static_cast<int>(a.size()), 1}, s); | ||||
|       return reshape(a, {1, a.shape(0), 1}, s); | ||||
|     case 2: | ||||
|       return reshape(a, {a.shape(0), a.shape(1), 1}, s); | ||||
|     default: | ||||
| @@ -4493,7 +4467,7 @@ array number_of_elements( | ||||
|   } | ||||
|  | ||||
|   return stop_gradient(array( | ||||
|       std::vector<int>{}, | ||||
|       Shape{}, | ||||
|       dtype, | ||||
|       std::make_shared<NumberOfElements>( | ||||
|           to_stream(s), std::move(axes), inverted, dtype), | ||||
| @@ -4613,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) { | ||||
|  | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     const Shape& shift, | ||||
|     const std::vector<int>& axes, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   if (axes.empty()) { | ||||
| @@ -4627,7 +4601,6 @@ array roll( | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   std::vector<array> parts; | ||||
|   array result = a; | ||||
|   for (int i = 0; i < axes.size(); i++) { | ||||
|     int ax = axes[i]; | ||||
| @@ -4641,11 +4614,11 @@ array roll( | ||||
|       throw std::invalid_argument(msg.str()); | ||||
|     } | ||||
|  | ||||
|     int sh = shift[i]; | ||||
|     int split_index = | ||||
|     auto sh = shift[i]; | ||||
|     auto split_index = | ||||
|         (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); | ||||
|  | ||||
|     parts = split(result, std::vector<int>{split_index}, ax, s); | ||||
|     auto parts = split(result, Shape{split_index}, ax, s); | ||||
|     std::swap(parts[0], parts[1]); | ||||
|     result = concatenate(parts, ax, s); | ||||
|   } | ||||
| @@ -4656,19 +4629,12 @@ array roll( | ||||
| array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { | ||||
|   auto shape = a.shape(); | ||||
|   return reshape( | ||||
|       roll( | ||||
|           reshape(a, std::vector<int>{-1}, s), | ||||
|           std::vector<int>{shift}, | ||||
|           std::vector<int>{0}, | ||||
|           s), | ||||
|       roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s), | ||||
|       std::move(shape), | ||||
|       s); | ||||
| } | ||||
|  | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
| array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) { | ||||
|   int total_shift = 0; | ||||
|   for (auto& s : shift) { | ||||
|     total_shift += s; | ||||
| @@ -4677,7 +4643,7 @@ array roll( | ||||
| } | ||||
|  | ||||
| array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) { | ||||
|   return roll(a, std::vector<int>{shift}, std::vector<int>{axis}, s); | ||||
|   return roll(a, Shape{shift}, std::vector<int>{axis}, s); | ||||
| } | ||||
|  | ||||
| array roll( | ||||
| @@ -4685,20 +4651,20 @@ array roll( | ||||
|     int shift, | ||||
|     const std::vector<int>& axes, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   std::vector<int> shifts(axes.size(), shift); | ||||
|   Shape shifts(axes.size(), shift); | ||||
|   return roll(a, shifts, axes, s); | ||||
| } | ||||
|  | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     const Shape& shift, | ||||
|     int axis, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   int total_shift = 0; | ||||
|   for (auto& s : shift) { | ||||
|     total_shift += s; | ||||
|   } | ||||
|   return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s); | ||||
|   return roll(a, Shape{total_shift}, std::vector<int>{axis}, s); | ||||
| } | ||||
|  | ||||
| array real(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   | ||||
							
								
								
									
										76
									
								
								mlx/ops.h
									
									
									
									
									
								
							
							
						
						
									
										76
									
								
								mlx/ops.h
									
									
									
									
									
								
							| @@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {}); | ||||
| /** Create a view of an array with the given shape and strides. */ | ||||
| array as_strided( | ||||
|     array a, | ||||
|     std::vector<int> shape, | ||||
|     std::vector<size_t> strides, | ||||
|     Shape shape, | ||||
|     Strides strides, | ||||
|     size_t offset, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| @@ -58,31 +58,27 @@ array as_strided( | ||||
| array copy(array a, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Fill an array of the given shape with the given value(s). */ | ||||
| array full( | ||||
|     std::vector<int> shape, | ||||
|     array vals, | ||||
|     Dtype dtype, | ||||
|     StreamOrDevice s = {}); | ||||
| array full(std::vector<int> shape, array vals, StreamOrDevice s = {}); | ||||
| array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); | ||||
| array full(Shape shape, array vals, StreamOrDevice s = {}); | ||||
| template <typename T> | ||||
| array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) { | ||||
| array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) { | ||||
|   return full(std::move(shape), array(val, dtype), to_stream(s)); | ||||
| } | ||||
| template <typename T> | ||||
| array full(std::vector<int> shape, T val, StreamOrDevice s = {}) { | ||||
| array full(Shape shape, T val, StreamOrDevice s = {}) { | ||||
|   return full(std::move(shape), array(val), to_stream(s)); | ||||
| } | ||||
|  | ||||
| /** Fill an array of the given shape with zeros. */ | ||||
| array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {}); | ||||
| inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) { | ||||
| array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); | ||||
| inline array zeros(const Shape& shape, StreamOrDevice s = {}) { | ||||
|   return zeros(shape, float32, s); | ||||
| } | ||||
| array zeros_like(const array& a, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Fill an array of the given shape with ones. */ | ||||
| array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {}); | ||||
| inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) { | ||||
| array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); | ||||
| inline array ones(const Shape& shape, StreamOrDevice s = {}) { | ||||
|   return ones(shape, float32, s); | ||||
| } | ||||
| array ones_like(const array& a, StreamOrDevice s = {}); | ||||
| @@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {}); | ||||
| array triu(array x, int k = 0, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Reshape an array to the given shape. */ | ||||
| array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {}); | ||||
| array reshape(const array& a, Shape shape, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Flatten the dimensions in the range `[start_axis, end_axis]` . */ | ||||
| array flatten( | ||||
| @@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {}); | ||||
| /** Slice an array. */ | ||||
| array slice( | ||||
|     const array& a, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     std::vector<int> strides, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     Shape strides, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** Slice an array with a stride of 1 in each dimension. */ | ||||
| array slice( | ||||
|     const array& a, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     StreamOrDevice s = {}); | ||||
| array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Update a slice from the source array */ | ||||
| array slice_update( | ||||
|     const array& src, | ||||
|     const array& update, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     std::vector<int> strides, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     Shape strides, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** Update a slice from the source array with stride 1 in each dimension */ | ||||
| array slice_update( | ||||
|     const array& src, | ||||
|     const array& update, | ||||
|     std::vector<int> start, | ||||
|     std::vector<int> stop, | ||||
|     Shape start, | ||||
|     Shape stop, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| /** Split an array into sub-arrays along a given axis. */ | ||||
| @@ -288,10 +280,7 @@ array pad( | ||||
| array transpose(const array& a, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Broadcast an array to a given shape. */ | ||||
| array broadcast_to( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shape, | ||||
|     StreamOrDevice s = {}); | ||||
| array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Broadcast a vector of arrays against one another. */ | ||||
| std::vector<array> broadcast_arrays( | ||||
| @@ -917,13 +906,13 @@ array gather( | ||||
|     const array& a, | ||||
|     const std::vector<array>& indices, | ||||
|     const std::vector<int>& axes, | ||||
|     const std::vector<int>& slice_sizes, | ||||
|     const Shape& slice_sizes, | ||||
|     StreamOrDevice s = {}); | ||||
| inline array gather( | ||||
|     const array& a, | ||||
|     const array& indices, | ||||
|     int axis, | ||||
|     const std::vector<int>& slice_sizes, | ||||
|     const Shape& slice_sizes, | ||||
|     StreamOrDevice s = {}) { | ||||
|   return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s); | ||||
| } | ||||
| @@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); | ||||
|  | ||||
| /** Roll elements along an axis and introduce them on the other side */ | ||||
| array roll(const array& a, int shift, StreamOrDevice s = {}); | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     StreamOrDevice s = {}); | ||||
| array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); | ||||
| array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); | ||||
| array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {}); | ||||
| array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); | ||||
| array roll( | ||||
|     const array& a, | ||||
|     int shift, | ||||
|     const std::vector<int>& axes, | ||||
|     StreamOrDevice s = {}); | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     int axis, | ||||
|     StreamOrDevice s = {}); | ||||
| array roll( | ||||
|     const array& a, | ||||
|     const std::vector<int>& shift, | ||||
|     const Shape& shift, | ||||
|     const std::vector<int>& axes, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
|   | ||||
| @@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) { | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| std::vector<int> broadcast_shapes( | ||||
|     const std::vector<int>& s1, | ||||
|     const std::vector<int>& s2) { | ||||
| Shape broadcast_shapes(const Shape& s1, const Shape& s2) { | ||||
|   // Use the same broadcasting rules as numpy | ||||
|   // https://numpy.org/doc/1.20/user/theory.broadcasting.html | ||||
|   // "The size of the trailing axes for both arrays in an operation must | ||||
| @@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes( | ||||
|   int diff = std::abs(ndim1 - ndim2); | ||||
|   const auto& big = ndim1 > ndim2 ? s1 : s2; | ||||
|   const auto& small = ndim1 > ndim2 ? s2 : s1; | ||||
|   std::vector<int> out_shape(ndim); | ||||
|   Shape out_shape(ndim); | ||||
|   for (int i = ndim - 1; i >= diff; --i) { | ||||
|     int a = big[i]; | ||||
|     int b = small[i - diff]; | ||||
| @@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| inline size_t elem_to_loc( | ||||
|     int elem, | ||||
|     const std::vector<int>& shape, | ||||
|     const std::vector<size_t>& strides) { | ||||
| inline size_t | ||||
| elem_to_loc(int elem, const Shape& shape, const Strides& strides) { | ||||
|   size_t loc = 0; | ||||
|   for (int i = shape.size() - 1; i >= 0; --i) { | ||||
|     auto q_and_r = ldiv(elem, shape[i]); | ||||
| @@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { | ||||
|  | ||||
| template <typename T> | ||||
| void print_array(std::ostream& os, const array& a) { | ||||
|   std::vector<int> indices(a.ndim(), 0); | ||||
|   os << std::boolalpha; | ||||
|   os << "array("; | ||||
|   if (a.ndim() == 0) { | ||||
| @@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) { | ||||
|   return os; | ||||
| } | ||||
|  | ||||
| std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) { | ||||
| std::ostream& operator<<(std::ostream& os, const Shape& v) { | ||||
|   os << "("; | ||||
|   for (int i = 0; i < v.size(); ++i) { | ||||
|     os << v[i] << ((i == v.size() - 1) ? "" : ","); | ||||
| @@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) { | ||||
|   return os; | ||||
| } | ||||
|  | ||||
| std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) { | ||||
| std::ostream& operator<<(std::ostream& os, const Strides& v) { | ||||
|   os << "("; | ||||
|   for (int i = 0; i < v.size(); ++i) { | ||||
|     os << v[i] << ((i == v.size() - 1) ? "" : ","); | ||||
|   | ||||
| @@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) { | ||||
| } | ||||
| Dtype result_type(const std::vector<array>& arrays); | ||||
|  | ||||
| std::vector<int> broadcast_shapes( | ||||
|     const std::vector<int>& s1, | ||||
|     const std::vector<int>& s2); | ||||
| Shape broadcast_shapes(const Shape& s1, const Shape& s2); | ||||
|  | ||||
| bool is_same_shape(const std::vector<array>& arrays); | ||||
|  | ||||
| @@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); | ||||
| std::ostream& operator<<(std::ostream& os, const Dtype& d); | ||||
| std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); | ||||
| std::ostream& operator<<(std::ostream& os, array a); | ||||
| std::ostream& operator<<(std::ostream& os, const std::vector<int>& v); | ||||
| std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v); | ||||
| std::ostream& operator<<(std::ostream& os, const Shape& v); | ||||
| std::ostream& operator<<(std::ostream& os, const Strides& v); | ||||
| std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v); | ||||
| inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { | ||||
|   return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun