Explicit barriers with concurrent dispatch (#977)

This commit is contained in:
Awni Hannun
2024-04-10 21:45:31 -07:00
committed by GitHub
parent 8580d997ff
commit 12d4507ee3
21 changed files with 326 additions and 267 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -10,29 +10,11 @@ namespace mlx::core {
namespace {
inline void
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
inline void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int64_t offset,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
using metal::CommandEncoder;
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
CommandEncoder& enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
@@ -40,10 +22,8 @@ inline void set_vector_bytes(
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
int idx) {
inline void
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}