mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Feature expand nn linear (#315)
* Added an identity and bilinear layers Added a reset_parameters option Added normal init for bias * pre-commit run * add type hints for parameters and the return type change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization * Remove unnecessary reshape * Added 'i' to bilinear formula * Changed bilinear computation to two matrix multiplications * avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1) * Changed math formula in Linear Added more explanation to math formulas Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
@@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Identity()
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
|
||||
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
|
||||
def test_bilinear(self):
|
||||
inputs1 = mx.zeros((10, 2))
|
||||
inputs2 = mx.zeros((10, 4))
|
||||
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
|
||||
outputs = layer(inputs1, inputs2)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 6))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
Reference in New Issue
Block a user