Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -116,11 +116,11 @@ class array {
};
/** Evaluate the array. */
void eval(bool retain_graph = false);
void eval();
/** Get the value from a scalar array. */
template <typename T>
T item(bool retain_graph = false);
T item();
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -265,9 +265,7 @@ class array {
array_desc_->is_tracer = is_tracer;
}
// Check if the array is a tracer array
bool is_tracer() const {
return array_desc_->is_tracer;
}
bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
@@ -381,11 +379,11 @@ array::array(
}
template <typename T>
T array::item(bool retain_graph /* = false */) {
T array::item() {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
eval(retain_graph);
eval();
return *data<T>();
}