mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops * add pad internal op * formatting + no_cpu fix
This commit is contained in:
parent
ea9090bbc4
commit
375a8bbdcc
@ -48,6 +48,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
|
@ -250,49 +250,6 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@ -492,7 +493,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
|
52
mlx/backend/common/slicing.cpp
Normal file
52
mlx/backend/common/slicing.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
20
mlx/backend/common/slicing.h
Normal file
20
mlx/backend/common/slicing.h
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
@ -135,6 +135,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
|
@ -10,16 +10,14 @@ namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void binary_op(
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
@ -61,7 +59,6 @@ void binary_op(
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
@ -120,15 +117,36 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -168,7 +186,6 @@ void binary_op(
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
@ -221,102 +238,123 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
binary_op_gpu(inputs, out, "add");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
binary_op_gpu(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
binary_op_gpu(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
binary_op_gpu(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
binary_op_gpu(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
binary_op_gpu(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
binary_op_gpu(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
binary_op_gpu(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
binary_op_gpu(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
binary_op_gpu(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
binary_op_gpu(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
binary_op_gpu(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
binary_op_gpu(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
binary_op_gpu(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "land");
|
||||
binary_op_gpu(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lor");
|
||||
binary_op_gpu(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
binary_op_gpu(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
binary_op_gpu(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
binary_op_gpu(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
binary_op_gpu(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
binary_op_gpu(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
binary_op_gpu(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
binary_op_gpu(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
33
mlx/backend/metal/binary.h
Normal file
33
mlx/backend/metal/binary.h
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@ -7,6 +7,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -163,30 +164,7 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis_));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
auto& d = metal::device(stream().device);
|
||||
auto& compute_encoder = d.get_command_encoder(stream().index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -238,23 +216,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Padding value, input and output must be of the same type
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy_gpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -331,28 +293,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ stream());
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
98
mlx/backend/metal/slicing.cpp
Normal file
98
mlx/backend/metal/slicing.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::vector<int> start_indices,
|
||||
std::vector<int> strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices, strides);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ s);
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
sizes.push_back(p.shape(axis));
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
auto& d = metal::device(s.device);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
}
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
copy_gpu(val, out, CopyType::Scalar, s);
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||
}
|
||||
|
||||
// Extract slice from output where input will be pasted
|
||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
out_slice.copy_shared_buffer(
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
30
mlx/backend/metal/slicing.h
Normal file
30
mlx/backend/metal/slicing.h
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::vector<int> start_indices,
|
||||
std::vector<int> strides,
|
||||
const Stream& s);
|
||||
|
||||
void concatenate_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int axis,
|
||||
const Stream& s);
|
||||
|
||||
void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@ -10,16 +10,16 @@ namespace mlx::core {
|
||||
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void ternary_op(
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@ -47,7 +47,6 @@ void ternary_op(
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out);
|
||||
@ -101,8 +100,29 @@ void ternary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
ternary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op(inputs, out, "select");
|
||||
ternary_op_gpu(inputs, out, "select");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
21
mlx/backend/metal/ternary.h
Normal file
21
mlx/backend/metal/ternary.h
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@ -7,30 +7,17 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void unary_op(
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
|
||||
@ -59,39 +46,70 @@ void unary_op(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& s = out.primitive().stream();
|
||||
unary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "abs");
|
||||
unary_op_gpu(inputs, out, "abs");
|
||||
}
|
||||
|
||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccos");
|
||||
unary_op_gpu(inputs, out, "arccos");
|
||||
}
|
||||
|
||||
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arccosh");
|
||||
unary_op_gpu(inputs, out, "arccosh");
|
||||
}
|
||||
|
||||
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsin");
|
||||
unary_op_gpu(inputs, out, "arcsin");
|
||||
}
|
||||
|
||||
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arcsinh");
|
||||
unary_op_gpu(inputs, out, "arcsinh");
|
||||
}
|
||||
|
||||
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctan");
|
||||
unary_op_gpu(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctanh");
|
||||
unary_op_gpu(inputs, out, "arctanh");
|
||||
}
|
||||
|
||||
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_op(inputs, out, "conj");
|
||||
unary_op_gpu(inputs, out, "conj");
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
@ -99,68 +117,68 @@ void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cos");
|
||||
unary_op_gpu(inputs, out, "cos");
|
||||
}
|
||||
|
||||
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cosh");
|
||||
unary_op_gpu(inputs, out, "cosh");
|
||||
}
|
||||
|
||||
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erf");
|
||||
unary_op_gpu(inputs, out, "erf");
|
||||
}
|
||||
|
||||
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "erfinv");
|
||||
unary_op_gpu(inputs, out, "erfinv");
|
||||
}
|
||||
|
||||
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "exp");
|
||||
unary_op_gpu(inputs, out, "exp");
|
||||
}
|
||||
|
||||
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "expm1");
|
||||
unary_op_gpu(inputs, out, "expm1");
|
||||
}
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op(inputs, out, "log");
|
||||
unary_op_gpu(inputs, out, "log");
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op(inputs, out, "log2");
|
||||
unary_op_gpu(inputs, out, "log2");
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op(inputs, out, "log10");
|
||||
unary_op_gpu(inputs, out, "log10");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "log1p");
|
||||
unary_op_gpu(inputs, out, "log1p");
|
||||
}
|
||||
|
||||
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "lnot");
|
||||
unary_op_gpu(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "floor");
|
||||
unary_op_gpu(inputs, out, "floor");
|
||||
}
|
||||
|
||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "ceil");
|
||||
unary_op_gpu(inputs, out, "ceil");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
unary_op_gpu(inputs, out, "neg");
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op(inputs, out, "round");
|
||||
unary_op_gpu(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@ -168,39 +186,39 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
unary_op_gpu(inputs, out, "sigmoid");
|
||||
}
|
||||
|
||||
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sign");
|
||||
unary_op_gpu(inputs, out, "sign");
|
||||
}
|
||||
|
||||
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sin");
|
||||
unary_op_gpu(inputs, out, "sin");
|
||||
}
|
||||
|
||||
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sinh");
|
||||
unary_op_gpu(inputs, out, "sinh");
|
||||
}
|
||||
|
||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "square");
|
||||
unary_op_gpu(inputs, out, "square");
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (recip_) {
|
||||
unary_op(inputs, out, "rsqrt");
|
||||
unary_op_gpu(inputs, out, "rsqrt");
|
||||
} else {
|
||||
unary_op(inputs, out, "sqrt");
|
||||
unary_op_gpu(inputs, out, "sqrt");
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tan");
|
||||
unary_op_gpu(inputs, out, "tan");
|
||||
}
|
||||
|
||||
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "tanh");
|
||||
unary_op_gpu(inputs, out, "tanh");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
21
mlx/backend/metal/unary.h
Normal file
21
mlx/backend/metal/unary.h
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@ -6,4 +6,5 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||
)
|
||||
|
@ -1794,14 +1794,6 @@ class Slice : public UnaryPrimitive {
|
||||
std::vector<int> strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in);
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out);
|
||||
};
|
||||
|
||||
class SliceUpdate : public UnaryPrimitive {
|
||||
|
Loading…
Reference in New Issue
Block a user