Add some internal GPU apis (#1177)

* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
This commit is contained in:
Alex Barron
2024-06-04 09:24:26 -07:00
committed by GitHub
parent ea9090bbc4
commit 375a8bbdcc
17 changed files with 449 additions and 203 deletions

View File

@@ -10,16 +10,16 @@ namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
void ternary_op(
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const std::string op,
const Stream& s) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
@@ -47,7 +47,6 @@ void ternary_op(
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_ternary_kernel(d, kernel_name, out);
@@ -101,8 +100,29 @@ void ternary_op(
}
}
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
ternary_op_gpu_inplace(inputs, out, op, s);
}
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& s = out.primitive().stream();
ternary_op_gpu(inputs, out, op, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
ternary_op_gpu(inputs, out, "select");
}
} // namespace mlx::core