fix + comment

This commit is contained in:
Awni Hannun
2025-08-15 12:11:23 -07:00
parent f403ea1764
commit 5e542d98e0
2 changed files with 3 additions and 3 deletions

View File

@@ -121,7 +121,7 @@ void copy_general_input(
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in)); const_param<dims_constant()>(strides_in));
}); });
@@ -137,7 +137,7 @@ void copy_general_input(
0, 0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), rest,
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
ndim); ndim);

View File

@@ -4,7 +4,7 @@
namespace mlx::core { namespace mlx::core {
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) { void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu"); nvtx3::scoped_range r("Sqrt::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
if (recip_) { if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s); unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);