mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts unique_ptr.
This commit is contained in:
parent
dc175f08d3
commit
90dfa43ff1
@ -223,7 +223,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
/* Dtype dtype = */ out_dtype,
|
/* Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<Primitive> primitive = */
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ array axpby(
|
|||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
/* Dtype dtype = */ out_dtype,
|
/* Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<Primitive> primitive = */
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ array fft_impl(
|
|||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
|
std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||||
{astype(in, in_type, s)});
|
{astype(in, in_type, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,7 +217,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape,
|
shape,
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
if (col_contiguous) {
|
if (col_contiguous) {
|
||||||
loaded_array = transpose(loaded_array, s);
|
loaded_array = transpose(loaded_array, s);
|
||||||
|
@ -136,7 +136,7 @@ SafetensorsLoad load_safetensors(
|
|||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape,
|
shape,
|
||||||
type,
|
type,
|
||||||
std::make_unique<Load>(
|
std::make_shared<Load>(
|
||||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
res.insert({item.key(), loaded_array});
|
res.insert({item.key(), loaded_array});
|
||||||
|
@ -1267,7 +1267,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
|||||||
{array(
|
{array(
|
||||||
out_shape,
|
out_shape,
|
||||||
real_ && inverse_ ? float32 : complex64,
|
real_ && inverse_ ? float32 : complex64,
|
||||||
std::make_unique<FFT>(stream(), fft_axes, inverse_, real_),
|
std::make_shared<FFT>(stream(), fft_axes, inverse_, real_),
|
||||||
{in})},
|
{in})},
|
||||||
{ax}};
|
{ax}};
|
||||||
}
|
}
|
||||||
@ -1377,7 +1377,7 @@ std::pair<std::vector<array>, std::vector<int>> Full::vmap(
|
|||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto out =
|
auto out =
|
||||||
array(in.shape(), in.dtype(), std::make_unique<Full>(stream()), {in});
|
array(in.shape(), in.dtype(), std::make_shared<Full>(stream()), {in});
|
||||||
return {{out}, axes};
|
return {{out}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1604,7 +1604,7 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
|
|||||||
{array(
|
{array(
|
||||||
in.shape(),
|
in.shape(),
|
||||||
in.dtype(),
|
in.dtype(),
|
||||||
std::make_unique<Log>(stream(), base_),
|
std::make_shared<Log>(stream(), base_),
|
||||||
{in})},
|
{in})},
|
||||||
axes};
|
axes};
|
||||||
}
|
}
|
||||||
@ -2259,7 +2259,7 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
shape,
|
shape,
|
||||||
get_dtype(),
|
get_dtype(),
|
||||||
std::make_unique<RandomBits>(stream(), shape, width_),
|
std::make_shared<RandomBits>(stream(), shape, width_),
|
||||||
{key});
|
{key});
|
||||||
return {{out}, {kax}};
|
return {{out}, {kax}};
|
||||||
}
|
}
|
||||||
@ -2493,7 +2493,7 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
|||||||
{array(
|
{array(
|
||||||
in.shape(),
|
in.shape(),
|
||||||
out_dtype,
|
out_dtype,
|
||||||
std::make_unique<Scan>(
|
std::make_shared<Scan>(
|
||||||
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
||||||
{in})},
|
{in})},
|
||||||
axes};
|
axes};
|
||||||
@ -3303,7 +3303,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
|||||||
array out = array(
|
array out = array(
|
||||||
std::vector<int>{},
|
std::vector<int>{},
|
||||||
dtype_,
|
dtype_,
|
||||||
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||||
inputs);
|
inputs);
|
||||||
|
|
||||||
return {{out}, {-1}};
|
return {{out}, {-1}};
|
||||||
|
@ -53,7 +53,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto synchronizer = array(
|
auto synchronizer = array(
|
||||||
{}, bool_, std::make_unique<Synchronizer>(stream), std::move(outputs));
|
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||||
|
|
||||||
size_t depth_counter = 0;
|
size_t depth_counter = 0;
|
||||||
recurse = [&](const array& a) {
|
recurse = [&](const array& a) {
|
||||||
@ -118,7 +118,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
}
|
}
|
||||||
std::shared_ptr<std::promise<void>> p;
|
std::shared_ptr<std::promise<void>> p;
|
||||||
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
||||||
p = std::make_unique<std::promise<void>>();
|
p = std::make_shared<std::promise<void>>();
|
||||||
ps.push_back(p);
|
ps.push_back(p);
|
||||||
it->second = p->get_future().share();
|
it->second = p->get_future().share();
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ void test_arg_reduce_small(
|
|||||||
std::vector<int> expected_output) {
|
std::vector<int> expected_output) {
|
||||||
auto s = default_stream(d);
|
auto s = default_stream(d);
|
||||||
auto y =
|
auto y =
|
||||||
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
|
array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x});
|
||||||
y.eval();
|
y.eval();
|
||||||
const uint32_t* ydata = y.data<uint32_t>();
|
const uint32_t* ydata = y.data<uint32_t>();
|
||||||
for (int i = 0; i < y.size(); i++) {
|
for (int i = 0; i < y.size(); i++) {
|
||||||
@ -32,12 +32,12 @@ void test_arg_reduce_against_cpu(
|
|||||||
auto y1 = array(
|
auto y1 = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(default_stream(Device::cpu), r, axis),
|
std::make_shared<ArgReduce>(default_stream(Device::cpu), r, axis),
|
||||||
{x});
|
{x});
|
||||||
auto y2 = array(
|
auto y2 = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(default_stream(Device::gpu), r, axis),
|
std::make_shared<ArgReduce>(default_stream(Device::gpu), r, axis),
|
||||||
{x});
|
{x});
|
||||||
y1.eval();
|
y1.eval();
|
||||||
y2.eval();
|
y2.eval();
|
||||||
@ -136,7 +136,7 @@ void test_arg_reduce_small_bool(
|
|||||||
{2, 3, 4});
|
{2, 3, 4});
|
||||||
x.eval();
|
x.eval();
|
||||||
auto y =
|
auto y =
|
||||||
array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
|
array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x});
|
||||||
y.eval();
|
y.eval();
|
||||||
const uint32_t* ydata = y.data<uint32_t>();
|
const uint32_t* ydata = y.data<uint32_t>();
|
||||||
for (int i = 0; i < y.size(); i++) {
|
for (int i = 0; i < y.size(); i++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user