Multi output primitives (#330)

* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-08 16:39:08 -08:00
committed by GitHub
parent f45f70f133
commit f099ebe535
26 changed files with 2313 additions and 1039 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
@@ -174,7 +173,13 @@ class array {
array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -182,6 +187,11 @@ class array {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
}
/** A unique identifier for an arrays primitive. */
std::uintptr_t primitive_id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
}
struct Data {
allocator::Buffer buffer;
deleter_t d;
@@ -219,12 +229,32 @@ class array {
return array_desc_->inputs;
};
/** A non-const reference to the array's inputs so that they can be used to
* edit the graph. */
std::vector<array>& editable_inputs() {
std::vector<array>& inputs() {
return array_desc_->inputs;
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
auto idx = array_desc_->position;
std::vector<array> outputs;
outputs.reserve(siblings().size() + 1);
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
};
/** Detach the array from the graph. */
void detach();
@@ -299,7 +329,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::unique_ptr<Primitive> primitive{nullptr};
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -321,16 +351,19 @@ class array {
Flags flags;
std::vector<array> inputs;
// An array to keep track of the siblings from a multi-output
// primitive.
std::vector<array> siblings;
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
~ArrayDesc();
};
// The ArrayDesc contains the details of the materialized array including the