mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Multi output primitives (#330)
* Multi-output primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
51
mlx/array.h
51
mlx/array.h
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
@@ -174,7 +173,13 @@ class array {
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
@@ -182,6 +187,11 @@ class array {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
}
|
||||
|
||||
/** A unique identifier for an arrays primitive. */
|
||||
std::uintptr_t primitive_id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
|
||||
}
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
@@ -219,12 +229,32 @@ class array {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
|
||||
/** A non-const reference to the array's inputs so that they can be used to
|
||||
* edit the graph. */
|
||||
std::vector<array>& editable_inputs() {
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
auto idx = array_desc_->position;
|
||||
std::vector<array> outputs;
|
||||
outputs.reserve(siblings().size() + 1);
|
||||
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
||||
@@ -299,7 +329,7 @@ class array {
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::unique_ptr<Primitive> primitive{nullptr};
|
||||
std::shared_ptr<Primitive> primitive{nullptr};
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
@@ -321,16 +351,19 @@ class array {
|
||||
Flags flags;
|
||||
|
||||
std::vector<array> inputs;
|
||||
// An array to keep track of the siblings from a multi-output
|
||||
// primitive.
|
||||
std::vector<array> siblings;
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
~ArrayDesc();
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
|
||||
Reference in New Issue
Block a user