Refactor JIT for unary/binary/ternary ops (#1206)

* refactor unary/binary/ternary ops

* get_primitive_string util

---------
This commit is contained in:
Alex Barron
2024-06-12 14:22:12 -07:00
committed by GitHub
parent de2b9e7d0a
commit 934683088e
16 changed files with 379 additions and 935 deletions

View File

@@ -49,7 +49,7 @@ void ternary_op_gpu_inplace(
auto& d = metal::device(s.device);
auto kernel = get_ternary_kernel(d, kernel_name, out);
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
@@ -122,7 +122,7 @@ void ternary_op_gpu(
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op_gpu(inputs, out, "select");
ternary_op_gpu(inputs, out, get_primitive_string(this));
}
} // namespace mlx::core