mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
@@ -1,13 +1,14 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
@@ -16,8 +17,10 @@ def get_all_optimizers():
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
|
Reference in New Issue
Block a user