mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
							
								
								
									
										4
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					Before submitting this PR, check the [contribution guidelines](CONTRIBUTING.md).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Make sure your code is formatted: `pre-commit run --all-files`.
 | 
				
			||||||
@@ -1,13 +1,14 @@
 | 
				
			|||||||
# Copyright © 2023 Apple Inc.
 | 
					# Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import unittest
 | 
					 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
import mlx.optimizers as opt
 | 
					import mlx.optimizers as opt
 | 
				
			||||||
import mlx.utils
 | 
					import mlx.utils
 | 
				
			||||||
import mlx_tests
 | 
					import mlx_tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_all_optimizers():
 | 
					def get_all_optimizers():
 | 
				
			||||||
    classes = dict()
 | 
					    classes = dict()
 | 
				
			||||||
    for name, obj in inspect.getmembers(opt):
 | 
					    for name, obj in inspect.getmembers(opt):
 | 
				
			||||||
@@ -16,8 +17,10 @@ def get_all_optimizers():
 | 
				
			|||||||
                classes[name] = obj
 | 
					                classes[name] = obj
 | 
				
			||||||
    return classes
 | 
					    return classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
optimizers_dict = get_all_optimizers()
 | 
					optimizers_dict = get_all_optimizers()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestOptimizers(mlx_tests.MLXTestCase):
 | 
					class TestOptimizers(mlx_tests.MLXTestCase):
 | 
				
			||||||
    def test_optimizers(self):
 | 
					    def test_optimizers(self):
 | 
				
			||||||
        params = {
 | 
					        params = {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user