From 9307b2ab8b6b5aeed909047862c93fde21ecc12a Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Mon, 24 Mar 2025 16:08:40 +0100 Subject: [PATCH] Fixed 32-bit platform support for distributed/ring implementation (#1996) Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1). --- mlx/distributed/ring/ring.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 0b396f7f6..e55d960e7 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -625,7 +625,7 @@ class RingGroup : public GroupImpl { std::min( sockets_right_.size() + sockets_left_.size(), nbytes / min_send_size), - 1UL); + size_t(1)); size_t bytes_per_gather = ceildiv(nbytes, n_gathers); std::vector> all_gathers; for (int i = 0; i < n_gathers; i++) { @@ -740,7 +740,7 @@ class RingGroup : public GroupImpl { std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), - 1UL); + size_t(1)); size_t step = ceildiv(size, n_reduces); std::vector> all_sums; @@ -777,8 +777,8 @@ class RingGroup : public GroupImpl { // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. size_t segment_size = ceildiv(data_size, size_); - size_t BUFFER_SIZE = - std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + size_t BUFFER_SIZE = std::max( + size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments @@ -897,7 +897,8 @@ class RingGroup : public GroupImpl { void send(const std::vector& sockets, const char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { @@ -914,7 +915,8 @@ class RingGroup : public GroupImpl { } void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> recvs; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) {