mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
This commit is contained in:
parent
522d8d3917
commit
9307b2ab8b
@ -625,7 +625,7 @@ class RingGroup : public GroupImpl {
|
|||||||
std::min(
|
std::min(
|
||||||
sockets_right_.size() + sockets_left_.size(),
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
nbytes / min_send_size),
|
nbytes / min_send_size),
|
||||||
1UL);
|
size_t(1));
|
||||||
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||||
std::vector<std::future<void>> all_gathers;
|
std::vector<std::future<void>> all_gathers;
|
||||||
for (int i = 0; i < n_gathers; i++) {
|
for (int i = 0; i < n_gathers; i++) {
|
||||||
@ -740,7 +740,7 @@ class RingGroup : public GroupImpl {
|
|||||||
std::min(
|
std::min(
|
||||||
sockets_right_.size() + sockets_left_.size(),
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
nbytes / (size_ * min_send_size)),
|
nbytes / (size_ * min_send_size)),
|
||||||
1UL);
|
size_t(1));
|
||||||
size_t step = ceildiv(size, n_reduces);
|
size_t step = ceildiv(size, n_reduces);
|
||||||
std::vector<std::future<void>> all_sums;
|
std::vector<std::future<void>> all_sums;
|
||||||
|
|
||||||
@ -777,8 +777,8 @@ class RingGroup : public GroupImpl {
|
|||||||
// We split the data into `size_` segments of size `segment_size` and each
|
// We split the data into `size_` segments of size `segment_size` and each
|
||||||
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
||||||
size_t segment_size = ceildiv(data_size, size_);
|
size_t segment_size = ceildiv(data_size, size_);
|
||||||
size_t BUFFER_SIZE =
|
size_t BUFFER_SIZE = std::max(
|
||||||
std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
||||||
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
||||||
|
|
||||||
// Initial segments
|
// Initial segments
|
||||||
@ -897,7 +897,8 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
void
|
void
|
||||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
||||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
size_t segment_size =
|
||||||
|
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
||||||
std::vector<std::future<void>> sends;
|
std::vector<std::future<void>> sends;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
if (i * segment_size >= data_size) {
|
if (i * segment_size >= data_size) {
|
||||||
@ -914,7 +915,8 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
size_t segment_size =
|
||||||
|
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
||||||
std::vector<std::future<void>> recvs;
|
std::vector<std::future<void>> recvs;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
if (i * segment_size >= data_size) {
|
if (i * segment_size >= data_size) {
|
||||||
|
Loading…
Reference in New Issue
Block a user