MLX
Loading...
Searching...
No Matches
compiled.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2#pragma once
3
4#include <iomanip>
5#include <sstream>
6#include <unordered_set>
7
8#include "mlx/array.h"
9#include "mlx/primitives.h"
10
11namespace mlx::core {
12
13inline bool is_static_cast(const Primitive& p) {
14 return (
15 typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
16 typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
17}
18
19std::string build_lib_name(
20 const std::vector<array>& inputs,
21 const std::vector<array>& outputs,
22 const std::vector<array>& tape,
23 const std::unordered_set<uintptr_t>& constant_ids);
24
25std::string get_type_string(Dtype d);
26
27template <typename T>
28void print_float_constant(std::ostream& os, const array& x) {
29 auto old_precision = os.precision();
30 os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
31 << x.item<T>() << std::setprecision(old_precision);
32}
33
34template <typename T>
35void print_int_constant(std::ostream& os, const array& x) {
36 os << x.item<T>();
37}
38
39template <typename T>
40void print_complex_constant(std::ostream& os, const array& x) {
41 auto old_precision = os.precision();
42 T constant = x.item<T>();
43
44 os << get_type_string(x.dtype()) << "("
45 << std::setprecision(std::numeric_limits<float>::digits10 + 1)
46 << constant.real() << ", " << constant.imag() << ")"
47 << std::setprecision(old_precision);
48}
49
50void print_constant(std::ostream& os, const array& x);
51
52inline bool is_scalar(const array& x) {
53 return x.ndim() == 0;
54}
55
56// Check if we can use a contiguous operation given inputs and the output shape
58 const std::vector<array>& inputs,
59 const std::vector<int>& shape);
60
61// Allocate space for the outputs possibly with input donation
63 const std::vector<array>& inputs,
64 std::vector<array>& outputs,
65 const std::vector<array>& inputs_,
66 const std::unordered_set<uintptr_t>& constant_ids_,
67 bool contiguous,
68 bool move_buffers = false);
69
70} // namespace mlx::core
Definition primitives.h:416
Definition primitives.h:526
Definition primitives.h:681
Definition primitives.h:48
Definition primitives.h:1975
Definition array.h:20
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
T item()
Get the value from a scalar array.
Definition array.h:490
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
Definition allocator.h:7
void print_complex_constant(std::ostream &os, const array &x)
Definition compiled.h:40
bool compiled_check_contiguity(const std::vector< array > &inputs, const std::vector< int > &shape)
std::string build_lib_name(const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
void print_constant(std::ostream &os, const array &x)
void print_float_constant(std::ostream &os, const array &x)
Definition compiled.h:28
void print_int_constant(std::ostream &os, const array &x)
Definition compiled.h:35
bool is_scalar(const array &x)
Definition compiled.h:52
void compiled_allocate_outputs(const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
std::string get_type_string(Dtype d)
bool is_static_cast(const Primitive &p)
Definition compiled.h:13
Definition dtype.h:15