203 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			203 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import numpy as np
 | 
						|
import pytest
 | 
						|
 | 
						|
from pandas import (
 | 
						|
    Index,
 | 
						|
    IntervalIndex,
 | 
						|
    Timestamp,
 | 
						|
    interval_range,
 | 
						|
)
 | 
						|
import pandas._testing as tm
 | 
						|
 | 
						|
 | 
						|
def monotonic_index(start, end, dtype="int64", closed="right"):
 | 
						|
    return IntervalIndex.from_breaks(np.arange(start, end, dtype=dtype), closed=closed)
 | 
						|
 | 
						|
 | 
						|
def empty_index(dtype="int64", closed="right"):
 | 
						|
    return IntervalIndex(np.array([], dtype=dtype), closed=closed)
 | 
						|
 | 
						|
 | 
						|
class TestIntervalIndex:
 | 
						|
    def test_union(self, closed, sort):
 | 
						|
        index = monotonic_index(0, 11, closed=closed)
 | 
						|
        other = monotonic_index(5, 13, closed=closed)
 | 
						|
 | 
						|
        expected = monotonic_index(0, 13, closed=closed)
 | 
						|
        result = index[::-1].union(other, sort=sort)
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        result = other[::-1].union(index, sort=sort)
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        tm.assert_index_equal(index.union(index, sort=sort), index)
 | 
						|
        tm.assert_index_equal(index.union(index[:1], sort=sort), index)
 | 
						|
 | 
						|
    def test_union_empty_result(self, closed, sort):
 | 
						|
        # GH 19101: empty result, same dtype
 | 
						|
        index = empty_index(dtype="int64", closed=closed)
 | 
						|
        result = index.union(index, sort=sort)
 | 
						|
        tm.assert_index_equal(result, index)
 | 
						|
 | 
						|
        # GH 19101: empty result, different numeric dtypes -> common dtype is f8
 | 
						|
        other = empty_index(dtype="float64", closed=closed)
 | 
						|
        result = index.union(other, sort=sort)
 | 
						|
        expected = other
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        other = index.union(index, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        other = empty_index(dtype="uint64", closed=closed)
 | 
						|
        result = index.union(other, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        result = other.union(index, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    def test_intersection(self, closed, sort):
 | 
						|
        index = monotonic_index(0, 11, closed=closed)
 | 
						|
        other = monotonic_index(5, 13, closed=closed)
 | 
						|
 | 
						|
        expected = monotonic_index(5, 11, closed=closed)
 | 
						|
        result = index[::-1].intersection(other, sort=sort)
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        result = other[::-1].intersection(index, sort=sort)
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        tm.assert_index_equal(index.intersection(index, sort=sort), index)
 | 
						|
 | 
						|
        # GH 26225: nested intervals
 | 
						|
        index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
 | 
						|
        other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
 | 
						|
        expected = IntervalIndex.from_tuples([(1, 2), (1, 3)])
 | 
						|
        result = index.intersection(other)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 26225
 | 
						|
        index = IntervalIndex.from_tuples([(0, 3), (0, 2)])
 | 
						|
        other = IntervalIndex.from_tuples([(0, 2), (1, 3)])
 | 
						|
        expected = IntervalIndex.from_tuples([(0, 2)])
 | 
						|
        result = index.intersection(other)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 26225: duplicate nan element
 | 
						|
        index = IntervalIndex([np.nan, np.nan])
 | 
						|
        other = IntervalIndex([np.nan])
 | 
						|
        expected = IntervalIndex([np.nan])
 | 
						|
        result = index.intersection(other)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    def test_intersection_empty_result(self, closed, sort):
 | 
						|
        index = monotonic_index(0, 11, closed=closed)
 | 
						|
 | 
						|
        # GH 19101: empty result, same dtype
 | 
						|
        other = monotonic_index(300, 314, closed=closed)
 | 
						|
        expected = empty_index(dtype="int64", closed=closed)
 | 
						|
        result = index.intersection(other, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 19101: empty result, different numeric dtypes -> common dtype is float64
 | 
						|
        other = monotonic_index(300, 314, dtype="float64", closed=closed)
 | 
						|
        result = index.intersection(other, sort=sort)
 | 
						|
        expected = other[:0]
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        other = monotonic_index(300, 314, dtype="uint64", closed=closed)
 | 
						|
        result = index.intersection(other, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    def test_intersection_duplicates(self):
 | 
						|
        # GH#38743
 | 
						|
        index = IntervalIndex.from_tuples([(1, 2), (1, 2), (2, 3), (3, 4)])
 | 
						|
        other = IntervalIndex.from_tuples([(1, 2), (2, 3)])
 | 
						|
        expected = IntervalIndex.from_tuples([(1, 2), (2, 3)])
 | 
						|
        result = index.intersection(other)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    def test_difference(self, closed, sort):
 | 
						|
        index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
 | 
						|
        result = index.difference(index[:1], sort=sort)
 | 
						|
        expected = index[1:]
 | 
						|
        if sort is None:
 | 
						|
            expected = expected.sort_values()
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 19101: empty result, same dtype
 | 
						|
        result = index.difference(index, sort=sort)
 | 
						|
        expected = empty_index(dtype="int64", closed=closed)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 19101: empty result, different dtypes
 | 
						|
        other = IntervalIndex.from_arrays(
 | 
						|
            index.left.astype("float64"), index.right, closed=closed
 | 
						|
        )
 | 
						|
        result = index.difference(other, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    def test_symmetric_difference(self, closed, sort):
 | 
						|
        index = monotonic_index(0, 11, closed=closed)
 | 
						|
        result = index[1:].symmetric_difference(index[:-1], sort=sort)
 | 
						|
        expected = IntervalIndex([index[0], index[-1]])
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        # GH 19101: empty result, same dtype
 | 
						|
        result = index.symmetric_difference(index, sort=sort)
 | 
						|
        expected = empty_index(dtype="int64", closed=closed)
 | 
						|
        if sort is None:
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
        assert tm.equalContents(result, expected)
 | 
						|
 | 
						|
        # GH 19101: empty result, different dtypes
 | 
						|
        other = IntervalIndex.from_arrays(
 | 
						|
            index.left.astype("float64"), index.right, closed=closed
 | 
						|
        )
 | 
						|
        result = index.symmetric_difference(other, sort=sort)
 | 
						|
        expected = empty_index(dtype="float64", closed=closed)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
    @pytest.mark.filterwarnings("ignore:'<' not supported between:RuntimeWarning")
 | 
						|
    @pytest.mark.parametrize(
 | 
						|
        "op_name", ["union", "intersection", "difference", "symmetric_difference"]
 | 
						|
    )
 | 
						|
    def test_set_incompatible_types(self, closed, op_name, sort):
 | 
						|
        index = monotonic_index(0, 11, closed=closed)
 | 
						|
        set_op = getattr(index, op_name)
 | 
						|
 | 
						|
        # TODO: standardize return type of non-union setops type(self vs other)
 | 
						|
        # non-IntervalIndex
 | 
						|
        if op_name == "difference":
 | 
						|
            expected = index
 | 
						|
        else:
 | 
						|
            expected = getattr(index.astype("O"), op_name)(Index([1, 2, 3]))
 | 
						|
        result = set_op(Index([1, 2, 3]), sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # mixed closed -> cast to object
 | 
						|
        for other_closed in {"right", "left", "both", "neither"} - {closed}:
 | 
						|
            other = monotonic_index(0, 11, closed=other_closed)
 | 
						|
            expected = getattr(index.astype(object), op_name)(other, sort=sort)
 | 
						|
            if op_name == "difference":
 | 
						|
                expected = index
 | 
						|
            result = set_op(other, sort=sort)
 | 
						|
            tm.assert_index_equal(result, expected)
 | 
						|
 | 
						|
        # GH 19016: incompatible dtypes -> cast to object
 | 
						|
        other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
 | 
						|
        expected = getattr(index.astype(object), op_name)(other, sort=sort)
 | 
						|
        if op_name == "difference":
 | 
						|
            expected = index
 | 
						|
        result = set_op(other, sort=sort)
 | 
						|
        tm.assert_index_equal(result, expected)
 |