From a64a8dfe45e7596023f164f2055ef9c1c3b8d3be Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 2 Jan 2025 16:16:16 -0800 Subject: [PATCH] fix extension (#1740) --- examples/extensions/axpby/axpby.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 70b02fb73..1a5d8c1c9 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -58,9 +58,9 @@ mx::array axpby( // Construct the array as the output of the Axpby primitive // with the broadcasted and upcasted arrays as inputs return mx::array( - /* const std::vector& shape = */ out_shape, + /* const mx::Shape& shape = */ out_shape, /* mx::Dtype dtype = */ out_dtype, - /* std::unique_ptr primitive = */ + /* std::shared_ptr primitive = */ std::make_shared(to_stream(s), alpha, beta), /* const std::vector& inputs = */ broadcasted_inputs); } @@ -279,7 +279,7 @@ void Axpby::eval_gpu( if (!contiguous_kernel) { compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder.set_vector_bytes(x.strides(), 6); - compute_encoder.set_bytes(y.strides(), 7); + compute_encoder.set_vector_bytes(y.strides(), 7); compute_encoder.set_bytes(ndim, 8); }