8th day of python challenges 111-117
This commit is contained in:
@@ -0,0 +1,223 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas.core.dtypes.dtypes import CategoricalDtype, IntervalDtype
|
||||
|
||||
from pandas import (
|
||||
CategoricalIndex,
|
||||
Index,
|
||||
IntervalIndex,
|
||||
NaT,
|
||||
Timedelta,
|
||||
Timestamp,
|
||||
interval_range,
|
||||
)
|
||||
import pandas.util.testing as tm
|
||||
|
||||
|
||||
class Base:
|
||||
"""Tests common to IntervalIndex with any subtype"""
|
||||
|
||||
def test_astype_idempotent(self, index):
|
||||
result = index.astype("interval")
|
||||
tm.assert_index_equal(result, index)
|
||||
|
||||
result = index.astype(index.dtype)
|
||||
tm.assert_index_equal(result, index)
|
||||
|
||||
def test_astype_object(self, index):
|
||||
result = index.astype(object)
|
||||
expected = Index(index.values, dtype="object")
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert not result.equals(index)
|
||||
|
||||
def test_astype_category(self, index):
|
||||
result = index.astype("category")
|
||||
expected = CategoricalIndex(index.values)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
result = index.astype(CategoricalDtype())
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# non-default params
|
||||
categories = index.dropna().unique().values[:-1]
|
||||
dtype = CategoricalDtype(categories=categories, ordered=True)
|
||||
result = index.astype(dtype)
|
||||
expected = CategoricalIndex(index.values, categories=categories, ordered=True)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[
|
||||
"int64",
|
||||
"uint64",
|
||||
"float64",
|
||||
"complex128",
|
||||
"period[M]",
|
||||
"timedelta64",
|
||||
"timedelta64[ns]",
|
||||
"datetime64",
|
||||
"datetime64[ns]",
|
||||
"datetime64[ns, US/Eastern]",
|
||||
],
|
||||
)
|
||||
def test_astype_cannot_cast(self, index, dtype):
|
||||
msg = "Cannot cast IntervalIndex to dtype"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
||||
|
||||
def test_astype_invalid_dtype(self, index):
|
||||
msg = "data type 'fake_dtype' not understood"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype("fake_dtype")
|
||||
|
||||
|
||||
class TestIntSubtype(Base):
|
||||
"""Tests specific to IntervalIndex with integer-like subtype"""
|
||||
|
||||
indexes = [
|
||||
IntervalIndex.from_breaks(np.arange(-10, 11, dtype="int64")),
|
||||
IntervalIndex.from_breaks(np.arange(100, dtype="uint64"), closed="left"),
|
||||
]
|
||||
|
||||
@pytest.fixture(params=indexes)
|
||||
def index(self, request):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"subtype", ["float64", "datetime64[ns]", "timedelta64[ns]"]
|
||||
)
|
||||
def test_subtype_conversion(self, index, subtype):
|
||||
dtype = IntervalDtype(subtype)
|
||||
result = index.astype(dtype)
|
||||
expected = IntervalIndex.from_arrays(
|
||||
index.left.astype(subtype), index.right.astype(subtype), closed=index.closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"subtype_start, subtype_end", [("int64", "uint64"), ("uint64", "int64")]
|
||||
)
|
||||
def test_subtype_integer(self, subtype_start, subtype_end):
|
||||
index = IntervalIndex.from_breaks(np.arange(100, dtype=subtype_start))
|
||||
dtype = IntervalDtype(subtype_end)
|
||||
result = index.astype(dtype)
|
||||
expected = IntervalIndex.from_arrays(
|
||||
index.left.astype(subtype_end),
|
||||
index.right.astype(subtype_end),
|
||||
closed=index.closed,
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.xfail(reason="GH#15832")
|
||||
def test_subtype_integer_errors(self):
|
||||
# int64 -> uint64 fails with negative values
|
||||
index = interval_range(-10, 10)
|
||||
dtype = IntervalDtype("uint64")
|
||||
with pytest.raises(ValueError):
|
||||
index.astype(dtype)
|
||||
|
||||
|
||||
class TestFloatSubtype(Base):
|
||||
"""Tests specific to IntervalIndex with float subtype"""
|
||||
|
||||
indexes = [
|
||||
interval_range(-10.0, 10.0, closed="neither"),
|
||||
IntervalIndex.from_arrays(
|
||||
[-1.5, np.nan, 0.0, 0.0, 1.5], [-0.5, np.nan, 1.0, 1.0, 3.0], closed="both"
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture(params=indexes)
|
||||
def index(self, request):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.parametrize("subtype", ["int64", "uint64"])
|
||||
def test_subtype_integer(self, subtype):
|
||||
index = interval_range(0.0, 10.0)
|
||||
dtype = IntervalDtype(subtype)
|
||||
result = index.astype(dtype)
|
||||
expected = IntervalIndex.from_arrays(
|
||||
index.left.astype(subtype), index.right.astype(subtype), closed=index.closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# raises with NA
|
||||
msg = "Cannot convert NA to integer"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
index.insert(0, np.nan).astype(dtype)
|
||||
|
||||
@pytest.mark.xfail(reason="GH#15832")
|
||||
def test_subtype_integer_errors(self):
|
||||
# float64 -> uint64 fails with negative values
|
||||
index = interval_range(-10.0, 10.0)
|
||||
dtype = IntervalDtype("uint64")
|
||||
with pytest.raises(ValueError):
|
||||
index.astype(dtype)
|
||||
|
||||
# float64 -> integer-like fails with non-integer valued floats
|
||||
index = interval_range(0.0, 10.0, freq=0.25)
|
||||
dtype = IntervalDtype("int64")
|
||||
with pytest.raises(ValueError):
|
||||
index.astype(dtype)
|
||||
|
||||
dtype = IntervalDtype("uint64")
|
||||
with pytest.raises(ValueError):
|
||||
index.astype(dtype)
|
||||
|
||||
@pytest.mark.parametrize("subtype", ["datetime64[ns]", "timedelta64[ns]"])
|
||||
def test_subtype_datetimelike(self, index, subtype):
|
||||
dtype = IntervalDtype(subtype)
|
||||
msg = "Cannot convert .* to .*; subtypes are incompatible"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
||||
|
||||
|
||||
class TestDatetimelikeSubtype(Base):
|
||||
"""Tests specific to IntervalIndex with datetime-like subtype"""
|
||||
|
||||
indexes = [
|
||||
interval_range(Timestamp("2018-01-01"), periods=10, closed="neither"),
|
||||
interval_range(Timestamp("2018-01-01"), periods=10).insert(2, NaT),
|
||||
interval_range(Timestamp("2018-01-01", tz="US/Eastern"), periods=10),
|
||||
interval_range(Timedelta("0 days"), periods=10, closed="both"),
|
||||
interval_range(Timedelta("0 days"), periods=10).insert(2, NaT),
|
||||
]
|
||||
|
||||
@pytest.fixture(params=indexes)
|
||||
def index(self, request):
|
||||
return request.param
|
||||
|
||||
@pytest.mark.parametrize("subtype", ["int64", "uint64"])
|
||||
def test_subtype_integer(self, index, subtype):
|
||||
dtype = IntervalDtype(subtype)
|
||||
result = index.astype(dtype)
|
||||
expected = IntervalIndex.from_arrays(
|
||||
index.left.astype(subtype), index.right.astype(subtype), closed=index.closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
def test_subtype_float(self, index):
|
||||
dtype = IntervalDtype("float64")
|
||||
msg = "Cannot convert .* to .*; subtypes are incompatible"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
||||
|
||||
def test_subtype_datetimelike(self):
|
||||
# datetime -> timedelta raises
|
||||
dtype = IntervalDtype("timedelta64[ns]")
|
||||
msg = "Cannot convert .* to .*; subtypes are incompatible"
|
||||
|
||||
index = interval_range(Timestamp("2018-01-01"), periods=10)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
||||
|
||||
index = interval_range(Timestamp("2018-01-01", tz="CET"), periods=10)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
||||
|
||||
# timedelta -> datetime raises
|
||||
dtype = IntervalDtype("datetime64[ns]")
|
||||
index = interval_range(Timedelta("0 days"), periods=10)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
index.astype(dtype)
|
@@ -0,0 +1,452 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas.core.dtypes.common import is_categorical_dtype
|
||||
from pandas.core.dtypes.dtypes import IntervalDtype
|
||||
|
||||
from pandas import (
|
||||
Categorical,
|
||||
CategoricalIndex,
|
||||
Float64Index,
|
||||
Index,
|
||||
Int64Index,
|
||||
Interval,
|
||||
IntervalIndex,
|
||||
date_range,
|
||||
notna,
|
||||
period_range,
|
||||
timedelta_range,
|
||||
)
|
||||
from pandas.core.arrays import IntervalArray
|
||||
import pandas.core.common as com
|
||||
import pandas.util.testing as tm
|
||||
|
||||
|
||||
@pytest.fixture(params=[None, "foo"])
|
||||
def name(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class Base:
|
||||
"""
|
||||
Common tests for all variations of IntervalIndex construction. Input data
|
||||
to be supplied in breaks format, then converted by the subclass method
|
||||
get_kwargs_from_breaks to the expected format.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"breaks",
|
||||
[
|
||||
[3, 14, 15, 92, 653],
|
||||
np.arange(10, dtype="int64"),
|
||||
Int64Index(range(-10, 11)),
|
||||
Float64Index(np.arange(20, 30, 0.5)),
|
||||
date_range("20180101", periods=10),
|
||||
date_range("20180101", periods=10, tz="US/Eastern"),
|
||||
timedelta_range("1 day", periods=10),
|
||||
],
|
||||
)
|
||||
def test_constructor(self, constructor, breaks, closed, name):
|
||||
result_kwargs = self.get_kwargs_from_breaks(breaks, closed)
|
||||
result = constructor(closed=closed, name=name, **result_kwargs)
|
||||
|
||||
assert result.closed == closed
|
||||
assert result.name == name
|
||||
assert result.dtype.subtype == getattr(breaks, "dtype", "int64")
|
||||
tm.assert_index_equal(result.left, Index(breaks[:-1]))
|
||||
tm.assert_index_equal(result.right, Index(breaks[1:]))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"breaks, subtype",
|
||||
[
|
||||
(Int64Index([0, 1, 2, 3, 4]), "float64"),
|
||||
(Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
|
||||
(Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
|
||||
(Float64Index([0, 1, 2, 3, 4]), "int64"),
|
||||
(date_range("2017-01-01", periods=5), "int64"),
|
||||
(timedelta_range("1 day", periods=5), "int64"),
|
||||
],
|
||||
)
|
||||
def test_constructor_dtype(self, constructor, breaks, subtype):
|
||||
# GH 19262: conversion via dtype parameter
|
||||
expected_kwargs = self.get_kwargs_from_breaks(breaks.astype(subtype))
|
||||
expected = constructor(**expected_kwargs)
|
||||
|
||||
result_kwargs = self.get_kwargs_from_breaks(breaks)
|
||||
iv_dtype = IntervalDtype(subtype)
|
||||
for dtype in (iv_dtype, str(iv_dtype)):
|
||||
result = constructor(dtype=dtype, **result_kwargs)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize("breaks", [[np.nan] * 2, [np.nan] * 4, [np.nan] * 50])
|
||||
def test_constructor_nan(self, constructor, breaks, closed):
|
||||
# GH 18421
|
||||
result_kwargs = self.get_kwargs_from_breaks(breaks)
|
||||
result = constructor(closed=closed, **result_kwargs)
|
||||
|
||||
expected_subtype = np.float64
|
||||
expected_values = np.array(breaks[:-1], dtype=object)
|
||||
|
||||
assert result.closed == closed
|
||||
assert result.dtype.subtype == expected_subtype
|
||||
tm.assert_numpy_array_equal(result._ndarray_values, expected_values)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"breaks",
|
||||
[
|
||||
[],
|
||||
np.array([], dtype="int64"),
|
||||
np.array([], dtype="float64"),
|
||||
np.array([], dtype="datetime64[ns]"),
|
||||
np.array([], dtype="timedelta64[ns]"),
|
||||
],
|
||||
)
|
||||
def test_constructor_empty(self, constructor, breaks, closed):
|
||||
# GH 18421
|
||||
result_kwargs = self.get_kwargs_from_breaks(breaks)
|
||||
result = constructor(closed=closed, **result_kwargs)
|
||||
|
||||
expected_values = np.array([], dtype=object)
|
||||
expected_subtype = getattr(breaks, "dtype", np.int64)
|
||||
|
||||
assert result.empty
|
||||
assert result.closed == closed
|
||||
assert result.dtype.subtype == expected_subtype
|
||||
tm.assert_numpy_array_equal(result._ndarray_values, expected_values)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"breaks",
|
||||
[
|
||||
tuple("0123456789"),
|
||||
list("abcdefghij"),
|
||||
np.array(list("abcdefghij"), dtype=object),
|
||||
np.array(list("abcdefghij"), dtype="<U1"),
|
||||
],
|
||||
)
|
||||
def test_constructor_string(self, constructor, breaks):
|
||||
# GH 19016
|
||||
msg = (
|
||||
"category, object, and string subtypes are not supported "
|
||||
"for IntervalIndex"
|
||||
)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
constructor(**self.get_kwargs_from_breaks(breaks))
|
||||
|
||||
@pytest.mark.parametrize("cat_constructor", [Categorical, CategoricalIndex])
|
||||
def test_constructor_categorical_valid(self, constructor, cat_constructor):
|
||||
# GH 21243/21253
|
||||
if isinstance(constructor, partial) and constructor.func is Index:
|
||||
# Index is defined to create CategoricalIndex from categorical data
|
||||
pytest.skip()
|
||||
|
||||
breaks = np.arange(10, dtype="int64")
|
||||
expected = IntervalIndex.from_breaks(breaks)
|
||||
|
||||
cat_breaks = cat_constructor(breaks)
|
||||
result_kwargs = self.get_kwargs_from_breaks(cat_breaks)
|
||||
result = constructor(**result_kwargs)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
def test_generic_errors(self, constructor):
|
||||
# filler input data to be used when supplying invalid kwargs
|
||||
filler = self.get_kwargs_from_breaks(range(10))
|
||||
|
||||
# invalid closed
|
||||
msg = "invalid option for 'closed': invalid"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
constructor(closed="invalid", **filler)
|
||||
|
||||
# unsupported dtype
|
||||
msg = "dtype must be an IntervalDtype, got int64"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
constructor(dtype="int64", **filler)
|
||||
|
||||
# invalid dtype
|
||||
msg = "data type 'invalid' not understood"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
constructor(dtype="invalid", **filler)
|
||||
|
||||
# no point in nesting periods in an IntervalIndex
|
||||
periods = period_range("2000-01-01", periods=10)
|
||||
periods_kwargs = self.get_kwargs_from_breaks(periods)
|
||||
msg = "Period dtypes are not supported, use a PeriodIndex instead"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
constructor(**periods_kwargs)
|
||||
|
||||
# decreasing values
|
||||
decreasing_kwargs = self.get_kwargs_from_breaks(range(10, -1, -1))
|
||||
msg = "left side of interval must be <= right side"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
constructor(**decreasing_kwargs)
|
||||
|
||||
|
||||
class TestFromArrays(Base):
|
||||
"""Tests specific to IntervalIndex.from_arrays"""
|
||||
|
||||
@pytest.fixture
|
||||
def constructor(self):
|
||||
return IntervalIndex.from_arrays
|
||||
|
||||
def get_kwargs_from_breaks(self, breaks, closed="right"):
|
||||
"""
|
||||
converts intervals in breaks format to a dictionary of kwargs to
|
||||
specific to the format expected by IntervalIndex.from_arrays
|
||||
"""
|
||||
return {"left": breaks[:-1], "right": breaks[1:]}
|
||||
|
||||
def test_constructor_errors(self):
|
||||
# GH 19016: categorical data
|
||||
data = Categorical(list("01234abcde"), ordered=True)
|
||||
msg = (
|
||||
"category, object, and string subtypes are not supported "
|
||||
"for IntervalIndex"
|
||||
)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
IntervalIndex.from_arrays(data[:-1], data[1:])
|
||||
|
||||
# unequal length
|
||||
left = [0, 1, 2]
|
||||
right = [2, 3]
|
||||
msg = "left and right must have the same length"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
IntervalIndex.from_arrays(left, right)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"left_subtype, right_subtype", [(np.int64, np.float64), (np.float64, np.int64)]
|
||||
)
|
||||
def test_mixed_float_int(self, left_subtype, right_subtype):
|
||||
"""mixed int/float left/right results in float for both sides"""
|
||||
left = np.arange(9, dtype=left_subtype)
|
||||
right = np.arange(1, 10, dtype=right_subtype)
|
||||
result = IntervalIndex.from_arrays(left, right)
|
||||
|
||||
expected_left = Float64Index(left)
|
||||
expected_right = Float64Index(right)
|
||||
expected_subtype = np.float64
|
||||
|
||||
tm.assert_index_equal(result.left, expected_left)
|
||||
tm.assert_index_equal(result.right, expected_right)
|
||||
assert result.dtype.subtype == expected_subtype
|
||||
|
||||
|
||||
class TestFromBreaks(Base):
|
||||
"""Tests specific to IntervalIndex.from_breaks"""
|
||||
|
||||
@pytest.fixture
|
||||
def constructor(self):
|
||||
return IntervalIndex.from_breaks
|
||||
|
||||
def get_kwargs_from_breaks(self, breaks, closed="right"):
|
||||
"""
|
||||
converts intervals in breaks format to a dictionary of kwargs to
|
||||
specific to the format expected by IntervalIndex.from_breaks
|
||||
"""
|
||||
return {"breaks": breaks}
|
||||
|
||||
def test_constructor_errors(self):
|
||||
# GH 19016: categorical data
|
||||
data = Categorical(list("01234abcde"), ordered=True)
|
||||
msg = (
|
||||
"category, object, and string subtypes are not supported "
|
||||
"for IntervalIndex"
|
||||
)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
IntervalIndex.from_breaks(data)
|
||||
|
||||
def test_length_one(self):
|
||||
"""breaks of length one produce an empty IntervalIndex"""
|
||||
breaks = [0]
|
||||
result = IntervalIndex.from_breaks(breaks)
|
||||
expected = IntervalIndex.from_breaks([])
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
|
||||
class TestFromTuples(Base):
|
||||
"""Tests specific to IntervalIndex.from_tuples"""
|
||||
|
||||
@pytest.fixture
|
||||
def constructor(self):
|
||||
return IntervalIndex.from_tuples
|
||||
|
||||
def get_kwargs_from_breaks(self, breaks, closed="right"):
|
||||
"""
|
||||
converts intervals in breaks format to a dictionary of kwargs to
|
||||
specific to the format expected by IntervalIndex.from_tuples
|
||||
"""
|
||||
if len(breaks) == 0:
|
||||
return {"data": breaks}
|
||||
|
||||
tuples = list(zip(breaks[:-1], breaks[1:]))
|
||||
if isinstance(breaks, (list, tuple)):
|
||||
return {"data": tuples}
|
||||
elif is_categorical_dtype(breaks):
|
||||
return {"data": breaks._constructor(tuples)}
|
||||
return {"data": com.asarray_tuplesafe(tuples)}
|
||||
|
||||
def test_constructor_errors(self):
|
||||
# non-tuple
|
||||
tuples = [(0, 1), 2, (3, 4)]
|
||||
msg = "IntervalIndex.from_tuples received an invalid item, 2"
|
||||
with pytest.raises(TypeError, match=msg.format(t=tuples)):
|
||||
IntervalIndex.from_tuples(tuples)
|
||||
|
||||
# too few/many items
|
||||
tuples = [(0, 1), (2,), (3, 4)]
|
||||
msg = "IntervalIndex.from_tuples requires tuples of length 2, got {t}"
|
||||
with pytest.raises(ValueError, match=msg.format(t=tuples)):
|
||||
IntervalIndex.from_tuples(tuples)
|
||||
|
||||
tuples = [(0, 1), (2, 3, 4), (5, 6)]
|
||||
with pytest.raises(ValueError, match=msg.format(t=tuples)):
|
||||
IntervalIndex.from_tuples(tuples)
|
||||
|
||||
def test_na_tuples(self):
|
||||
# tuple (NA, NA) evaluates the same as NA as an element
|
||||
na_tuple = [(0, 1), (np.nan, np.nan), (2, 3)]
|
||||
idx_na_tuple = IntervalIndex.from_tuples(na_tuple)
|
||||
idx_na_element = IntervalIndex.from_tuples([(0, 1), np.nan, (2, 3)])
|
||||
tm.assert_index_equal(idx_na_tuple, idx_na_element)
|
||||
|
||||
|
||||
class TestClassConstructors(Base):
|
||||
"""Tests specific to the IntervalIndex/Index constructors"""
|
||||
|
||||
@pytest.fixture(
|
||||
params=[IntervalIndex, partial(Index, dtype="interval")],
|
||||
ids=["IntervalIndex", "Index"],
|
||||
)
|
||||
def constructor(self, request):
|
||||
return request.param
|
||||
|
||||
def get_kwargs_from_breaks(self, breaks, closed="right"):
|
||||
"""
|
||||
converts intervals in breaks format to a dictionary of kwargs to
|
||||
specific to the format expected by the IntervalIndex/Index constructors
|
||||
"""
|
||||
if len(breaks) == 0:
|
||||
return {"data": breaks}
|
||||
|
||||
ivs = [
|
||||
Interval(l, r, closed) if notna(l) else l
|
||||
for l, r in zip(breaks[:-1], breaks[1:])
|
||||
]
|
||||
|
||||
if isinstance(breaks, list):
|
||||
return {"data": ivs}
|
||||
elif is_categorical_dtype(breaks):
|
||||
return {"data": breaks._constructor(ivs)}
|
||||
return {"data": np.array(ivs, dtype=object)}
|
||||
|
||||
def test_generic_errors(self, constructor):
|
||||
"""
|
||||
override the base class implementation since errors are handled
|
||||
differently; checks unnecessary since caught at the Interval level
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_constructor_string(self):
|
||||
# GH23013
|
||||
# When forming the interval from breaks,
|
||||
# the interval of strings is already forbidden.
|
||||
pass
|
||||
|
||||
def test_constructor_errors(self, constructor):
|
||||
# mismatched closed within intervals with no constructor override
|
||||
ivs = [Interval(0, 1, closed="right"), Interval(2, 3, closed="left")]
|
||||
msg = "intervals must all be closed on the same side"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
constructor(ivs)
|
||||
|
||||
# scalar
|
||||
msg = (
|
||||
r"IntervalIndex\(...\) must be called with a collection of "
|
||||
"some kind, 5 was passed"
|
||||
)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
constructor(5)
|
||||
|
||||
# not an interval
|
||||
msg = "type <class 'numpy.int64'> with value 0 is not an interval"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
constructor([0, 1])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, closed",
|
||||
[
|
||||
([], "both"),
|
||||
([np.nan, np.nan], "neither"),
|
||||
(
|
||||
[Interval(0, 3, closed="neither"), Interval(2, 5, closed="neither")],
|
||||
"left",
|
||||
),
|
||||
(
|
||||
[Interval(0, 3, closed="left"), Interval(2, 5, closed="right")],
|
||||
"neither",
|
||||
),
|
||||
(IntervalIndex.from_breaks(range(5), closed="both"), "right"),
|
||||
],
|
||||
)
|
||||
def test_override_inferred_closed(self, constructor, data, closed):
|
||||
# GH 19370
|
||||
if isinstance(data, IntervalIndex):
|
||||
tuples = data.to_tuples()
|
||||
else:
|
||||
tuples = [(iv.left, iv.right) if notna(iv) else iv for iv in data]
|
||||
expected = IntervalIndex.from_tuples(tuples, closed=closed)
|
||||
result = constructor(data, closed=closed)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values_constructor", [list, np.array, IntervalIndex, IntervalArray]
|
||||
)
|
||||
def test_index_object_dtype(self, values_constructor):
|
||||
# Index(intervals, dtype=object) is an Index (not an IntervalIndex)
|
||||
intervals = [Interval(0, 1), Interval(1, 2), Interval(2, 3)]
|
||||
values = values_constructor(intervals)
|
||||
result = Index(values, dtype=object)
|
||||
|
||||
assert type(result) is Index
|
||||
tm.assert_numpy_array_equal(result.values, np.array(values))
|
||||
|
||||
def test_index_mixed_closed(self):
|
||||
# GH27172
|
||||
intervals = [
|
||||
Interval(0, 1, closed="left"),
|
||||
Interval(1, 2, closed="right"),
|
||||
Interval(2, 3, closed="neither"),
|
||||
Interval(3, 4, closed="both"),
|
||||
]
|
||||
result = Index(intervals)
|
||||
expected = Index(intervals, dtype=object)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
|
||||
class TestFromIntervals(TestClassConstructors):
|
||||
"""
|
||||
Tests for IntervalIndex.from_intervals, which is deprecated in favor of the
|
||||
IntervalIndex constructor. Same tests as the IntervalIndex constructor,
|
||||
plus deprecation test. Should only need to delete this class when removed.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def constructor(self):
|
||||
def from_intervals_ignore_warnings(*args, **kwargs):
|
||||
with tm.assert_produces_warning(FutureWarning, check_stacklevel=False):
|
||||
return IntervalIndex.from_intervals(*args, **kwargs)
|
||||
|
||||
return from_intervals_ignore_warnings
|
||||
|
||||
def test_deprecated(self):
|
||||
ivs = [Interval(0, 1), Interval(1, 2)]
|
||||
with tm.assert_produces_warning(FutureWarning, check_stacklevel=False):
|
||||
IntervalIndex.from_intervals(ivs)
|
||||
|
||||
@pytest.mark.skip(reason="parent class test that is not applicable")
|
||||
def test_index_object_dtype(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="parent class test that is not applicable")
|
||||
def test_index_mixed_closed(self):
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,306 @@
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas import Interval, IntervalIndex
|
||||
from pandas.core.indexes.base import InvalidIndexError
|
||||
import pandas.util.testing as tm
|
||||
|
||||
|
||||
class TestIntervalIndex:
|
||||
@pytest.mark.parametrize("side", ["right", "left", "both", "neither"])
|
||||
def test_get_loc_interval(self, closed, side):
|
||||
|
||||
idx = IntervalIndex.from_tuples([(0, 1), (2, 3)], closed=closed)
|
||||
|
||||
for bound in [[0, 1], [1, 2], [2, 3], [3, 4], [0, 2], [2.5, 3], [-1, 4]]:
|
||||
# if get_loc is supplied an interval, it should only search
|
||||
# for exact matches, not overlaps or covers, else KeyError.
|
||||
msg = re.escape(
|
||||
"Interval({bound[0]}, {bound[1]}, closed='{side}')".format(
|
||||
bound=bound, side=side
|
||||
)
|
||||
)
|
||||
if closed == side:
|
||||
if bound == [0, 1]:
|
||||
assert idx.get_loc(Interval(0, 1, closed=side)) == 0
|
||||
elif bound == [2, 3]:
|
||||
assert idx.get_loc(Interval(2, 3, closed=side)) == 1
|
||||
else:
|
||||
with pytest.raises(KeyError, match=msg):
|
||||
idx.get_loc(Interval(*bound, closed=side))
|
||||
else:
|
||||
with pytest.raises(KeyError, match=msg):
|
||||
idx.get_loc(Interval(*bound, closed=side))
|
||||
|
||||
@pytest.mark.parametrize("scalar", [-0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5])
|
||||
def test_get_loc_scalar(self, closed, scalar):
|
||||
|
||||
# correct = {side: {query: answer}}.
|
||||
# If query is not in the dict, that query should raise a KeyError
|
||||
correct = {
|
||||
"right": {0.5: 0, 1: 0, 2.5: 1, 3: 1},
|
||||
"left": {0: 0, 0.5: 0, 2: 1, 2.5: 1},
|
||||
"both": {0: 0, 0.5: 0, 1: 0, 2: 1, 2.5: 1, 3: 1},
|
||||
"neither": {0.5: 0, 2.5: 1},
|
||||
}
|
||||
|
||||
idx = IntervalIndex.from_tuples([(0, 1), (2, 3)], closed=closed)
|
||||
|
||||
# if get_loc is supplied a scalar, it should return the index of
|
||||
# the interval which contains the scalar, or KeyError.
|
||||
if scalar in correct[closed].keys():
|
||||
assert idx.get_loc(scalar) == correct[closed][scalar]
|
||||
else:
|
||||
with pytest.raises(KeyError, match=str(scalar)):
|
||||
idx.get_loc(scalar)
|
||||
|
||||
def test_slice_locs_with_interval(self):
|
||||
|
||||
# increasing monotonically
|
||||
index = IntervalIndex.from_tuples([(0, 2), (1, 3), (2, 4)])
|
||||
|
||||
assert index.slice_locs(start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(start=Interval(0, 2)) == (0, 3)
|
||||
assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(end=Interval(0, 2)) == (0, 1)
|
||||
assert index.slice_locs(start=Interval(2, 4), end=Interval(0, 2)) == (2, 1)
|
||||
|
||||
# decreasing monotonically
|
||||
index = IntervalIndex.from_tuples([(2, 4), (1, 3), (0, 2)])
|
||||
|
||||
assert index.slice_locs(start=Interval(0, 2), end=Interval(2, 4)) == (2, 1)
|
||||
assert index.slice_locs(start=Interval(0, 2)) == (2, 3)
|
||||
assert index.slice_locs(end=Interval(2, 4)) == (0, 1)
|
||||
assert index.slice_locs(end=Interval(0, 2)) == (0, 3)
|
||||
assert index.slice_locs(start=Interval(2, 4), end=Interval(0, 2)) == (0, 3)
|
||||
|
||||
# sorted duplicates
|
||||
index = IntervalIndex.from_tuples([(0, 2), (0, 2), (2, 4)])
|
||||
|
||||
assert index.slice_locs(start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(start=Interval(0, 2)) == (0, 3)
|
||||
assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(end=Interval(0, 2)) == (0, 2)
|
||||
assert index.slice_locs(start=Interval(2, 4), end=Interval(0, 2)) == (2, 2)
|
||||
|
||||
# unsorted duplicates
|
||||
index = IntervalIndex.from_tuples([(0, 2), (2, 4), (0, 2)])
|
||||
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=re.escape(
|
||||
'"Cannot get left slice bound for non-unique label:'
|
||||
" Interval(0, 2, closed='right')\""
|
||||
),
|
||||
):
|
||||
index.slice_locs(start=Interval(0, 2), end=Interval(2, 4))
|
||||
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=re.escape(
|
||||
'"Cannot get left slice bound for non-unique label:'
|
||||
" Interval(0, 2, closed='right')\""
|
||||
),
|
||||
):
|
||||
index.slice_locs(start=Interval(0, 2))
|
||||
|
||||
assert index.slice_locs(end=Interval(2, 4)) == (0, 2)
|
||||
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=re.escape(
|
||||
'"Cannot get right slice bound for non-unique label:'
|
||||
" Interval(0, 2, closed='right')\""
|
||||
),
|
||||
):
|
||||
index.slice_locs(end=Interval(0, 2))
|
||||
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=re.escape(
|
||||
'"Cannot get right slice bound for non-unique label:'
|
||||
" Interval(0, 2, closed='right')\""
|
||||
),
|
||||
):
|
||||
index.slice_locs(start=Interval(2, 4), end=Interval(0, 2))
|
||||
|
||||
# another unsorted duplicates
|
||||
index = IntervalIndex.from_tuples([(0, 2), (0, 2), (2, 4), (1, 3)])
|
||||
|
||||
assert index.slice_locs(start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(start=Interval(0, 2)) == (0, 4)
|
||||
assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
|
||||
assert index.slice_locs(end=Interval(0, 2)) == (0, 2)
|
||||
assert index.slice_locs(start=Interval(2, 4), end=Interval(0, 2)) == (2, 2)
|
||||
|
||||
def test_slice_locs_with_ints_and_floats_succeeds(self):
|
||||
|
||||
# increasing non-overlapping
|
||||
index = IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)])
|
||||
|
||||
assert index.slice_locs(0, 1) == (0, 1)
|
||||
assert index.slice_locs(0, 2) == (0, 2)
|
||||
assert index.slice_locs(0, 3) == (0, 2)
|
||||
assert index.slice_locs(3, 1) == (2, 1)
|
||||
assert index.slice_locs(3, 4) == (2, 3)
|
||||
assert index.slice_locs(0, 4) == (0, 3)
|
||||
|
||||
# decreasing non-overlapping
|
||||
index = IntervalIndex.from_tuples([(3, 4), (1, 2), (0, 1)])
|
||||
assert index.slice_locs(0, 1) == (3, 3)
|
||||
assert index.slice_locs(0, 2) == (3, 2)
|
||||
assert index.slice_locs(0, 3) == (3, 1)
|
||||
assert index.slice_locs(3, 1) == (1, 3)
|
||||
assert index.slice_locs(3, 4) == (1, 1)
|
||||
assert index.slice_locs(0, 4) == (3, 1)
|
||||
|
||||
@pytest.mark.parametrize("query", [[0, 1], [0, 2], [0, 3], [0, 4]])
|
||||
@pytest.mark.parametrize(
|
||||
"tuples",
|
||||
[
|
||||
[(0, 2), (1, 3), (2, 4)],
|
||||
[(2, 4), (1, 3), (0, 2)],
|
||||
[(0, 2), (0, 2), (2, 4)],
|
||||
[(0, 2), (2, 4), (0, 2)],
|
||||
[(0, 2), (0, 2), (2, 4), (1, 3)],
|
||||
],
|
||||
)
|
||||
def test_slice_locs_with_ints_and_floats_errors(self, tuples, query):
|
||||
start, stop = query
|
||||
index = IntervalIndex.from_tuples(tuples)
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=(
|
||||
"'can only get slices from an IntervalIndex if bounds are"
|
||||
" non-overlapping and all monotonic increasing or decreasing'"
|
||||
),
|
||||
):
|
||||
index.slice_locs(start, stop)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, expected",
|
||||
[
|
||||
([Interval(2, 4, closed="right")], [1]),
|
||||
([Interval(2, 4, closed="left")], [-1]),
|
||||
([Interval(2, 4, closed="both")], [-1]),
|
||||
([Interval(2, 4, closed="neither")], [-1]),
|
||||
([Interval(1, 4, closed="right")], [-1]),
|
||||
([Interval(0, 4, closed="right")], [-1]),
|
||||
([Interval(0.5, 1.5, closed="right")], [-1]),
|
||||
([Interval(2, 4, closed="right"), Interval(0, 1, closed="right")], [1, -1]),
|
||||
([Interval(2, 4, closed="right"), Interval(2, 4, closed="right")], [1, 1]),
|
||||
([Interval(5, 7, closed="right"), Interval(2, 4, closed="right")], [2, 1]),
|
||||
([Interval(2, 4, closed="right"), Interval(2, 4, closed="left")], [1, -1]),
|
||||
],
|
||||
)
|
||||
def test_get_indexer_with_interval(self, query, expected):
|
||||
|
||||
tuples = [(0, 2), (2, 4), (5, 7)]
|
||||
index = IntervalIndex.from_tuples(tuples, closed="right")
|
||||
|
||||
result = index.get_indexer(query)
|
||||
expected = np.array(expected, dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, expected",
|
||||
[
|
||||
([-0.5], [-1]),
|
||||
([0], [-1]),
|
||||
([0.5], [0]),
|
||||
([1], [0]),
|
||||
([1.5], [1]),
|
||||
([2], [1]),
|
||||
([2.5], [-1]),
|
||||
([3], [-1]),
|
||||
([3.5], [2]),
|
||||
([4], [2]),
|
||||
([4.5], [-1]),
|
||||
([1, 2], [0, 1]),
|
||||
([1, 2, 3], [0, 1, -1]),
|
||||
([1, 2, 3, 4], [0, 1, -1, 2]),
|
||||
([1, 2, 3, 4, 2], [0, 1, -1, 2, 1]),
|
||||
],
|
||||
)
|
||||
def test_get_indexer_with_int_and_float(self, query, expected):
|
||||
|
||||
tuples = [(0, 1), (1, 2), (3, 4)]
|
||||
index = IntervalIndex.from_tuples(tuples, closed="right")
|
||||
|
||||
result = index.get_indexer(query)
|
||||
expected = np.array(expected, dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tuples, closed",
|
||||
[
|
||||
([(0, 2), (1, 3), (3, 4)], "neither"),
|
||||
([(0, 5), (1, 4), (6, 7)], "left"),
|
||||
([(0, 1), (0, 1), (1, 2)], "right"),
|
||||
([(0, 1), (2, 3), (3, 4)], "both"),
|
||||
],
|
||||
)
|
||||
def test_get_indexer_errors(self, tuples, closed):
|
||||
# IntervalIndex needs non-overlapping for uniqueness when querying
|
||||
index = IntervalIndex.from_tuples(tuples, closed=closed)
|
||||
|
||||
msg = (
|
||||
"cannot handle overlapping indices; use "
|
||||
"IntervalIndex.get_indexer_non_unique"
|
||||
)
|
||||
with pytest.raises(InvalidIndexError, match=msg):
|
||||
index.get_indexer([0, 2])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, expected",
|
||||
[
|
||||
([-0.5], ([-1], [0])),
|
||||
([0], ([0], [])),
|
||||
([0.5], ([0], [])),
|
||||
([1], ([0, 1], [])),
|
||||
([1.5], ([0, 1], [])),
|
||||
([2], ([0, 1, 2], [])),
|
||||
([2.5], ([1, 2], [])),
|
||||
([3], ([2], [])),
|
||||
([3.5], ([2], [])),
|
||||
([4], ([-1], [0])),
|
||||
([4.5], ([-1], [0])),
|
||||
([1, 2], ([0, 1, 0, 1, 2], [])),
|
||||
([1, 2, 3], ([0, 1, 0, 1, 2, 2], [])),
|
||||
([1, 2, 3, 4], ([0, 1, 0, 1, 2, 2, -1], [3])),
|
||||
([1, 2, 3, 4, 2], ([0, 1, 0, 1, 2, 2, -1, 0, 1, 2], [3])),
|
||||
],
|
||||
)
|
||||
def test_get_indexer_non_unique_with_int_and_float(self, query, expected):
|
||||
|
||||
tuples = [(0, 2.5), (1, 3), (2, 4)]
|
||||
index = IntervalIndex.from_tuples(tuples, closed="left")
|
||||
|
||||
result_indexer, result_missing = index.get_indexer_non_unique(query)
|
||||
expected_indexer = np.array(expected[0], dtype="intp")
|
||||
expected_missing = np.array(expected[1], dtype="intp")
|
||||
|
||||
tm.assert_numpy_array_equal(result_indexer, expected_indexer)
|
||||
tm.assert_numpy_array_equal(result_missing, expected_missing)
|
||||
|
||||
# TODO we may also want to test get_indexer for the case when
|
||||
# the intervals are duplicated, decreasing, non-monotonic, etc..
|
||||
|
||||
def test_contains_dunder(self):
|
||||
|
||||
index = IntervalIndex.from_arrays([0, 1], [1, 2], closed="right")
|
||||
|
||||
# __contains__ requires perfect matches to intervals.
|
||||
assert 0 not in index
|
||||
assert 1 not in index
|
||||
assert 2 not in index
|
||||
|
||||
assert Interval(0, 1, closed="right") in index
|
||||
assert Interval(0, 2, closed="right") not in index
|
||||
assert Interval(0, 0.5, closed="right") not in index
|
||||
assert Interval(3, 5, closed="right") not in index
|
||||
assert Interval(-1, 0, closed="left") not in index
|
||||
assert Interval(0, 1, closed="left") not in index
|
||||
assert Interval(0, 1, closed="both") not in index
|
@@ -0,0 +1,355 @@
|
||||
from datetime import timedelta
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas.core.dtypes.common import is_integer
|
||||
|
||||
from pandas import (
|
||||
DateOffset,
|
||||
Interval,
|
||||
IntervalIndex,
|
||||
Timedelta,
|
||||
Timestamp,
|
||||
date_range,
|
||||
interval_range,
|
||||
timedelta_range,
|
||||
)
|
||||
import pandas.util.testing as tm
|
||||
|
||||
from pandas.tseries.offsets import Day
|
||||
|
||||
|
||||
@pytest.fixture(scope="class", params=[None, "foo"])
|
||||
def name(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class TestIntervalRange:
|
||||
@pytest.mark.parametrize("freq, periods", [(1, 100), (2.5, 40), (5, 20), (25, 4)])
|
||||
def test_constructor_numeric(self, closed, name, freq, periods):
|
||||
start, end = 0, 100
|
||||
breaks = np.arange(101, step=freq)
|
||||
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
|
||||
|
||||
# defined from start/end/freq
|
||||
result = interval_range(
|
||||
start=start, end=end, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from start/periods/freq
|
||||
result = interval_range(
|
||||
start=start, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from end/periods/freq
|
||||
result = interval_range(
|
||||
end=end, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 20976: linspace behavior defined from start/end/periods
|
||||
result = interval_range(
|
||||
start=start, end=end, periods=periods, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize("tz", [None, "US/Eastern"])
|
||||
@pytest.mark.parametrize(
|
||||
"freq, periods", [("D", 364), ("2D", 182), ("22D18H", 16), ("M", 11)]
|
||||
)
|
||||
def test_constructor_timestamp(self, closed, name, freq, periods, tz):
|
||||
start, end = Timestamp("20180101", tz=tz), Timestamp("20181231", tz=tz)
|
||||
breaks = date_range(start=start, end=end, freq=freq)
|
||||
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
|
||||
|
||||
# defined from start/end/freq
|
||||
result = interval_range(
|
||||
start=start, end=end, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from start/periods/freq
|
||||
result = interval_range(
|
||||
start=start, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from end/periods/freq
|
||||
result = interval_range(
|
||||
end=end, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 20976: linspace behavior defined from start/end/periods
|
||||
if not breaks.freq.isAnchored() and tz is None:
|
||||
# matches expected only for non-anchored offsets and tz naive
|
||||
# (anchored/DST transitions cause unequal spacing in expected)
|
||||
result = interval_range(
|
||||
start=start, end=end, periods=periods, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"freq, periods", [("D", 100), ("2D12H", 40), ("5D", 20), ("25D", 4)]
|
||||
)
|
||||
def test_constructor_timedelta(self, closed, name, freq, periods):
|
||||
start, end = Timedelta("0 days"), Timedelta("100 days")
|
||||
breaks = timedelta_range(start=start, end=end, freq=freq)
|
||||
expected = IntervalIndex.from_breaks(breaks, name=name, closed=closed)
|
||||
|
||||
# defined from start/end/freq
|
||||
result = interval_range(
|
||||
start=start, end=end, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from start/periods/freq
|
||||
result = interval_range(
|
||||
start=start, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# defined from end/periods/freq
|
||||
result = interval_range(
|
||||
end=end, periods=periods, freq=freq, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 20976: linspace behavior defined from start/end/periods
|
||||
result = interval_range(
|
||||
start=start, end=end, periods=periods, name=name, closed=closed
|
||||
)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start, end, freq, expected_endpoint",
|
||||
[
|
||||
(0, 10, 3, 9),
|
||||
(0, 10, 1.5, 9),
|
||||
(0.5, 10, 3, 9.5),
|
||||
(Timedelta("0D"), Timedelta("10D"), "2D4H", Timedelta("8D16H")),
|
||||
(
|
||||
Timestamp("2018-01-01"),
|
||||
Timestamp("2018-02-09"),
|
||||
"MS",
|
||||
Timestamp("2018-02-01"),
|
||||
),
|
||||
(
|
||||
Timestamp("2018-01-01", tz="US/Eastern"),
|
||||
Timestamp("2018-01-20", tz="US/Eastern"),
|
||||
"5D12H",
|
||||
Timestamp("2018-01-17 12:00:00", tz="US/Eastern"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_early_truncation(self, start, end, freq, expected_endpoint):
|
||||
# index truncates early if freq causes end to be skipped
|
||||
result = interval_range(start=start, end=end, freq=freq)
|
||||
result_endpoint = result.right[-1]
|
||||
assert result_endpoint == expected_endpoint
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start, end, freq",
|
||||
[(0.5, None, None), (None, 4.5, None), (0.5, None, 1.5), (None, 6.5, 1.5)],
|
||||
)
|
||||
def test_no_invalid_float_truncation(self, start, end, freq):
|
||||
# GH 21161
|
||||
if freq is None:
|
||||
breaks = [0.5, 1.5, 2.5, 3.5, 4.5]
|
||||
else:
|
||||
breaks = [0.5, 2.0, 3.5, 5.0, 6.5]
|
||||
expected = IntervalIndex.from_breaks(breaks)
|
||||
|
||||
result = interval_range(start=start, end=end, periods=4, freq=freq)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start, mid, end",
|
||||
[
|
||||
(
|
||||
Timestamp("2018-03-10", tz="US/Eastern"),
|
||||
Timestamp("2018-03-10 23:30:00", tz="US/Eastern"),
|
||||
Timestamp("2018-03-12", tz="US/Eastern"),
|
||||
),
|
||||
(
|
||||
Timestamp("2018-11-03", tz="US/Eastern"),
|
||||
Timestamp("2018-11-04 00:30:00", tz="US/Eastern"),
|
||||
Timestamp("2018-11-05", tz="US/Eastern"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_linspace_dst_transition(self, start, mid, end):
|
||||
# GH 20976: linspace behavior defined from start/end/periods
|
||||
# accounts for the hour gained/lost during DST transition
|
||||
result = interval_range(start=start, end=end, periods=2)
|
||||
expected = IntervalIndex.from_breaks([start, mid, end])
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize("freq", [2, 2.0])
|
||||
@pytest.mark.parametrize("end", [10, 10.0])
|
||||
@pytest.mark.parametrize("start", [0, 0.0])
|
||||
def test_float_subtype(self, start, end, freq):
|
||||
# Has float subtype if any of start/end/freq are float, even if all
|
||||
# resulting endpoints can safely be upcast to integers
|
||||
|
||||
# defined from start/end/freq
|
||||
index = interval_range(start=start, end=end, freq=freq)
|
||||
result = index.dtype.subtype
|
||||
expected = "int64" if is_integer(start + end + freq) else "float64"
|
||||
assert result == expected
|
||||
|
||||
# defined from start/periods/freq
|
||||
index = interval_range(start=start, periods=5, freq=freq)
|
||||
result = index.dtype.subtype
|
||||
expected = "int64" if is_integer(start + freq) else "float64"
|
||||
assert result == expected
|
||||
|
||||
# defined from end/periods/freq
|
||||
index = interval_range(end=end, periods=5, freq=freq)
|
||||
result = index.dtype.subtype
|
||||
expected = "int64" if is_integer(end + freq) else "float64"
|
||||
assert result == expected
|
||||
|
||||
# GH 20976: linspace behavior defined from start/end/periods
|
||||
index = interval_range(start=start, end=end, periods=5)
|
||||
result = index.dtype.subtype
|
||||
expected = "int64" if is_integer(start + end) else "float64"
|
||||
assert result == expected
|
||||
|
||||
def test_constructor_coverage(self):
|
||||
# float value for periods
|
||||
expected = interval_range(start=0, periods=10)
|
||||
result = interval_range(start=0, periods=10.5)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# equivalent timestamp-like start/end
|
||||
start, end = Timestamp("2017-01-01"), Timestamp("2017-01-15")
|
||||
expected = interval_range(start=start, end=end)
|
||||
|
||||
result = interval_range(start=start.to_pydatetime(), end=end.to_pydatetime())
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
result = interval_range(start=start.asm8, end=end.asm8)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# equivalent freq with timestamp
|
||||
equiv_freq = [
|
||||
"D",
|
||||
Day(),
|
||||
Timedelta(days=1),
|
||||
timedelta(days=1),
|
||||
DateOffset(days=1),
|
||||
]
|
||||
for freq in equiv_freq:
|
||||
result = interval_range(start=start, end=end, freq=freq)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# equivalent timedelta-like start/end
|
||||
start, end = Timedelta(days=1), Timedelta(days=10)
|
||||
expected = interval_range(start=start, end=end)
|
||||
|
||||
result = interval_range(start=start.to_pytimedelta(), end=end.to_pytimedelta())
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
result = interval_range(start=start.asm8, end=end.asm8)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# equivalent freq with timedelta
|
||||
equiv_freq = ["D", Day(), Timedelta(days=1), timedelta(days=1)]
|
||||
for freq in equiv_freq:
|
||||
result = interval_range(start=start, end=end, freq=freq)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
def test_errors(self):
|
||||
# not enough params
|
||||
msg = (
|
||||
"Of the four parameters: start, end, periods, and freq, "
|
||||
"exactly three must be specified"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(start=0)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(end=5)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(periods=2)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range()
|
||||
|
||||
# too many params
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(start=0, end=5, periods=6, freq=1.5)
|
||||
|
||||
# mixed units
|
||||
msg = "start, end, freq need to be type compatible"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=0, end=Timestamp("20130101"), freq=2)
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=0, end=Timedelta("1 day"), freq=2)
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=0, end=10, freq="D")
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=Timestamp("20130101"), end=10, freq="D")
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(
|
||||
start=Timestamp("20130101"), end=Timedelta("1 day"), freq="D"
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(
|
||||
start=Timestamp("20130101"), end=Timestamp("20130110"), freq=2
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=Timedelta("1 day"), end=10, freq="D")
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(
|
||||
start=Timedelta("1 day"), end=Timestamp("20130110"), freq="D"
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=Timedelta("1 day"), end=Timedelta("10 days"), freq=2)
|
||||
|
||||
# invalid periods
|
||||
msg = "periods must be a number, got foo"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=0, periods="foo")
|
||||
|
||||
# invalid start
|
||||
msg = "start must be numeric or datetime-like, got foo"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(start="foo", periods=10)
|
||||
|
||||
# invalid end
|
||||
msg = r"end must be numeric or datetime-like, got \(0, 1\]"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(end=Interval(0, 1), periods=10)
|
||||
|
||||
# invalid freq for datetime-like
|
||||
msg = "freq must be numeric or convertible to DateOffset, got foo"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(start=0, end=10, freq="foo")
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(start=Timestamp("20130101"), periods=10, freq="foo")
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
interval_range(end=Timedelta("1 day"), periods=10, freq="foo")
|
||||
|
||||
# mixed tz
|
||||
start = Timestamp("2017-01-01", tz="US/Eastern")
|
||||
end = Timestamp("2017-01-07", tz="US/Pacific")
|
||||
msg = "Start and end cannot both be tz-aware with different timezones"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
interval_range(start=start, end=end)
|
@@ -0,0 +1,197 @@
|
||||
from itertools import permutations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas._libs.interval import IntervalTree
|
||||
|
||||
from pandas import compat
|
||||
import pandas.util.testing as tm
|
||||
|
||||
|
||||
def skipif_32bit(param):
|
||||
"""
|
||||
Skip parameters in a parametrize on 32bit systems. Specifically used
|
||||
here to skip leaf_size parameters related to GH 23440.
|
||||
"""
|
||||
marks = pytest.mark.skipif(
|
||||
compat.is_platform_32bit(), reason="GH 23440: int type mismatch on 32bit"
|
||||
)
|
||||
return pytest.param(param, marks=marks)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="class", params=["int32", "int64", "float32", "float64", "uint64"]
|
||||
)
|
||||
def dtype(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10])
|
||||
def leaf_size(request):
|
||||
"""
|
||||
Fixture to specify IntervalTree leaf_size parameter; to be used with the
|
||||
tree fixture.
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
np.arange(5, dtype="int64"),
|
||||
np.arange(5, dtype="int32"),
|
||||
np.arange(5, dtype="uint64"),
|
||||
np.arange(5, dtype="float64"),
|
||||
np.arange(5, dtype="float32"),
|
||||
np.array([0, 1, 2, 3, 4, np.nan], dtype="float64"),
|
||||
np.array([0, 1, 2, 3, 4, np.nan], dtype="float32"),
|
||||
]
|
||||
)
|
||||
def tree(request, leaf_size):
|
||||
left = request.param
|
||||
return IntervalTree(left, left + 2, leaf_size=leaf_size)
|
||||
|
||||
|
||||
class TestIntervalTree:
|
||||
def test_get_loc(self, tree):
|
||||
result = tree.get_loc(1)
|
||||
expected = np.array([0], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
result = np.sort(tree.get_loc(2))
|
||||
expected = np.array([0, 1], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
with pytest.raises(KeyError, match="-1"):
|
||||
tree.get_loc(-1)
|
||||
|
||||
def test_get_indexer(self, tree):
|
||||
result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
|
||||
expected = np.array([0, 4, -1], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
with pytest.raises(
|
||||
KeyError, match="'indexer does not intersect a unique set of intervals'"
|
||||
):
|
||||
tree.get_indexer(np.array([3.0]))
|
||||
|
||||
def test_get_indexer_non_unique(self, tree):
|
||||
indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))
|
||||
|
||||
result = indexer[:1]
|
||||
expected = np.array([0], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
result = np.sort(indexer[1:3])
|
||||
expected = np.array([0, 1], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
result = np.sort(indexer[3:])
|
||||
expected = np.array([-1], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
result = missing
|
||||
expected = np.array([2], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
def test_duplicates(self, dtype):
|
||||
left = np.array([0, 0, 0], dtype=dtype)
|
||||
tree = IntervalTree(left, left + 1)
|
||||
|
||||
result = np.sort(tree.get_loc(0.5))
|
||||
expected = np.array([0, 1, 2], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
with pytest.raises(
|
||||
KeyError, match="'indexer does not intersect a unique set of intervals'"
|
||||
):
|
||||
tree.get_indexer(np.array([0.5]))
|
||||
|
||||
indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
|
||||
result = np.sort(indexer)
|
||||
expected = np.array([0, 1, 2], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
result = missing
|
||||
expected = np.array([], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
def test_get_loc_closed(self, closed):
|
||||
tree = IntervalTree([0], [1], closed=closed)
|
||||
for p, errors in [(0, tree.open_left), (1, tree.open_right)]:
|
||||
if errors:
|
||||
with pytest.raises(KeyError, match=str(p)):
|
||||
tree.get_loc(p)
|
||||
else:
|
||||
result = tree.get_loc(p)
|
||||
expected = np.array([0], dtype="intp")
|
||||
tm.assert_numpy_array_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"leaf_size", [skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000]
|
||||
)
|
||||
def test_get_indexer_closed(self, closed, leaf_size):
|
||||
x = np.arange(1000, dtype="float64")
|
||||
found = x.astype("intp")
|
||||
not_found = (-1 * np.ones(1000)).astype("intp")
|
||||
|
||||
tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size)
|
||||
tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25))
|
||||
|
||||
expected = found if tree.closed_left else not_found
|
||||
tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0))
|
||||
|
||||
expected = found if tree.closed_right else not_found
|
||||
tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"left, right, expected",
|
||||
[
|
||||
(np.array([0, 1, 4]), np.array([2, 3, 5]), True),
|
||||
(np.array([0, 1, 2]), np.array([5, 4, 3]), True),
|
||||
(np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True),
|
||||
(np.array([0, 2, 4]), np.array([1, 3, 5]), False),
|
||||
(np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("order", map(list, permutations(range(3))))
|
||||
def test_is_overlapping(self, closed, order, left, right, expected):
|
||||
# GH 23309
|
||||
tree = IntervalTree(left[order], right[order], closed=closed)
|
||||
result = tree.is_overlapping
|
||||
assert result is expected
|
||||
|
||||
@pytest.mark.parametrize("order", map(list, permutations(range(3))))
|
||||
def test_is_overlapping_endpoints(self, closed, order):
|
||||
"""shared endpoints are marked as overlapping"""
|
||||
# GH 23309
|
||||
left, right = np.arange(3), np.arange(1, 4)
|
||||
tree = IntervalTree(left[order], right[order], closed=closed)
|
||||
result = tree.is_overlapping
|
||||
expected = closed == "both"
|
||||
assert result is expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"left, right",
|
||||
[
|
||||
(np.array([], dtype="int64"), np.array([], dtype="int64")),
|
||||
(np.array([0], dtype="int64"), np.array([1], dtype="int64")),
|
||||
(np.array([np.nan]), np.array([np.nan])),
|
||||
(np.array([np.nan] * 3), np.array([np.nan] * 3)),
|
||||
],
|
||||
)
|
||||
def test_is_overlapping_trivial(self, closed, left, right):
|
||||
# GH 23309
|
||||
tree = IntervalTree(left, right, closed=closed)
|
||||
assert tree.is_overlapping is False
|
||||
|
||||
@pytest.mark.skipif(compat.is_platform_32bit(), reason="GH 23440")
|
||||
def test_construction_overflow(self):
|
||||
# GH 25485
|
||||
left, right = np.arange(101), [np.iinfo(np.int64).max] * 101
|
||||
tree = IntervalTree(left, right)
|
||||
|
||||
# pivot should be average of left/right medians
|
||||
result = tree.root.pivot
|
||||
expected = (50 + np.iinfo(np.int64).max) / 2
|
||||
assert result == expected
|
@@ -0,0 +1,187 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pandas import Index, IntervalIndex, Timestamp, interval_range
|
||||
import pandas.util.testing as tm
|
||||
|
||||
|
||||
@pytest.fixture(scope="class", params=[None, "foo"])
|
||||
def name(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[None, False])
|
||||
def sort(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def monotonic_index(start, end, dtype="int64", closed="right"):
|
||||
return IntervalIndex.from_breaks(np.arange(start, end, dtype=dtype), closed=closed)
|
||||
|
||||
|
||||
def empty_index(dtype="int64", closed="right"):
|
||||
return IntervalIndex(np.array([], dtype=dtype), closed=closed)
|
||||
|
||||
|
||||
class TestIntervalIndex:
|
||||
def test_union(self, closed, sort):
|
||||
index = monotonic_index(0, 11, closed=closed)
|
||||
other = monotonic_index(5, 13, closed=closed)
|
||||
|
||||
expected = monotonic_index(0, 13, closed=closed)
|
||||
result = index[::-1].union(other, sort=sort)
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
result = other[::-1].union(index, sort=sort)
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
tm.assert_index_equal(index.union(index, sort=sort), index)
|
||||
tm.assert_index_equal(index.union(index[:1], sort=sort), index)
|
||||
|
||||
# GH 19101: empty result, same dtype
|
||||
index = empty_index(dtype="int64", closed=closed)
|
||||
result = index.union(index, sort=sort)
|
||||
tm.assert_index_equal(result, index)
|
||||
|
||||
# GH 19101: empty result, different dtypes
|
||||
other = empty_index(dtype="float64", closed=closed)
|
||||
result = index.union(other, sort=sort)
|
||||
tm.assert_index_equal(result, index)
|
||||
|
||||
def test_intersection(self, closed, sort):
|
||||
index = monotonic_index(0, 11, closed=closed)
|
||||
other = monotonic_index(5, 13, closed=closed)
|
||||
|
||||
expected = monotonic_index(5, 11, closed=closed)
|
||||
result = index[::-1].intersection(other, sort=sort)
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
result = other[::-1].intersection(index, sort=sort)
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
tm.assert_index_equal(index.intersection(index, sort=sort), index)
|
||||
|
||||
# GH 19101: empty result, same dtype
|
||||
other = monotonic_index(300, 314, closed=closed)
|
||||
expected = empty_index(dtype="int64", closed=closed)
|
||||
result = index.intersection(other, sort=sort)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 19101: empty result, different dtypes
|
||||
other = monotonic_index(300, 314, dtype="float64", closed=closed)
|
||||
result = index.intersection(other, sort=sort)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 26225: nested intervals
|
||||
index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
|
||||
other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
|
||||
expected = IntervalIndex.from_tuples([(1, 2), (1, 3)])
|
||||
result = index.intersection(other)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 26225: duplicate element
|
||||
index = IntervalIndex.from_tuples([(1, 2), (1, 2), (2, 3), (3, 4)])
|
||||
other = IntervalIndex.from_tuples([(1, 2), (2, 3)])
|
||||
expected = IntervalIndex.from_tuples([(1, 2), (1, 2), (2, 3)])
|
||||
result = index.intersection(other)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 26225
|
||||
index = IntervalIndex.from_tuples([(0, 3), (0, 2)])
|
||||
other = IntervalIndex.from_tuples([(0, 2), (1, 3)])
|
||||
expected = IntervalIndex.from_tuples([(0, 2)])
|
||||
result = index.intersection(other)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 26225: duplicate nan element
|
||||
index = IntervalIndex([np.nan, np.nan])
|
||||
other = IntervalIndex([np.nan])
|
||||
expected = IntervalIndex([np.nan])
|
||||
result = index.intersection(other)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
def test_difference(self, closed, sort):
|
||||
index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
|
||||
result = index.difference(index[:1], sort=sort)
|
||||
expected = index[1:]
|
||||
if sort is None:
|
||||
expected = expected.sort_values()
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 19101: empty result, same dtype
|
||||
result = index.difference(index, sort=sort)
|
||||
expected = empty_index(dtype="int64", closed=closed)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# GH 19101: empty result, different dtypes
|
||||
other = IntervalIndex.from_arrays(
|
||||
index.left.astype("float64"), index.right, closed=closed
|
||||
)
|
||||
result = index.difference(other, sort=sort)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
def test_symmetric_difference(self, closed, sort):
|
||||
index = monotonic_index(0, 11, closed=closed)
|
||||
result = index[1:].symmetric_difference(index[:-1], sort=sort)
|
||||
expected = IntervalIndex([index[0], index[-1]])
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
# GH 19101: empty result, same dtype
|
||||
result = index.symmetric_difference(index, sort=sort)
|
||||
expected = empty_index(dtype="int64", closed=closed)
|
||||
if sort is None:
|
||||
tm.assert_index_equal(result, expected)
|
||||
assert tm.equalContents(result, expected)
|
||||
|
||||
# GH 19101: empty result, different dtypes
|
||||
other = IntervalIndex.from_arrays(
|
||||
index.left.astype("float64"), index.right, closed=closed
|
||||
)
|
||||
result = index.symmetric_difference(other, sort=sort)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op_name", ["union", "intersection", "difference", "symmetric_difference"]
|
||||
)
|
||||
@pytest.mark.parametrize("sort", [None, False])
|
||||
def test_set_incompatible_types(self, closed, op_name, sort):
|
||||
index = monotonic_index(0, 11, closed=closed)
|
||||
set_op = getattr(index, op_name)
|
||||
|
||||
# TODO: standardize return type of non-union setops type(self vs other)
|
||||
# non-IntervalIndex
|
||||
if op_name == "difference":
|
||||
expected = index
|
||||
else:
|
||||
expected = getattr(index.astype("O"), op_name)(Index([1, 2, 3]))
|
||||
result = set_op(Index([1, 2, 3]), sort=sort)
|
||||
tm.assert_index_equal(result, expected)
|
||||
|
||||
# mixed closed
|
||||
msg = (
|
||||
"can only do set operations between two IntervalIndex objects "
|
||||
"that are closed on the same side"
|
||||
)
|
||||
for other_closed in {"right", "left", "both", "neither"} - {closed}:
|
||||
other = monotonic_index(0, 11, closed=other_closed)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
set_op(other, sort=sort)
|
||||
|
||||
# GH 19016: incompatible dtypes
|
||||
other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
|
||||
msg = (
|
||||
"can only do {op} between two IntervalIndex objects that have "
|
||||
"compatible dtypes"
|
||||
).format(op=op_name)
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
set_op(other, sort=sort)
|
Reference in New Issue
Block a user