// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include #include "mlx/array.h" #include "mlx/primitives.h" namespace mlx::core { inline bool is_static_cast(const Primitive& p) { return ( typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); } std::string build_lib_name( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, const std::unordered_set& constant_ids); std::string get_type_string(Dtype d); template void print_float_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); os << std::setprecision(std::numeric_limits::digits10 + 1) << x.item() << std::setprecision(old_precision); } template void print_int_constant(std::ostream& os, const array& x) { os << x.item(); } template void print_complex_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); T constant = x.item(); os << get_type_string(x.dtype()) << "(" << std::setprecision(std::numeric_limits::digits10 + 1) << constant.real() << ", " << constant.imag() << ")" << std::setprecision(old_precision); } void print_constant(std::ostream& os, const array& x); } // namespace mlx::core