13using metal::CommandEncoder;
16inline void set_vector_bytes(
18 const std::vector<T>& vec,
21 enc->setBytes(vec.data(), nelems *
sizeof(T), idx);
26set_vector_bytes(CommandEncoder& enc,
const std::vector<T>& vec,
int idx) {
27 return set_vector_bytes(enc, vec, vec.size(), idx);
30std::string type_to_name(
const array& a) {
76MTL::Size get_block_dims(
int dim0,
int dim1,
int dim2) {
77 int pows[3] = {0, 0, 0};
82 if (dim0 >= (1 << (pows[0] + 1))) {
89 if (dim1 >= (1 << (pows[1] + 1))) {
96 if (dim2 >= (1 << (pows[2] + 1))) {
100 if (
sum == presum ||
sum == 10) {
104 return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
107inline NS::String* make_string(std::ostringstream& os) {
108 std::string
string = os.str();
109 return NS::String::string(
string.c_str(), NS::UTF8StringEncoding);
112inline void debug_set_stream_queue_label(MTL::CommandQueue* queue,
int index) {
113#ifdef MLX_METAL_DEBUG
114 std::ostringstream label;
115 label <<
"Stream " << index;
116 queue->setLabel(make_string(label));
120inline void debug_set_primitive_buffer_label(
121 MTL::CommandBuffer* command_buffer,
122 Primitive& primitive) {
123#ifdef MLX_METAL_DEBUG
124 std::ostringstream label;
125 if (
auto cbuf_label = command_buffer->label(); cbuf_label) {
126 label << cbuf_label->utf8String();
128 primitive.print(label);
129 command_buffer->setLabel(make_string(label));
133std::string get_primitive_string(Primitive* primitive) {
134 std::ostringstream op_t;
135 primitive->print(op_t);
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75