More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
}
template <typename T>
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
size_t size = 1;
for (auto c : term) {
size *= dict[c];
@@ -120,7 +120,7 @@ size_t flop_count(
const CharSet& term,
bool inner,
int num_terms,
std::unordered_map<char, int> dict) {
std::unordered_map<char, ShapeElem> dict) {
size_t size = term_size(term, dict);
auto op_factor = 1;
if ((num_terms - 1) > op_factor) {
@@ -135,7 +135,7 @@ size_t flop_count(
std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map) {
std::unordered_map<char, ShapeElem> dim_map) {
CharSet contractions;
for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end());
@@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map,
std::unordered_map<char, ShapeElem> dim_map,
size_t cost_limit,
size_t memory_limit) {
// Helper struct for building the greedy path
@@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
}
Shape idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back());
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
auto idx = reshape(
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
for (int i = 0; i < v; ++i) {
indices.push_back(idx);
}
@@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
}
Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, int> dim_map;
std::unordered_map<char, ShapeElem> dim_map;
std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i];
@@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
// Check repeat subscripts are valid
if (in_set.size() < in.size()) {
std::unordered_map<char, int> local_dims;
std::unordered_map<char, ShapeElem> local_dims;
for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim});