mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 05:14:40 +08:00
@@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) {
|
||||
start (float or int, optional): Starting value which defaults to ``0``.
|
||||
stop (float or int): Stopping value.
|
||||
step (float or int, optional): Increment which defaults to ``1``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output.
|
||||
If unspecified will default to ``float32`` if any of ``start``,
|
||||
``stop``, or ``step`` are ``float``. Otherwise will default to
|
||||
``int32``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
|
||||
|
||||
Returns:
|
||||
array: The range of values.
|
||||
|
@@ -13,6 +13,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
@@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
||||
@@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
||||
@@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
||||
|
Reference in New Issue
Block a user