mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
@@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
@@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
|
@@ -63,7 +63,15 @@ std::function<void()> make_task(
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
|
@@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator&&(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
}
|
||||
|
||||
@@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator||(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"[operator||] is only supported for bool arrays.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user