2023-12-01 03:12:53 +08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "mlx/device.h"
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
|
|
struct Stream {
|
|
|
|
int index;
|
|
|
|
Device device;
|
|
|
|
explicit Stream(int index, Device device) : index(index), device(device) {}
|
|
|
|
};
|
|
|
|
|
|
|
|
/** Get the default stream for the given device. */
|
|
|
|
Stream default_stream(Device d);
|
|
|
|
|
|
|
|
/** Make the stream the default for its device. */
|
|
|
|
void set_default_stream(Stream s);
|
|
|
|
|
|
|
|
/** Make a new stream on the given device. */
|
|
|
|
Stream new_stream(Device d);
|
|
|
|
|
2024-12-25 03:19:13 +08:00
|
|
|
/** Get the stream with the given index. */
|
|
|
|
Stream get_stream(int index);
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
|
|
|
return lhs.index == rhs.index;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
|
|
|
return !(lhs == rhs);
|
|
|
|
}
|
|
|
|
|
2024-04-22 23:25:46 +08:00
|
|
|
/* Synchronize with the default stream. */
|
|
|
|
void synchronize();
|
|
|
|
|
|
|
|
/* Synchronize with the provided stream. */
|
|
|
|
void synchronize(Stream);
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
} // namespace mlx::core
|