MLX
 
Loading...
Searching...
No Matches
encoder.h
Go to the documentation of this file.
1// Copyright © 2025 Apple Inc.
2
3#pragma once
4
5#include <unordered_map>
6
7#include "mlx/array.h"
8#include "mlx/scheduler.h"
9
10namespace mlx::core::cpu {
11
12// Number of dispatches per scheduler task
13constexpr int DISPATCHES_PER_TASK = 10;
14
16 CommandEncoder(Stream stream) : stream_(stream) {}
17
22
23 void set_input_array(const array& a) {}
25
26 // Hold onto a temporary until any already scheduled tasks which use it as
27 // an input are complete.
28 void add_temporary(array arr) {
29 temporaries_.push_back(std::move(arr));
30 }
31
32 void add_temporaries(std::vector<array> arrays) {
33 temporaries_.insert(
34 temporaries_.end(),
35 std::make_move_iterator(arrays.begin()),
36 std::make_move_iterator(arrays.end()));
37 }
38
39 std::vector<array>& temporaries() {
40 return temporaries_;
41 }
42
43 template <class F, class... Args>
44 void dispatch(F&& f, Args&&... args) {
45 num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
46 auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
47 if (num_ops_ == 0) {
49 auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
50 task();
52 };
53 scheduler::enqueue(stream_, std::move(task_wrap));
54 } else {
55 scheduler::enqueue(stream_, std::move(task));
56 }
57 }
58
59 private:
60 Stream stream_;
61 std::vector<array> temporaries_;
62 int num_ops_{0};
63};
64
66
67} // namespace mlx::core::cpu
Definition array.h:24
Definition encoder.h:10
constexpr int DISPATCHES_PER_TASK
Definition encoder.h:13
CommandEncoder & get_command_encoder(Stream stream)
void notify_task_completion(const Stream &stream)
Definition scheduler.h:177
void notify_new_task(const Stream &stream)
Definition scheduler.h:173
void enqueue(const Stream &stream, F &&f)
Definition scheduler.h:165
std::vector< array > Args
Definition export.h:11
Definition stream.h:9
Definition encoder.h:15
void add_temporary(array arr)
Definition encoder.h:28
CommandEncoder(Stream stream)
Definition encoder.h:16
void add_temporaries(std::vector< array > arrays)
Definition encoder.h:32
CommandEncoder & operator=(const CommandEncoder &)=delete
std::vector< array > & temporaries()
Definition encoder.h:39
CommandEncoder(const CommandEncoder &)=delete
void set_input_array(const array &a)
Definition encoder.h:23
CommandEncoder(CommandEncoder &&)=delete
void set_output_array(array &a)
Definition encoder.h:24
void dispatch(F &&f, Args &&... args)
Definition encoder.h:44
CommandEncoder & operator=(CommandEncoder &&)=delete