// Copyright © 2023-2024 Apple Inc. #pragma once #include #include "mlx/array.h" namespace mlx::core::detail { // This is not part of the general C++ API as calling with a bad id is a bad // idea. std::function(const std::vector&)> compile( std::function(const std::vector&)> fun, std::uintptr_t fun_id, bool shapeless = false, std::vector constants = {}); // Erase cached compile functions void compile_erase(std::uintptr_t fun_id); // Clear the compiler cache causing a recompilation of all compiled functions // when called again. void compile_clear_cache(); bool compile_available_for_device(const Device& device); std::pair, std::vector> compile_trace( const std::function(const std::vector&)>& fun, const std::vector& inputs, bool shapeless); using ParentsMap = std::unordered_map>>; // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, const std::vector& outputs, const std::vector& original_inputs); // Simplify the tape. void compile_simplify( std::vector& tape, ParentsMap& parents_map, std::vector& outputs, int passes); std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, const std::vector& inputs, bool shapeless); void compile_validate_shapeless(const std::vector& tape); } // namespace mlx::core::detail