mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (gpu)
This commit is contained in:
@@ -51,7 +51,7 @@ void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
constexpr size_t extra_bytes = 16384;
|
constexpr int64_t extra_bytes = 16384;
|
||||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||||
(in.flags().row_contiguous ||
|
(in.flags().row_contiguous ||
|
||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
|||||||
array& out,
|
array& out,
|
||||||
const Shape& start_indices,
|
const Shape& start_indices,
|
||||||
const Shape& strides,
|
const Shape& strides,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
slice(in, out, start_indices, strides);
|
slice(in, out, start_indices, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ void pad_gpu(
|
|||||||
|
|
||||||
// Find offset for start of input values
|
// Find offset for start of input values
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (int i = 0; i < axes.size(); i++) {
|
for (int i = 0; i < std::ssize(axes); i++) {
|
||||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user