mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
42
mlx/ops.cpp
42
mlx/ops.cpp
@@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
|
||||
// Clamp to bounds
|
||||
auto st = std::min(s, n - 1);
|
||||
auto ed = std::max(-1, e);
|
||||
auto ed = e > -1 ? e : -1;
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed > st ? st : ed;
|
||||
@@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
|
||||
} else {
|
||||
// Clamp to bounds
|
||||
auto st = std::max(0, std::min(s, n));
|
||||
auto ed = std::max(0, std::min(e, n));
|
||||
auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
|
||||
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed < st ? st : ed;
|
||||
@@ -765,7 +765,7 @@ array slice_update(
|
||||
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const std::vector<int>& indices,
|
||||
const Shape& indices,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto ax = axis < 0 ? axis + a.ndim() : axis;
|
||||
@@ -809,10 +809,8 @@ std::vector<array> split(
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<array> split(
|
||||
const array& a,
|
||||
const std::vector<int>& indices,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array>
|
||||
split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
|
||||
return split(a, indices, 0, s);
|
||||
}
|
||||
|
||||
@@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto split_size = q_and_r.quot;
|
||||
std::vector<int> indices(num_splits - 1);
|
||||
Shape indices(num_splits - 1);
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
indices[i] = (i + 1) * split_size;
|
||||
}
|
||||
@@ -1104,7 +1102,7 @@ array edge_pad(
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
const Shape& axes,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
@@ -1904,9 +1902,11 @@ array min(
|
||||
|
||||
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmin(reshape(a, {size}, s), 0, true, s);
|
||||
auto result = argmin(flatten(a, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
||||
std::vector<int> axes(a.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
result = expand_dims(result, axes, s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
@@ -1940,9 +1940,11 @@ array argmin(
|
||||
|
||||
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmax(reshape(a, {size}, s), 0, true, s);
|
||||
auto result = argmax(flatten(a, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, Shape(a.shape().size(), 1), s);
|
||||
std::vector<int> axes(a.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
result = expand_dims(result, axes, s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
@@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
|
||||
}
|
||||
|
||||
Shape conv_out_shape(
|
||||
const std::vector<int>& in_shape,
|
||||
const std::vector<int>& wt_shape,
|
||||
const Shape& in_shape,
|
||||
const Shape& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
@@ -4329,16 +4331,16 @@ array diagonal(
|
||||
"[diagonal] axis1 and axis2 cannot be the same axis");
|
||||
}
|
||||
|
||||
auto off1 = std::max(-offset, 0);
|
||||
auto off2 = std::max(offset, 0);
|
||||
ShapeElem off1 = std::max(-offset, 0);
|
||||
ShapeElem off2 = std::max(offset, 0);
|
||||
|
||||
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
|
||||
diag_size = std::max(diag_size, 0);
|
||||
diag_size = diag_size < 0 ? 0 : diag_size;
|
||||
|
||||
std::vector<array> indices = {
|
||||
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
|
||||
|
||||
std::vector<int> slice_sizes = a.shape();
|
||||
Shape slice_sizes = a.shape();
|
||||
slice_sizes[ax1] = 1;
|
||||
slice_sizes[ax2] = 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user