Fix cpu segfault (#1488)

* fix cpu segfault

* nit in tests
This commit is contained in:
Awni Hannun
2024-10-14 16:17:03 -07:00
committed by GitHub
parent 020f048cd0
commit 0ab8e099e8
6 changed files with 52 additions and 51 deletions

View File

@@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) {
start (float or int, optional): Starting value which defaults to ``0``.
stop (float or int): Stopping value.
step (float or int, optional): Increment which defaults to ``1``.
dtype (Dtype, optional): Specifies the data type of the output.
If unspecified will default to ``float32`` if any of ``start``,
``stop``, or ``step`` are ``float``. Otherwise will default to
``int32``.
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
Returns:
array: The range of values.

View File

@@ -13,6 +13,7 @@
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"

View File

@@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
@@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
@@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()