mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
handling inf, -inf as numpy does, more extensive tests of compatibility with numpy
This commit is contained in:
parent
26bb16e768
commit
49c48de53b
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
@ -68,6 +69,12 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
const double ord,
|
const double ord,
|
||||||
const bool keepdims,
|
const bool keepdims,
|
||||||
const StreamOrDevice stream) {
|
const StreamOrDevice stream) {
|
||||||
|
if (std::isinf((float)ord) || std::isinf(ord))
|
||||||
|
if (ord > 0)
|
||||||
|
return norm(a, "inf", {}, keepdims, stream);
|
||||||
|
else
|
||||||
|
return norm(a, "-inf", {}, keepdims, stream);
|
||||||
|
|
||||||
return norm(a, ord, {}, keepdims, stream);
|
return norm(a, ord, {}, keepdims, stream);
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -82,6 +89,12 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
const int axis,
|
const int axis,
|
||||||
const bool keepdims,
|
const bool keepdims,
|
||||||
const StreamOrDevice stream) {
|
const StreamOrDevice stream) {
|
||||||
|
if (std::isinf((float)ord) || std::isinf(ord))
|
||||||
|
if (ord > 0)
|
||||||
|
return norm(a, "inf", {axis}, keepdims, stream);
|
||||||
|
else
|
||||||
|
return norm(a, "-inf", {axis}, keepdims, stream);
|
||||||
|
|
||||||
return norm(a, ord, {axis}, keepdims, stream);
|
return norm(a, ord, {axis}, keepdims, stream);
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -97,6 +110,12 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
const bool keepdims,
|
const bool keepdims,
|
||||||
const StreamOrDevice stream) {
|
const StreamOrDevice stream) {
|
||||||
|
if (std::isinf((float)ord) || std::isinf(ord))
|
||||||
|
if (ord > 0)
|
||||||
|
return norm(a, "inf", axis, keepdims, stream);
|
||||||
|
else
|
||||||
|
return norm(a, "-inf", axis, keepdims, stream);
|
||||||
|
|
||||||
return norm(a, ord, axis, keepdims, stream);
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -10,18 +11,19 @@ import numpy as np
|
|||||||
|
|
||||||
class TestLinalg(mlx_tests.MLXTestCase):
|
class TestLinalg(mlx_tests.MLXTestCase):
|
||||||
def test_norm(self):
|
def test_norm(self):
|
||||||
def check_mx_np(a_mx, a_np):
|
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")]
|
||||||
self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6))
|
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||||
|
|
||||||
x_mx = mx.arange(18).reshape((2, 3, 3))
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
x_np = np.arange(18).reshape((2, 3, 3))
|
x_mx = mx.arange(math.prod(shape)).reshape(shape)
|
||||||
|
x_np = np.arange(math.prod(shape)).reshape(shape)
|
||||||
for num_axes in range(1, 3):
|
# Test when at least one axis is provided
|
||||||
for axis in itertools.combinations(range(3), num_axes):
|
for num_axes in range(1, len(shape)):
|
||||||
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
if num_axes == 1:
|
if num_axes == 1:
|
||||||
ords = [None, 0.5, 0, 1, 2, 3, -1, 1]
|
ords = vector_ords
|
||||||
else:
|
else:
|
||||||
ords = [None, "fro", -1, 1]
|
ords = matrix_ords
|
||||||
for o in ords:
|
for o in ords:
|
||||||
for keepdims in [True, False]:
|
for keepdims in [True, False]:
|
||||||
if o:
|
if o:
|
||||||
@ -32,8 +34,18 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
|
out_np = np.linalg.norm(
|
||||||
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
|
x_np, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
out_mx = mx.linalg.norm(
|
||||||
|
x_mx, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
|
||||||
|
# Test when no axes and no ords are provided
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||||
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user