mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	compile binding
This commit is contained in:
		@@ -169,7 +169,7 @@ array::ArrayDesc::ArrayDesc(
 | 
				
			|||||||
      dtype(dtype),
 | 
					      dtype(dtype),
 | 
				
			||||||
      primitive(std::move(primitive)),
 | 
					      primitive(std::move(primitive)),
 | 
				
			||||||
      inputs(inputs) {
 | 
					      inputs(inputs) {
 | 
				
			||||||
  std::tie(size, strides) = cum_prod(shape);
 | 
					  std::tie(size, strides) = cum_prod(this->shape);
 | 
				
			||||||
  for (auto& in : inputs) {
 | 
					  for (auto& in : inputs) {
 | 
				
			||||||
    is_tracer |= in.is_tracer();
 | 
					    is_tracer |= in.is_tracer();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -184,7 +184,7 @@ array::ArrayDesc::ArrayDesc(
 | 
				
			|||||||
      dtype(dtype),
 | 
					      dtype(dtype),
 | 
				
			||||||
      primitive(std::move(primitive)),
 | 
					      primitive(std::move(primitive)),
 | 
				
			||||||
      inputs(std::move(inputs)) {
 | 
					      inputs(std::move(inputs)) {
 | 
				
			||||||
  std::tie(size, strides) = cum_prod(shape);
 | 
					  std::tie(size, strides) = cum_prod(this->shape);
 | 
				
			||||||
  for (auto& in : inputs) {
 | 
					  for (auto& in : inputs) {
 | 
				
			||||||
    is_tracer |= in.is_tracer();
 | 
					    is_tracer |= in.is_tracer();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										246
									
								
								mlx/compile.cpp
									
									
									
									
									
								
							
							
						
						
									
										246
									
								
								mlx/compile.cpp
									
									
									
									
									
								
							@@ -1,15 +1,21 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
#include <iostream> // TODO
 | 
					#include <iostream> // TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
#include <unordered_map>
 | 
					#include <unordered_map>
 | 
				
			||||||
#include <unordered_set>
 | 
					#include <unordered_set>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlx/primitives.h"
 | 
				
			||||||
#include "mlx/transforms.h"
 | 
					#include "mlx/transforms.h"
 | 
				
			||||||
#include "mlx/transforms_impl.h"
 | 
					#include "mlx/transforms_impl.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mlx::core {
 | 
					namespace mlx::core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace detail {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
 | 
					using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
 | 
				
			||||||
 | 
					using ParentsMap =
 | 
				
			||||||
 | 
					    std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T, typename... U>
 | 
					template <typename T, typename... U>
 | 
				
			||||||
size_t getAddress(std::function<T(U...)> f) {
 | 
					size_t getAddress(std::function<T(U...)> f) {
 | 
				
			||||||
@@ -28,9 +34,9 @@ struct CompilerCache {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Returns a reference to a CacheEntry which can be updated
 | 
					  // Returns a reference to a CacheEntry which can be updated
 | 
				
			||||||
  // by the caller to avoid copying large tapes / inputs / outputs
 | 
					  // by the caller to avoid copying large tapes / inputs / outputs
 | 
				
			||||||
  CacheEntry& find(const CompileFn& fn, const std::vector<array>& inputs) {
 | 
					  CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
 | 
				
			||||||
    // Try to find the entry
 | 
					    // Try to find the entry
 | 
				
			||||||
    auto inserted = cache_.insert({getAddress(fn), {}});
 | 
					    auto inserted = cache_.insert({fun_id, {}});
 | 
				
			||||||
    auto& entries = inserted.first->second;
 | 
					    auto& entries = inserted.first->second;
 | 
				
			||||||
    auto is_match = [](const std::vector<array>& in1,
 | 
					    auto is_match = [](const std::vector<array>& in1,
 | 
				
			||||||
                       const std::vector<array>& in2) {
 | 
					                       const std::vector<array>& in2) {
 | 
				
			||||||
@@ -93,38 +99,40 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
 | 
				
			|||||||
  return {tracer_inputs, fun(tracer_inputs)};
 | 
					  return {tracer_inputs, fun(tracer_inputs)};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<array> compile_dfs_graph(
 | 
					// Traverses the graph to build a tape and a map of array ids to their parents
 | 
				
			||||||
 | 
					std::pair<std::vector<array>, ParentsMap> compile_dfs(
 | 
				
			||||||
    const std::vector<array>& inputs,
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
    const std::vector<array>& outputs) {
 | 
					    const std::vector<array>& outputs) {
 | 
				
			||||||
  std::unordered_set<std::uintptr_t> needs_compile;
 | 
					  std::function<void(const array&)> recurse;
 | 
				
			||||||
 | 
					  std::vector<array> tape;
 | 
				
			||||||
  std::unordered_set<std::uintptr_t> cache;
 | 
					  std::unordered_set<std::uintptr_t> cache;
 | 
				
			||||||
 | 
					  std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
 | 
				
			||||||
 | 
					      parents_map;
 | 
				
			||||||
 | 
					  std::unordered_set<std::uintptr_t> needs_compile;
 | 
				
			||||||
  for (int i = 0; i < inputs.size(); ++i) {
 | 
					  for (int i = 0; i < inputs.size(); ++i) {
 | 
				
			||||||
    auto in = inputs[i];
 | 
					    auto in = inputs[i];
 | 
				
			||||||
    needs_compile.insert(in.id());
 | 
					    needs_compile.insert(in.id());
 | 
				
			||||||
    cache.insert(in.id());
 | 
					    cache.insert(in.id());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Topologically sort the graph
 | 
					  // DFS the graph to build the tape, and log parents and scalars
 | 
				
			||||||
  std::vector<array> tape;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  std::function<void(const array&)> recurse;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  recurse = [&](const array& a) {
 | 
					  recurse = [&](const array& a) {
 | 
				
			||||||
    auto id = a.id();
 | 
					    auto id = a.id();
 | 
				
			||||||
    if (cache.find(id) != cache.end()) {
 | 
					    if (cache.find(id) != cache.end()) {
 | 
				
			||||||
      return;
 | 
					      return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    for (int i = 0; i < a.inputs().size(); i++) {
 | 
				
			||||||
 | 
					      auto& in = a.inputs()[i];
 | 
				
			||||||
 | 
					      parents_map[in.id()].push_back({a, i});
 | 
				
			||||||
 | 
					      for (auto& s : a.siblings()) {
 | 
				
			||||||
 | 
					        parents_map[in.id()].push_back({s, i});
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      recurse(in);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    cache.insert(id);
 | 
					    cache.insert(id);
 | 
				
			||||||
    for (auto& s : a.siblings()) {
 | 
					    for (auto& s : a.siblings()) {
 | 
				
			||||||
      cache.insert(s.id());
 | 
					      cache.insert(s.id());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Recurse on inputs
 | 
					 | 
				
			||||||
    for (auto& input : a.inputs()) {
 | 
					 | 
				
			||||||
      recurse(input);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    // If any input needs a vmap, then the outputs also need
 | 
					 | 
				
			||||||
    // a vmap
 | 
					 | 
				
			||||||
    for (auto& input : a.inputs()) {
 | 
					    for (auto& input : a.inputs()) {
 | 
				
			||||||
      if (needs_compile.find(input.id()) != needs_compile.end()) {
 | 
					      if (needs_compile.find(input.id()) != needs_compile.end()) {
 | 
				
			||||||
        tape.push_back(a);
 | 
					        tape.push_back(a);
 | 
				
			||||||
@@ -136,16 +144,165 @@ std::vector<array> compile_dfs_graph(
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					  for (auto& a : outputs) {
 | 
				
			||||||
  for (auto& out : outputs) {
 | 
					    recurse(a);
 | 
				
			||||||
    if (out.has_primitive()) {
 | 
					 | 
				
			||||||
      recurse(out);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return tape;
 | 
					  return {tape, parents_map};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<array> compile_tape_replace(
 | 
					// Simplify the tape. Note, this function modifies in-place both the tape and
 | 
				
			||||||
 | 
					// the parents map to remove orphaned arrays
 | 
				
			||||||
 | 
					void compile_simplify(
 | 
				
			||||||
 | 
					    std::vector<array>& tape,
 | 
				
			||||||
 | 
					    ParentsMap& parents_map,
 | 
				
			||||||
 | 
					    const std::vector<array>& outputs,
 | 
				
			||||||
 | 
					    int passes) {
 | 
				
			||||||
 | 
					  // Helpers to identify identical scalars
 | 
				
			||||||
 | 
					  std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
 | 
				
			||||||
 | 
					  auto is_scalar = [](const array& a) {
 | 
				
			||||||
 | 
					    return a.is_evaled() && a.ndim() == 0;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  auto get_scalar_rep = [](const array& a) {
 | 
				
			||||||
 | 
					    uint64_t v = 0;
 | 
				
			||||||
 | 
					    int dtype;
 | 
				
			||||||
 | 
					    switch (a.dtype().size) {
 | 
				
			||||||
 | 
					      case 1:
 | 
				
			||||||
 | 
					        v = *a.data<uint8_t>();
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					      case 4:
 | 
				
			||||||
 | 
					        v = *a.data<uint32_t>();
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					      case 8:
 | 
				
			||||||
 | 
					        v = *a.data<uint64_t>();
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return std::make_pair(v, a.dtype().val);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (auto& a : tape) {
 | 
				
			||||||
 | 
					    if (is_scalar(a)) {
 | 
				
			||||||
 | 
					      scalars.insert({get_scalar_rep(a), a});
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Helper that fuses two arrays in the graph by setting the parents of the
 | 
				
			||||||
 | 
					  // source to point to the destination
 | 
				
			||||||
 | 
					  auto fuse = [&](array& dst, array& src) {
 | 
				
			||||||
 | 
					    // Canonicalize the order of the primitives outputs
 | 
				
			||||||
 | 
					    auto sources = src.outputs();
 | 
				
			||||||
 | 
					    auto dests = dst.outputs();
 | 
				
			||||||
 | 
					    // For each src parent, point it to the corresponding dest
 | 
				
			||||||
 | 
					    for (int i = 0; i < sources.size(); ++i) {
 | 
				
			||||||
 | 
					      auto src_parents = parents_map.find(sources[i].id());
 | 
				
			||||||
 | 
					      if (src_parents == parents_map.end()) {
 | 
				
			||||||
 | 
					        continue;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      auto& pairs = parents_map[dests[i].id()];
 | 
				
			||||||
 | 
					      for (auto& parent : src_parents->second) {
 | 
				
			||||||
 | 
					        parent.first.inputs()[parent.second] = dests[i];
 | 
				
			||||||
 | 
					        pairs.push_back(parent);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      // Remove the source from the map to avoid fusing with it again
 | 
				
			||||||
 | 
					      parents_map.erase(src_parents);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Depth-1 array equivalence check.
 | 
				
			||||||
 | 
					  auto array_equivalent = [](const array& a, const array& b) {
 | 
				
			||||||
 | 
					    if (!a.has_primitive() || !b.has_primitive()) {
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (a.primitive_id() == b.primitive_id()) {
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    const auto& pa = a.primitive();
 | 
				
			||||||
 | 
					    const auto& pb = b.primitive();
 | 
				
			||||||
 | 
					    if (typeid(pa) != typeid(pb)) {
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (a.inputs().size() != b.inputs().size()) {
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int i = 0; i < a.inputs().size(); i++) {
 | 
				
			||||||
 | 
					      if (a.inputs()[i].id() != b.inputs()[i].id()) {
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return pa.is_equivalent(pb);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pass 0: fuse scalars
 | 
				
			||||||
 | 
					  std::vector<array> new_tape;
 | 
				
			||||||
 | 
					  for (auto& arr : tape) {
 | 
				
			||||||
 | 
					    // Check if we can fuse scalars
 | 
				
			||||||
 | 
					    if (is_scalar(arr)) {
 | 
				
			||||||
 | 
					      auto scalar = scalars.find(get_scalar_rep(arr));
 | 
				
			||||||
 | 
					      if (scalar->second.id() != arr.id()) {
 | 
				
			||||||
 | 
					        fuse(scalar->second, arr);
 | 
				
			||||||
 | 
					        // Don't keep orphaned scalars in the tape
 | 
				
			||||||
 | 
					        continue;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    new_tape.push_back(std::move(arr));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  tape = std::move(new_tape);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::unordered_set<uintptr_t> output_set;
 | 
				
			||||||
 | 
					  for (auto& o : outputs) {
 | 
				
			||||||
 | 
					    output_set.insert(o.id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape
 | 
				
			||||||
 | 
					  for (int pass = 0; pass < passes; ++pass) {
 | 
				
			||||||
 | 
					    for (auto& arr : tape) {
 | 
				
			||||||
 | 
					      // Helper to check if we can fuse the parents of the
 | 
				
			||||||
 | 
					      // given array
 | 
				
			||||||
 | 
					      // If an array has no parents and siblings have
 | 
				
			||||||
 | 
					      auto maybe_fuse_parents = [&](auto& a) {
 | 
				
			||||||
 | 
					        auto parents = parents_map.find(a.id());
 | 
				
			||||||
 | 
					        if (parents != parents_map.end()) {
 | 
				
			||||||
 | 
					          auto N = parents->second.size();
 | 
				
			||||||
 | 
					          std::vector<bool> mask(N, false);
 | 
				
			||||||
 | 
					          for (int i = 0; i < N; i++) {
 | 
				
			||||||
 | 
					            if (mask[i]) {
 | 
				
			||||||
 | 
					              continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            for (int j = i + 1; j < N; j++) {
 | 
				
			||||||
 | 
					              if (mask[j]) {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					              auto& src = parents->second[j].first;
 | 
				
			||||||
 | 
					              auto& dst = parents->second[i].first;
 | 
				
			||||||
 | 
					              if (src.id() != dst.id() && array_equivalent(src, dst)) {
 | 
				
			||||||
 | 
					                fuse(dst, src);
 | 
				
			||||||
 | 
					                mask[j] = true;
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          return false;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          return output_set.find(a.id()) != output_set.end();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      bool discard = maybe_fuse_parents(arr);
 | 
				
			||||||
 | 
					      for (auto& s : arr.siblings()) {
 | 
				
			||||||
 | 
					        discard &= maybe_fuse_parents(s);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      // If an array and its siblings have no parents, and none of them are
 | 
				
			||||||
 | 
					      // outputs, it is safe to remove it from the tape
 | 
				
			||||||
 | 
					      if (!discard) {
 | 
				
			||||||
 | 
					        new_tape.push_back(std::move(arr));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    tape = std::move(new_tape);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::vector<array> compile_replace(
 | 
				
			||||||
    const std::vector<array>& tape,
 | 
					    const std::vector<array>& tape,
 | 
				
			||||||
    const std::vector<array>& trace_inputs,
 | 
					    const std::vector<array>& trace_inputs,
 | 
				
			||||||
    const std::vector<array>& trace_outputs,
 | 
					    const std::vector<array>& trace_outputs,
 | 
				
			||||||
@@ -155,7 +312,6 @@ std::vector<array> compile_tape_replace(
 | 
				
			|||||||
    trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
 | 
					    trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // We need a map here of traced inputs to real inputs
 | 
					 | 
				
			||||||
  for (auto& a : tape) {
 | 
					  for (auto& a : tape) {
 | 
				
			||||||
    if (!a.has_primitive()) {
 | 
					    if (!a.has_primitive()) {
 | 
				
			||||||
      std::runtime_error(
 | 
					      std::runtime_error(
 | 
				
			||||||
@@ -177,36 +333,50 @@ std::vector<array> compile_tape_replace(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
					std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
				
			||||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
 | 
					    const std::function<std::vector<array>(const std::vector<array>&)>& fun,
 | 
				
			||||||
  //  std::cout << getAddress(fun) << std::endl;
 | 
					    size_t fun_id) {
 | 
				
			||||||
  return [&fun](const std::vector<array>& inputs) {
 | 
					  return [&fun, fun_id](const std::vector<array>& inputs) {
 | 
				
			||||||
    // Find a cache entry with the correct inputs
 | 
					    // Find a cache entry with the correct inputs
 | 
				
			||||||
    auto& entry = compiler_cache().find(fun, inputs);
 | 
					    auto& entry = compiler_cache().find(fun_id, inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // No matching cache entry existed, so compile
 | 
					    // No matching cache entry existed, so compile
 | 
				
			||||||
    if (entry.empty) {
 | 
					    if (entry.empty) {
 | 
				
			||||||
      std::cout << "RECOMPILING? " << std::endl;
 | 
					 | 
				
			||||||
      // Mark the entry as not empty since we are about to fill it
 | 
					      // Mark the entry as not empty since we are about to fill it
 | 
				
			||||||
      entry.empty = false;
 | 
					      entry.empty = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Trace te build the graph
 | 
					      // Trace te build the graph
 | 
				
			||||||
      std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
 | 
					      std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // This is a good point to do optimizations:
 | 
					      // DFS the graph and get a tape, and a map of array id to (parent,
 | 
				
			||||||
      // - simplify
 | 
					      // position in parent inputs)
 | 
				
			||||||
      // - kernel fusion to generate new primitives
 | 
					      std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
 | 
				
			||||||
      // - may make sense to keep the tape from simplify
 | 
					          parents_map;
 | 
				
			||||||
      //   and pass it around so that we don't have to keep rebuilding it
 | 
					      std::tie(entry.tape, parents_map) =
 | 
				
			||||||
 | 
					          compile_dfs(entry.inputs, entry.outputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Recurse to build the tape
 | 
					      // Simplify the tape
 | 
				
			||||||
      entry.tape = compile_dfs_graph(entry.inputs, entry.outputs);
 | 
					      // compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */
 | 
				
			||||||
 | 
					      // 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // This is a good point to do more optimizations, e.g. kernel fusion to
 | 
				
			||||||
 | 
					      // generate new primitives. The tape needs to be updated accordingly
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // At this point we must have a tape, now replace the placeholders
 | 
					    // At this point we must have a tape, now replace the placeholders
 | 
				
			||||||
    // with real arrays that can be evaluated
 | 
					    // with real arrays that can be evaluated
 | 
				
			||||||
    return compile_tape_replace(
 | 
					    return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
 | 
				
			||||||
        entry.tape, entry.inputs, entry.outputs, inputs);
 | 
					 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					} // namespace detail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
				
			||||||
 | 
					    const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
 | 
				
			||||||
 | 
					  auto fun_id = detail::getAddress(fun);
 | 
				
			||||||
 | 
					  if (fun_id == 0) {
 | 
				
			||||||
 | 
					    throw std::invalid_argument(
 | 
				
			||||||
 | 
					        "[compile] Cannot compile a non-addressable function.");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return detail::compile(fun, fun_id);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core
 | 
					} // namespace mlx::core
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,6 +14,12 @@ std::vector<array> vmap_replace(
 | 
				
			|||||||
    const std::vector<int>& in_axes,
 | 
					    const std::vector<int>& in_axes,
 | 
				
			||||||
    const std::vector<int>& out_axes);
 | 
					    const std::vector<int>& out_axes);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// This is not part of the general C++ API as calling with a bad id is a bad
 | 
				
			||||||
 | 
					// idea.
 | 
				
			||||||
 | 
					std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
				
			||||||
 | 
					    const std::function<std::vector<array>(const std::vector<array>&)>& fun,
 | 
				
			||||||
 | 
					    size_t fun_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Create an InTracing object during tracing operations to signify to the rest
 | 
					// Create an InTracing object during tracing operations to signify to the rest
 | 
				
			||||||
// of the codebase that we are during tracing so evals should not throw away
 | 
					// of the codebase that we are during tracing so evals should not throw away
 | 
				
			||||||
// the graph.
 | 
					// the graph.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,5 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					#include <iostream> // TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <pybind11/functional.h>
 | 
					#include <pybind11/functional.h>
 | 
				
			||||||
#include <pybind11/pybind11.h>
 | 
					#include <pybind11/pybind11.h>
 | 
				
			||||||
@@ -437,6 +438,34 @@ auto py_vmap(
 | 
				
			|||||||
  };
 | 
					  };
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					auto py_compile(const py::function& fun) {
 | 
				
			||||||
 | 
					  return [fun](const py::args& args) {
 | 
				
			||||||
 | 
					    // Inputs must be array or tree of arrays
 | 
				
			||||||
 | 
					    auto inputs = tree_flatten(args, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // py_value_out will hold the output of the python function in order to be
 | 
				
			||||||
 | 
					    // able to reconstruct the python tree of extra return values
 | 
				
			||||||
 | 
					    py::object py_outputs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto compile_fun =
 | 
				
			||||||
 | 
					        [&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
 | 
				
			||||||
 | 
					          // Call the python function
 | 
				
			||||||
 | 
					          py_outputs = fun(*tree_unflatten(args, a));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // Flatten the outputs
 | 
				
			||||||
 | 
					          return tree_flatten(py_outputs, true);
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compile and call
 | 
				
			||||||
 | 
					    // TODO, awni, I think this cast is ok??
 | 
				
			||||||
 | 
					    size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
				
			||||||
 | 
					    auto outputs = detail::compile(compile_fun, fun_id)(inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Put the outputs back in the container
 | 
				
			||||||
 | 
					    return tree_unflatten(py_outputs, outputs);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void init_transforms(py::module_& m) {
 | 
					void init_transforms(py::module_& m) {
 | 
				
			||||||
  py::options options;
 | 
					  py::options options;
 | 
				
			||||||
  options.disable_function_signatures();
 | 
					  options.disable_function_signatures();
 | 
				
			||||||
@@ -736,4 +765,22 @@ void init_transforms(py::module_& m) {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      "file"_a);
 | 
					      "file"_a);
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "compile",
 | 
				
			||||||
 | 
					      [](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
 | 
				
			||||||
 | 
					      "fun"_a,
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        compile(fun: function) -> function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns a compiled function which produces the same output as ``fun``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            fun (function): A function which takes a variable number of
 | 
				
			||||||
 | 
					              :class:`array` or trees of :class:`array` and returns
 | 
				
			||||||
 | 
					              a variable number of :class:`array` or trees of :class:`array`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            function: A compiled function which has the same input arguments
 | 
				
			||||||
 | 
					            as ``fun`` and returns the the same output(s).
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										22
									
								
								python/tests/test_compile.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								python/tests/test_compile.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					# Copyright © 2023-2024 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import mlx_tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestCompile(mlx_tests.MLXTestCase):
 | 
				
			||||||
 | 
					    def test_simple_compile(self):
 | 
				
			||||||
 | 
					        def fun(x, y):
 | 
				
			||||||
 | 
					            return x + y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        compiled_fn = mx.compile(fun)
 | 
				
			||||||
 | 
					        compiled_fn = mx.compile(fun)
 | 
				
			||||||
 | 
					        x = mx.array(1.0)
 | 
				
			||||||
 | 
					        y = mx.array(1.0)
 | 
				
			||||||
 | 
					        # out = compiled_fn(x, y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    unittest.main()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user