mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Data parallel helper (#1407)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							8d68a3e805
						
					
				
				
					commit
					914409fef9
				
			| @@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void simple_sum( | ||||
|     void* input, | ||||
|     void* accumulator, | ||||
|     int* len, | ||||
|     MPI_Datatype* datatype) { | ||||
|   T* in = (T*)input; | ||||
|   T* acc = (T*)accumulator; | ||||
|   int N = *len; | ||||
|  | ||||
|   while (N-- > 0) { | ||||
|     *acc += *in; | ||||
|     acc++; | ||||
|     in++; | ||||
|   } | ||||
| } | ||||
| template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*); | ||||
| template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*); | ||||
|  | ||||
| struct MPIWrapper { | ||||
|   MPIWrapper() { | ||||
|     initialized_ = false; | ||||
|  | ||||
|     libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); | ||||
|     if (libmpi_handle_ == nullptr) { | ||||
|       return; | ||||
| @@ -50,6 +71,9 @@ struct MPIWrapper { | ||||
|     LOAD_SYMBOL(MPI_Allgather, all_gather); | ||||
|     LOAD_SYMBOL(MPI_Send, send); | ||||
|     LOAD_SYMBOL(MPI_Recv, recv); | ||||
|     LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); | ||||
|     LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); | ||||
|     LOAD_SYMBOL(MPI_Op_create, mpi_op_create); | ||||
|  | ||||
|     // Objects | ||||
|     LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); | ||||
| @@ -79,7 +103,24 @@ struct MPIWrapper { | ||||
|     if (!is_available()) { | ||||
|       return false; | ||||
|     } | ||||
|     return init(nullptr, nullptr) == MPI_SUCCESS; | ||||
|     bool success = init(nullptr, nullptr) == MPI_SUCCESS; | ||||
|  | ||||
|     // Initialize custom types and ops | ||||
|     if (success && !initialized_) { | ||||
|       // Custom float16 dtypes | ||||
|       mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_); | ||||
|       mpi_type_commit(&mpi_float16_); | ||||
|       mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); | ||||
|       mpi_type_commit(&mpi_bfloat16_); | ||||
|  | ||||
|       // Custom sum ops | ||||
|       mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_); | ||||
|       mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_); | ||||
|  | ||||
|       initialized_ = true; | ||||
|     } | ||||
|  | ||||
|     return success; | ||||
|   } | ||||
|  | ||||
|   void finalize_safe() { | ||||
| @@ -117,13 +158,21 @@ struct MPIWrapper { | ||||
|       case complex64: | ||||
|         return mpi_complex_; | ||||
|       case float16: | ||||
|         return mpi_float16_; | ||||
|       case bfloat16: | ||||
|         throw std::runtime_error("MPI doesn't support 16-bit floats"); | ||||
|         return mpi_bfloat16_; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   MPI_Op op_sum() { | ||||
|     return op_sum_; | ||||
|   MPI_Op op_sum(const array& arr) { | ||||
|     switch (arr.dtype()) { | ||||
|       case float16: | ||||
|         return op_sum_f16_; | ||||
|       case bfloat16: | ||||
|         return op_sum_bf16_; | ||||
|       default: | ||||
|         return op_sum_; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void* libmpi_handle_; | ||||
| @@ -152,6 +201,8 @@ struct MPIWrapper { | ||||
|  | ||||
|   // Ops | ||||
|   MPI_Op op_sum_; | ||||
|   MPI_Op op_sum_f16_; | ||||
|   MPI_Op op_sum_bf16_; | ||||
|  | ||||
|   // Datatypes | ||||
|   MPI_Datatype mpi_bool_; | ||||
| @@ -165,6 +216,16 @@ struct MPIWrapper { | ||||
|   MPI_Datatype mpi_uint64_; | ||||
|   MPI_Datatype mpi_float_; | ||||
|   MPI_Datatype mpi_complex_; | ||||
|   MPI_Datatype mpi_float16_; | ||||
|   MPI_Datatype mpi_bfloat16_; | ||||
|  | ||||
|  private: | ||||
|   bool initialized_; | ||||
|  | ||||
|   // Private API | ||||
|   int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*); | ||||
|   int (*mpi_type_commit)(MPI_Datatype*); | ||||
|   int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*); | ||||
| }; | ||||
|  | ||||
| MPIWrapper& mpi() { | ||||
| @@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) { | ||||
|       output.data<void>(), | ||||
|       input.size(), | ||||
|       mpi().datatype(input), | ||||
|       mpi().op_sum(), | ||||
|       mpi().op_sum(input), | ||||
|       to_comm(group)); | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| from functools import wraps | ||||
| from typing import Callable, Optional | ||||
| from functools import reduce, wraps | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| from ..utils import tree_flatten, tree_map, tree_unflatten | ||||
| from .layers.base import Module | ||||
|  | ||||
|  | ||||
| @@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None): | ||||
|         return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) | ||||
|  | ||||
|     return wrapped_checkpointed_fn | ||||
|  | ||||
|  | ||||
| def average_gradients( | ||||
|     gradients: Any, | ||||
|     group: Optional[mx.distributed.Group] = None, | ||||
|     all_reduce_size: int = 32 * 1024**2, | ||||
|     communication_type: Optional[mx.Dtype] = None, | ||||
| ): | ||||
|     """Average the gradients across the distributed processes in the passed group. | ||||
|  | ||||
|     This helper enables concatenating several gradients of small arrays to one | ||||
|     big all reduce call for better networking performance. | ||||
|  | ||||
|     Args: | ||||
|         gradients (Any): The Python tree containing the gradients (it should | ||||
|             have the same structure across processes) | ||||
|         group (Optional[mlx.core.distributed.Group]): The group of processes to | ||||
|             average the gradients. If set to ``None`` the global group is used. | ||||
|             Default: ``None``. | ||||
|         all_reduce_size (int): Group arrays until their size in bytes exceeds | ||||
|             this number. Perform one communication step per group of arrays. If | ||||
|             less or equal to 0 array grouping is disabled. Default: ``32MiB``. | ||||
|         communication_type (Optional[mlx.core.Dtype]): If provided cast to this | ||||
|             type before performing the communication. Typically cast to a | ||||
|             smaller float to reduce the communication size. Default: ``None``. | ||||
|     """ | ||||
|     group = group or mx.distributed.init() | ||||
|     N = group.size() | ||||
|  | ||||
|     if N == 1: | ||||
|         return gradients | ||||
|  | ||||
|     def _average(x): | ||||
|         dt = x.dtype | ||||
|         x = x.astype(communication_type) if communication_type is not None else x | ||||
|         return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N | ||||
|  | ||||
|     if all_reduce_size <= 0: | ||||
|         return tree_map(_average, gradients) | ||||
|  | ||||
|     else: | ||||
|         flat_grads = tree_flatten(gradients) | ||||
|         if len(flat_grads) == 0: | ||||
|             return gradients | ||||
|  | ||||
|         # Extract some info for the gradient | ||||
|         keys = [k for k, _ in flat_grads] | ||||
|         shapes = [v.shape for _, v in flat_grads] | ||||
|         sizes = [v.size for _, v in flat_grads] | ||||
|         dtypes = [v.dtype for _, v in flat_grads] | ||||
|  | ||||
|         # We can't group them if they have mixed types | ||||
|         if not all(dt == dtypes[0] for dt in dtypes): | ||||
|             return average_gradients(gradients, group, 0, communication_type) | ||||
|         itemsize = ( | ||||
|             communication_type.size | ||||
|             if communication_type is not None | ||||
|             else dtypes[0].size | ||||
|         ) | ||||
|  | ||||
|         # Gather the gradients in groups that are just above or equal to all_reduce_size | ||||
|         grad_groups = [] | ||||
|         grad_group = [] | ||||
|         grad_group_size = 0 | ||||
|         for i in range(len(keys)): | ||||
|             grad_group.append(i) | ||||
|             grad_group_size += sizes[i] * itemsize | ||||
|             if grad_group_size >= all_reduce_size: | ||||
|                 grad_groups.append(grad_group) | ||||
|                 grad_group = [] | ||||
|                 grad_group_size = 0 | ||||
|         if grad_group: | ||||
|             grad_groups.append(grad_group) | ||||
|             grad_group = [] | ||||
|  | ||||
|         # Concatenate-reduce-split | ||||
|         new_flat_grads = [] | ||||
|         for grad_group in grad_groups: | ||||
|             indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0]) | ||||
|             big_grad = mx.concatenate( | ||||
|                 [flat_grads[i][1].reshape(-1) for i in grad_group] | ||||
|             ) | ||||
|             big_grad = _average(big_grad) | ||||
|             big_grad = mx.split(big_grad, indices[1:-1]) | ||||
|             new_flat_grads.extend( | ||||
|                 (keys[j], big_grad[i].reshape(shapes[j])) | ||||
|                 for i, j in enumerate(grad_group) | ||||
|             ) | ||||
|  | ||||
|         return tree_unflatten(new_flat_grads) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| from mlx.nn.utils import average_gradients | ||||
|  | ||||
|  | ||||
| class TestDistributed(mlx_tests.MLXTestCase): | ||||
| @@ -110,6 +111,59 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) | ||||
|  | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|         xtype = None | ||||
|  | ||||
|         def new_all_sum(x, **kwargs): | ||||
|             nonlocal n_calls | ||||
|             nonlocal xtype | ||||
|  | ||||
|             n_calls += 1 | ||||
|             if xtype is not None: | ||||
|                 self.assertEqual(xtype, x.dtype) | ||||
|  | ||||
|             return original_all_sum(x, **kwargs) | ||||
|  | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|  | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 10) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             xtype = mx.float16 | ||||
|             new_grads = average_gradients( | ||||
|                 grads, all_reduce_size=2 * 50, communication_type=mx.float16 | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user