mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
@@ -86,7 +86,7 @@ void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
slice(inputs[0], out, start_indices_, strides_);
|
||||
}
|
||||
void Split::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -262,29 +262,6 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
@@ -355,7 +332,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_inplace(
|
||||
|
||||
Reference in New Issue
Block a user