mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Treate 'minimum' differently in cosine decay (#1138)
This commit is contained in:
		| @@ -58,14 +58,14 @@ def step_decay(init: float, decay_rate: float, step_size: int) -> Callable: | ||||
|     return schedule | ||||
|  | ||||
|  | ||||
| def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callable: | ||||
| def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable: | ||||
|     r"""Make a cosine decay scheduler. | ||||
|  | ||||
|     Args: | ||||
|         init (float): Initial value. | ||||
|         decay_steps (int): Number of steps to decay over. The decayed | ||||
|             value is constant for steps beyond ``decay_steps``. | ||||
|         minimum (float, optional): Minimal value to decay to. Default: ``0``. | ||||
|         end (float, optional): Final value to decay to. Default: ``0``. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
| @@ -83,7 +83,7 @@ def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callabl | ||||
|     def scheduler(step): | ||||
|         s = mx.minimum(step, decay_steps) | ||||
|         decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) | ||||
|         return mx.maximum(init * decay, minimum) | ||||
|         return end + decay * (init - end) | ||||
|  | ||||
|     return scheduler | ||||
|  | ||||
|   | ||||
| @@ -329,9 +329,11 @@ class TestSchedulers(unittest.TestCase): | ||||
|         self.assertAlmostEqual(lr, expected_lr, delta=1e-7) | ||||
|  | ||||
|         lr_schedule = opt.cosine_decay(0.1, 10, 0.05) | ||||
|         lr = lr_schedule(9) | ||||
|         expected_end_lr = 0.05 | ||||
|         self.assertGreater(lr, expected_end_lr) | ||||
|         lr = lr_schedule(20) | ||||
|         expected_lr = 0.05 | ||||
|         self.assertEqual(lr, expected_lr) | ||||
|         self.assertEqual(lr, expected_end_lr) | ||||
|  | ||||
|     def test_schedule_joiner(self): | ||||
|         boundaries = [2, 3, 4] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jlwitthuhn
					jlwitthuhn