diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 5f63c6337..9198548a4 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -223,7 +223,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive. /* const std::vector& shape = */ out_shape, /* Dtype dtype = */ out_dtype, /* std::unique_ptr primitive = */ - std::make_unique(to_stream(s), alpha, beta), + std::make_shared(to_stream(s), alpha, beta), /* const std::vector& inputs = */ broadcasted_inputs); } diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 732dc43b6..43b3aedc9 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -61,7 +61,7 @@ array axpby( /* const std::vector& shape = */ out_shape, /* Dtype dtype = */ out_dtype, /* std::unique_ptr primitive = */ - std::make_unique(to_stream(s), alpha, beta), + std::make_shared(to_stream(s), alpha, beta), /* const std::vector& inputs = */ broadcasted_inputs); } diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 96d0424ab..95791a57c 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -95,7 +95,7 @@ array fft_impl( return array( out_shape, out_type, - std::make_unique(to_stream(s), valid_axes, inverse, real), + std::make_shared(to_stream(s), valid_axes, inverse, real), {astype(in, in_type, s)}); } diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index c9c618fe9..de91ad549 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -217,7 +217,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { auto loaded_array = array( shape, dtype, - std::make_unique(to_stream(s), in_stream, offset, swap_endianness), + std::make_shared(to_stream(s), in_stream, offset, swap_endianness), std::vector{}); if (col_contiguous) { loaded_array = transpose(loaded_array, s); diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 1dd59f444..6f25aefee 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -136,7 +136,7 @@ SafetensorsLoad load_safetensors( auto loaded_array = array( shape, type, - std::make_unique( + std::make_shared( to_stream(s), in_stream, offset + data_offsets.at(0), false), std::vector{}); res.insert({item.key(), loaded_array}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4fb0fc7f4..4529e0131 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1267,7 +1267,7 @@ std::pair, std::vector> FFT::vmap( {array( out_shape, real_ && inverse_ ? float32 : complex64, - std::make_unique(stream(), fft_axes, inverse_, real_), + std::make_shared(stream(), fft_axes, inverse_, real_), {in})}, {ax}}; } @@ -1377,7 +1377,7 @@ std::pair, std::vector> Full::vmap( assert(axes.size() == 1); auto& in = inputs[0]; auto out = - array(in.shape(), in.dtype(), std::make_unique(stream()), {in}); + array(in.shape(), in.dtype(), std::make_shared(stream()), {in}); return {{out}, axes}; } @@ -1604,7 +1604,7 @@ std::pair, std::vector> Log::vmap( {array( in.shape(), in.dtype(), - std::make_unique(stream(), base_), + std::make_shared(stream(), base_), {in})}, axes}; } @@ -2259,7 +2259,7 @@ std::pair, std::vector> RandomBits::vmap( auto out = array( shape, get_dtype(), - std::make_unique(stream(), shape, width_), + std::make_shared(stream(), shape, width_), {key}); return {{out}, {kax}}; } @@ -2493,7 +2493,7 @@ std::pair, std::vector> Scan::vmap( {array( in.shape(), out_dtype, - std::make_unique( + std::make_shared( stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_), {in})}, axes}; @@ -3303,7 +3303,7 @@ std::pair, std::vector> NumberOfElements::vmap( array out = array( std::vector{}, dtype_, - std::make_unique(stream(), new_axes, inverted_, dtype_), + std::make_shared(stream(), new_axes, inverted_, dtype_), inputs); return {{out}, {-1}}; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 42c9213f6..74b1e1b04 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -53,7 +53,7 @@ void eval(std::vector outputs) { } auto synchronizer = array( - {}, bool_, std::make_unique(stream), std::move(outputs)); + {}, bool_, std::make_shared(stream), std::move(outputs)); size_t depth_counter = 0; recurse = [&](const array& a) { @@ -118,7 +118,7 @@ void eval(std::vector outputs) { } std::shared_ptr> p; if (auto it = deps.find(arr.output(0).id()); it != deps.end()) { - p = std::make_unique>(); + p = std::make_shared>(); ps.push_back(p); it->second = p->get_future().share(); } diff --git a/tests/arg_reduce_tests.cpp b/tests/arg_reduce_tests.cpp index f10cf5285..7fa01d837 100644 --- a/tests/arg_reduce_tests.cpp +++ b/tests/arg_reduce_tests.cpp @@ -16,7 +16,7 @@ void test_arg_reduce_small( std::vector expected_output) { auto s = default_stream(d); auto y = - array(out_shape, uint32, std::make_unique(s, r, axis), {x}); + array(out_shape, uint32, std::make_shared(s, r, axis), {x}); y.eval(); const uint32_t* ydata = y.data(); for (int i = 0; i < y.size(); i++) { @@ -32,12 +32,12 @@ void test_arg_reduce_against_cpu( auto y1 = array( out_shape, uint32, - std::make_unique(default_stream(Device::cpu), r, axis), + std::make_shared(default_stream(Device::cpu), r, axis), {x}); auto y2 = array( out_shape, uint32, - std::make_unique(default_stream(Device::gpu), r, axis), + std::make_shared(default_stream(Device::gpu), r, axis), {x}); y1.eval(); y2.eval(); @@ -136,7 +136,7 @@ void test_arg_reduce_small_bool( {2, 3, 4}); x.eval(); auto y = - array(out_shape, uint32, std::make_unique(s, r, axis), {x}); + array(out_shape, uint32, std::make_shared(s, r, axis), {x}); y.eval(); const uint32_t* ydata = y.data(); for (int i = 0; i < y.size(); i++) {