mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Explicit barriers with concurrent dispatch (#977)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -10,29 +10,11 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
using metal::CommandEncoder;
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
CommandEncoder& enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
@@ -40,10 +22,8 @@ inline void set_vector_bytes(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
inline void
|
||||
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user