mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix + comment
This commit is contained in:
@@ -121,7 +121,7 @@ void copy_general_input(
|
|||||||
0,
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
rest,
|
||||||
const_param<dims_constant()>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<dims_constant()>(strides_in));
|
const_param<dims_constant()>(strides_in));
|
||||||
});
|
});
|
||||||
@@ -137,7 +137,7 @@ void copy_general_input(
|
|||||||
0,
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
rest,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
ndim);
|
ndim);
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
nvtx3::scoped_range r("Sqrt::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
if (recip_) {
|
if (recip_) {
|
||||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||||
|
|||||||
Reference in New Issue
Block a user