Fix divide types + floor divide (//) (#138)

* divide types

* fix black + test
This commit is contained in:
Awni Hannun
2023-12-11 20:20:58 -08:00
committed by GitHub
parent 02de234ef0
commit 25f70d4ca4
4 changed files with 42 additions and 6 deletions

View File

@@ -2,7 +2,7 @@
import os
import unittest
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union
import mlx.core as mx
import numpy as np
@@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray(
self,
args: List[mx.array | float | int],
args: List[Union[mx.array, float, int]],
mlx_func: Callable[..., mx.array],
expected: mx.array,
atol=1e-2,