mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops * get_primitive_string util ---------
This commit is contained in:
@@ -49,7 +49,7 @@ void ternary_op_gpu_inplace(
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out);
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@@ -122,7 +122,7 @@ void ternary_op_gpu(
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op_gpu(inputs, out, "select");
|
||||
ternary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user