handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

This commit is contained in:
Gabrijel Boduljak 2023-12-22 01:19:57 +01:00 committed by Awni Hannun
parent 26bb16e768
commit 49c48de53b
2 changed files with 54 additions and 23 deletions

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <limits>
#include <numeric> #include <numeric>
#include <ostream> #include <ostream>
#include <variant> #include <variant>
@ -68,6 +69,12 @@ void init_linalg(py::module_& parent_module) {
const double ord, const double ord,
const bool keepdims, const bool keepdims,
const StreamOrDevice stream) { const StreamOrDevice stream) {
if (std::isinf((float)ord) || std::isinf(ord))
if (ord > 0)
return norm(a, "inf", {}, keepdims, stream);
else
return norm(a, "-inf", {}, keepdims, stream);
return norm(a, ord, {}, keepdims, stream); return norm(a, ord, {}, keepdims, stream);
}, },
"a"_a, "a"_a,
@ -82,6 +89,12 @@ void init_linalg(py::module_& parent_module) {
const int axis, const int axis,
const bool keepdims, const bool keepdims,
const StreamOrDevice stream) { const StreamOrDevice stream) {
if (std::isinf((float)ord) || std::isinf(ord))
if (ord > 0)
return norm(a, "inf", {axis}, keepdims, stream);
else
return norm(a, "-inf", {axis}, keepdims, stream);
return norm(a, ord, {axis}, keepdims, stream); return norm(a, ord, {axis}, keepdims, stream);
}, },
"a"_a, "a"_a,
@ -97,6 +110,12 @@ void init_linalg(py::module_& parent_module) {
const std::vector<int>& axis, const std::vector<int>& axis,
const bool keepdims, const bool keepdims,
const StreamOrDevice stream) { const StreamOrDevice stream) {
if (std::isinf((float)ord) || std::isinf(ord))
if (ord > 0)
return norm(a, "inf", axis, keepdims, stream);
else
return norm(a, "-inf", axis, keepdims, stream);
return norm(a, ord, axis, keepdims, stream); return norm(a, ord, axis, keepdims, stream);
}, },
"a"_a, "a"_a,

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import itertools import itertools
import math
import unittest import unittest
import mlx.core as mx import mlx.core as mx
@ -10,18 +11,19 @@ import numpy as np
class TestLinalg(mlx_tests.MLXTestCase): class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self): def test_norm(self):
def check_mx_np(a_mx, a_np): vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")]
self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6)) matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
x_mx = mx.arange(18).reshape((2, 3, 3)) for shape in [(3,), (2, 3), (2, 3, 3)]:
x_np = np.arange(18).reshape((2, 3, 3)) x_mx = mx.arange(math.prod(shape)).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape)
for num_axes in range(1, 3): # Test when at least one axis is provided
for axis in itertools.combinations(range(3), num_axes): for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
if num_axes == 1: if num_axes == 1:
ords = [None, 0.5, 0, 1, 2, 3, -1, 1] ords = vector_ords
else: else:
ords = [None, "fro", -1, 1] ords = matrix_ords
for o in ords: for o in ords:
for keepdims in [True, False]: for keepdims in [True, False]:
if o: if o:
@ -32,8 +34,18 @@ class TestLinalg(mlx_tests.MLXTestCase):
x_mx, ord=o, axis=axis, keepdims=keepdims x_mx, ord=o, axis=axis, keepdims=keepdims
) )
else: else:
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) out_np = np.linalg.norm(
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) x_np, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, axis=axis, keepdims=keepdims
)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test when no axes and no ords are provided
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)