mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
Remove stream from average grads so it uses default
This commit is contained in:
parent
30561229c7
commit
d08fa4bef8
@ -76,7 +76,6 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
stream: mx.Stream = mx.cpu,
|
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@ -95,7 +94,6 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@ -106,7 +104,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
return mx.distributed.all_sum(x).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
|||||||
mx.distributed.all_sum = new_all_sum
|
mx.distributed.all_sum = new_all_sum
|
||||||
try:
|
try:
|
||||||
grads = [mx.ones(10) for i in range(10)]
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
new_grads = average_gradients(grads, stream=mx.gpu)
|
new_grads = average_gradients(grads)
|
||||||
mx.eval(new_grads)
|
mx.eval(new_grads)
|
||||||
self.assertEqual(len(new_grads), 10)
|
self.assertEqual(len(new_grads), 10)
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
self.assertEqual(n_calls, 1)
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
n_calls = 0
|
n_calls = 0
|
||||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||||
mx.eval(new_grads)
|
mx.eval(new_grads)
|
||||||
self.assertEqual(len(new_grads), 10)
|
self.assertEqual(len(new_grads), 10)
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
self.assertEqual(n_calls, 2)
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
n_calls = 0
|
n_calls = 0
|
||||||
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||||
mx.eval(new_grads)
|
mx.eval(new_grads)
|
||||||
self.assertEqual(len(new_grads), 10)
|
self.assertEqual(len(new_grads), 10)
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
|||||||
grads,
|
grads,
|
||||||
all_reduce_size=2 * 50,
|
all_reduce_size=2 * 50,
|
||||||
communication_type=mx.float16,
|
communication_type=mx.float16,
|
||||||
stream=mx.gpu,
|
|
||||||
)
|
)
|
||||||
mx.eval(new_grads)
|
mx.eval(new_grads)
|
||||||
self.assertEqual(len(new_grads), 10)
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
Loading…
Reference in New Issue
Block a user