mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
@@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
|
||||
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
|
||||
size_t size = 1;
|
||||
for (auto c : term) {
|
||||
size *= dict[c];
|
||||
@@ -120,7 +120,7 @@ size_t flop_count(
|
||||
const CharSet& term,
|
||||
bool inner,
|
||||
int num_terms,
|
||||
std::unordered_map<char, int> dict) {
|
||||
std::unordered_map<char, ShapeElem> dict) {
|
||||
size_t size = term_size(term, dict);
|
||||
auto op_factor = 1;
|
||||
if ((num_terms - 1) > op_factor) {
|
||||
@@ -135,7 +135,7 @@ size_t flop_count(
|
||||
std::pair<size_t, int> compute_cost_and_scaling(
|
||||
const std::vector<Subscript>& inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map) {
|
||||
std::unordered_map<char, ShapeElem> dim_map) {
|
||||
CharSet contractions;
|
||||
for (auto& in : inputs) {
|
||||
contractions.insert(in.set.begin(), in.set.end());
|
||||
@@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
|
||||
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
|
||||
std::vector<Subscript> inputs,
|
||||
const Subscript& output,
|
||||
std::unordered_map<char, int> dim_map,
|
||||
std::unordered_map<char, ShapeElem> dim_map,
|
||||
size_t cost_limit,
|
||||
size_t memory_limit) {
|
||||
// Helper struct for building the greedy path
|
||||
@@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||
}
|
||||
Shape idx_shape(n_expand--, 1);
|
||||
idx_shape[0] = in.shape(axes.back());
|
||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||
auto idx = reshape(
|
||||
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
|
||||
for (int i = 0; i < v; ++i) {
|
||||
indices.push_back(idx);
|
||||
}
|
||||
@@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||
}
|
||||
Subscript output(out_subscript, std::move(out_set));
|
||||
|
||||
std::unordered_map<char, int> dim_map;
|
||||
std::unordered_map<char, ShapeElem> dim_map;
|
||||
std::vector<Subscript> inputs;
|
||||
for (int i = 0; i < in_subscripts.size(); ++i) {
|
||||
auto& in = in_subscripts[i];
|
||||
@@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
||||
|
||||
// Check repeat subscripts are valid
|
||||
if (in_set.size() < in.size()) {
|
||||
std::unordered_map<char, int> local_dims;
|
||||
std::unordered_map<char, ShapeElem> local_dims;
|
||||
for (int j = 0; j < in.size(); ++j) {
|
||||
auto dim = operands[i].shape(j);
|
||||
auto inserted = local_dims.insert({in[j], dim});
|
||||
|
Reference in New Issue
Block a user