No copy command encoder (#986)

* no copy command encoder

* up layer norm test tolerances
This commit is contained in:
Awni Hannun
2024-04-11 21:15:36 -07:00
committed by GitHub
parent 91eba8e485
commit ae18326533
3 changed files with 10 additions and 7 deletions

View File

@@ -129,7 +129,7 @@ Device::~Device() {
b.second.second->release();
}
for (auto& e : encoder_map_) {
e.second->release();
(*e.second)->release();
}
for (auto& k : kernel_map_) {
k.second->release();
@@ -200,8 +200,8 @@ void Device::commit_command_buffer(int index) {
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
eit->second->endEncoding();
eit->second->release();
(*eit->second)->endEncoding();
(*eit->second)->release();
encoder_map_.erase(eit);
}
}
@@ -214,9 +214,11 @@ CommandEncoder& Device::get_command_encoder(int index) {
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;
eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
}
return eit->second;
return *(eit->second);
}
void Device::register_library(

View File

@@ -39,6 +39,7 @@ using MTLFCList =
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false){};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
struct ConcurrentContext {
@@ -197,7 +198,7 @@ class Device {
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, CommandEncoder> encoder_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;