2024-04-10 21:45:31 -07:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-11-30 11:12:53 -08:00
|
|
|
|
2023-11-29 10:30:41 -08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "mlx/array.h"
|
|
|
|
|
#include "mlx/backend/metal/device.h"
|
2024-03-28 09:40:31 -07:00
|
|
|
#include "mlx/primitives.h"
|
2023-11-29 10:30:41 -08:00
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
2024-04-10 21:45:31 -07:00
|
|
|
using metal::CommandEncoder;
|
2024-03-20 10:39:25 -07:00
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void set_vector_bytes(
|
2024-04-10 21:45:31 -07:00
|
|
|
CommandEncoder& enc,
|
2024-03-20 10:39:25 -07:00
|
|
|
const std::vector<T>& vec,
|
|
|
|
|
size_t nelems,
|
|
|
|
|
int idx) {
|
|
|
|
|
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
2024-04-10 21:45:31 -07:00
|
|
|
inline void
|
|
|
|
|
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
2024-03-20 10:39:25 -07:00
|
|
|
return set_vector_bytes(enc, vec, vec.size(), idx);
|
2023-11-29 10:30:41 -08:00
|
|
|
}
|
|
|
|
|
|
2024-08-07 13:38:07 -07:00
|
|
|
std::string type_to_name(const array& a);
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-08-07 13:38:07 -07:00
|
|
|
// Compute the thread block dimensions which fit the given
|
|
|
|
|
// input dimensions.
|
|
|
|
|
// - The thread block dimensions will be powers of two
|
2024-10-18 11:06:40 -07:00
|
|
|
// - The thread block size will be less than 2^pow2
|
|
|
|
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-07-30 17:18:39 -07:00
|
|
|
// Computes a 2D grid where each element is < UINT_MAX
|
|
|
|
|
// Assumes:
|
|
|
|
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
|
|
|
|
// - shape and strides correspond to a contiguous (no holes) but
|
|
|
|
|
// possibly broadcasted array
|
|
|
|
|
MTL::Size get_2d_grid_dims(
|
|
|
|
|
const std::vector<int>& shape,
|
2024-08-07 13:38:07 -07:00
|
|
|
const std::vector<size_t>& strides);
|
2024-07-30 17:18:39 -07:00
|
|
|
|
2024-03-28 09:40:31 -07:00
|
|
|
inline NS::String* make_string(std::ostringstream& os) {
|
|
|
|
|
std::string string = os.str();
|
|
|
|
|
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
|
|
|
|
#ifdef MLX_METAL_DEBUG
|
|
|
|
|
std::ostringstream label;
|
|
|
|
|
label << "Stream " << index;
|
|
|
|
|
queue->setLabel(make_string(label));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void debug_set_primitive_buffer_label(
|
|
|
|
|
MTL::CommandBuffer* command_buffer,
|
|
|
|
|
Primitive& primitive) {
|
|
|
|
|
#ifdef MLX_METAL_DEBUG
|
|
|
|
|
std::ostringstream label;
|
2024-04-07 21:47:43 -07:00
|
|
|
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
|
|
|
|
label << cbuf_label->utf8String();
|
|
|
|
|
}
|
2024-03-28 09:40:31 -07:00
|
|
|
primitive.print(label);
|
|
|
|
|
command_buffer->setLabel(make_string(label));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
2024-08-07 13:38:07 -07:00
|
|
|
std::string get_primitive_string(Primitive* primitive);
|
2023-11-29 10:30:41 -08:00
|
|
|
|
|
|
|
|
} // namespace mlx::core
|