mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Shape and Strides 1 / N (#1645)
* shape and stride type def * more shape
This commit is contained in:
76
mlx/ops.h
76
mlx/ops.h
@@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
||||
/** Create a view of an array with the given shape and strides. */
|
||||
array as_strided(
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -58,31 +58,27 @@ array as_strided(
|
||||
array copy(array a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with the given value(s). */
|
||||
array full(
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
|
||||
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
|
||||
array full(Shape shape, array vals, StreamOrDevice s = {});
|
||||
template <typename T>
|
||||
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val, dtype), to_stream(s));
|
||||
}
|
||||
template <typename T>
|
||||
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
|
||||
array full(Shape shape, T val, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val), to_stream(s));
|
||||
}
|
||||
|
||||
/** Fill an array of the given shape with zeros. */
|
||||
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
||||
return zeros(shape, float32, s);
|
||||
}
|
||||
array zeros_like(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with ones. */
|
||||
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array ones(const Shape& shape, StreamOrDevice s = {}) {
|
||||
return ones(shape, float32, s);
|
||||
}
|
||||
array ones_like(const array& a, StreamOrDevice s = {});
|
||||
@@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {});
|
||||
array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
@@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {});
|
||||
/** Slice an array. */
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Split an array into sub-arrays along a given axis. */
|
||||
@@ -288,10 +280,7 @@ array pad(
|
||||
array transpose(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Broadcast an array to a given shape. */
|
||||
array broadcast_to(
|
||||
const array& a,
|
||||
const std::vector<int>& shape,
|
||||
StreamOrDevice s = {});
|
||||
array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
|
||||
|
||||
/** Broadcast a vector of arrays against one another. */
|
||||
std::vector<array> broadcast_arrays(
|
||||
@@ -917,13 +906,13 @@ array gather(
|
||||
const array& a,
|
||||
const std::vector<array>& indices,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes,
|
||||
const Shape& slice_sizes,
|
||||
StreamOrDevice s = {});
|
||||
inline array gather(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
int axis,
|
||||
const std::vector<int>& slice_sizes,
|
||||
const Shape& slice_sizes,
|
||||
StreamOrDevice s = {}) {
|
||||
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||
}
|
||||
@@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Roll elements along an axis and introduce them on the other side */
|
||||
array roll(const array& a, int shift, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
int shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
const Shape& shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user