mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops * add pad internal op * formatting + no_cpu fix
This commit is contained in:
@@ -10,16 +10,16 @@ namespace mlx::core {
|
||||
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void ternary_op(
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
@@ -47,7 +47,6 @@ void ternary_op(
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out);
|
||||
@@ -101,8 +100,29 @@ void ternary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
ternary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op(inputs, out, "select");
|
||||
ternary_op_gpu(inputs, out, "select");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user