Files
mlx/mlx/backend/common/compiled.h
Awni Hannun c4230747a1 redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00

68 lines
1.8 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
template <typename T>
void print_complex_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
T constant = x.item<T>();
os << get_type_string(x.dtype()) << "("
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< constant.real() << ", " << constant.imag() << ")"
<< std::setprecision(old_precision);
}
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const Shape& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous);
} // namespace mlx::core