mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
@@ -499,8 +499,8 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
|
@@ -14,18 +14,7 @@ void slice_gpu(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides and initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -49,7 +48,7 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||
if (large) {
|
||||
kernel_name += "_large";
|
||||
kernel_name += "large";
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out));
|
||||
|
Reference in New Issue
Block a user