mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
fd836d891b
commit
89b90dcfec
4
.github/pull_request_template.md
vendored
Normal file
4
.github/pull_request_template.md
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
Before submitting this PR, check the [contribution guidelines](CONTRIBUTING.md).
|
||||||
|
|
||||||
|
Make sure your code is formatted: `pre-commit run --all-files`.
|
@ -1,13 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
def get_all_optimizers():
|
def get_all_optimizers():
|
||||||
classes = dict()
|
classes = dict()
|
||||||
for name, obj in inspect.getmembers(opt):
|
for name, obj in inspect.getmembers(opt):
|
||||||
@ -16,8 +17,10 @@ def get_all_optimizers():
|
|||||||
classes[name] = obj
|
classes[name] = obj
|
||||||
return classes
|
return classes
|
||||||
|
|
||||||
|
|
||||||
optimizers_dict = get_all_optimizers()
|
optimizers_dict = get_all_optimizers()
|
||||||
|
|
||||||
|
|
||||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||||
def test_optimizers(self):
|
def test_optimizers(self):
|
||||||
params = {
|
params = {
|
||||||
|
Loading…
Reference in New Issue
Block a user