针对pulse-transit的工具
This commit is contained in:
		
							
								
								
									
										369
									
								
								dist/client/pandas/tests/arrays/interval/test_interval.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										369
									
								
								dist/client/pandas/tests/arrays/interval/test_interval.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,369 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
import pandas.util._test_decorators as td
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from pandas import (
 | 
			
		||||
    Index,
 | 
			
		||||
    Interval,
 | 
			
		||||
    IntervalIndex,
 | 
			
		||||
    Timedelta,
 | 
			
		||||
    Timestamp,
 | 
			
		||||
    date_range,
 | 
			
		||||
    timedelta_range,
 | 
			
		||||
)
 | 
			
		||||
import pandas._testing as tm
 | 
			
		||||
from pandas.core.arrays import IntervalArray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(
 | 
			
		||||
    params=[
 | 
			
		||||
        (Index([0, 2, 4]), Index([1, 3, 5])),
 | 
			
		||||
        (Index([0.0, 1.0, 2.0]), Index([1.0, 2.0, 3.0])),
 | 
			
		||||
        (timedelta_range("0 days", periods=3), timedelta_range("1 day", periods=3)),
 | 
			
		||||
        (date_range("20170101", periods=3), date_range("20170102", periods=3)),
 | 
			
		||||
        (
 | 
			
		||||
            date_range("20170101", periods=3, tz="US/Eastern"),
 | 
			
		||||
            date_range("20170102", periods=3, tz="US/Eastern"),
 | 
			
		||||
        ),
 | 
			
		||||
    ],
 | 
			
		||||
    ids=lambda x: str(x[0].dtype),
 | 
			
		||||
)
 | 
			
		||||
def left_right_dtypes(request):
 | 
			
		||||
    """
 | 
			
		||||
    Fixture for building an IntervalArray from various dtypes
 | 
			
		||||
    """
 | 
			
		||||
    return request.param
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestAttributes:
 | 
			
		||||
    @pytest.mark.parametrize(
 | 
			
		||||
        "left, right",
 | 
			
		||||
        [
 | 
			
		||||
            (0, 1),
 | 
			
		||||
            (Timedelta("0 days"), Timedelta("1 day")),
 | 
			
		||||
            (Timestamp("2018-01-01"), Timestamp("2018-01-02")),
 | 
			
		||||
            (
 | 
			
		||||
                Timestamp("2018-01-01", tz="US/Eastern"),
 | 
			
		||||
                Timestamp("2018-01-02", tz="US/Eastern"),
 | 
			
		||||
            ),
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
    @pytest.mark.parametrize("constructor", [IntervalArray, IntervalIndex])
 | 
			
		||||
    def test_is_empty(self, constructor, left, right, closed):
 | 
			
		||||
        # GH27219
 | 
			
		||||
        tuples = [(left, left), (left, right), np.nan]
 | 
			
		||||
        expected = np.array([closed != "both", False, False])
 | 
			
		||||
        result = constructor.from_tuples(tuples, closed=closed).is_empty
 | 
			
		||||
        tm.assert_numpy_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestMethods:
 | 
			
		||||
    @pytest.mark.parametrize("new_closed", ["left", "right", "both", "neither"])
 | 
			
		||||
    def test_set_closed(self, closed, new_closed):
 | 
			
		||||
        # GH 21670
 | 
			
		||||
        array = IntervalArray.from_breaks(range(10), closed=closed)
 | 
			
		||||
        result = array.set_closed(new_closed)
 | 
			
		||||
        expected = IntervalArray.from_breaks(range(10), closed=new_closed)
 | 
			
		||||
        tm.assert_extension_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize(
 | 
			
		||||
        "other",
 | 
			
		||||
        [
 | 
			
		||||
            Interval(0, 1, closed="right"),
 | 
			
		||||
            IntervalArray.from_breaks([1, 2, 3, 4], closed="right"),
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
    def test_where_raises(self, other):
 | 
			
		||||
        ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left"))
 | 
			
		||||
        match = "'value.closed' is 'right', expected 'left'."
 | 
			
		||||
        with pytest.raises(ValueError, match=match):
 | 
			
		||||
            ser.where([True, False, True], other=other)
 | 
			
		||||
 | 
			
		||||
    def test_shift(self):
 | 
			
		||||
        # https://github.com/pandas-dev/pandas/issues/31495
 | 
			
		||||
        a = IntervalArray.from_breaks([1, 2, 3])
 | 
			
		||||
        result = a.shift()
 | 
			
		||||
        # int -> float
 | 
			
		||||
        expected = IntervalArray.from_tuples([(np.nan, np.nan), (1.0, 2.0)])
 | 
			
		||||
        tm.assert_interval_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
    def test_shift_datetime(self):
 | 
			
		||||
        a = IntervalArray.from_breaks(date_range("2000", periods=4))
 | 
			
		||||
        result = a.shift(2)
 | 
			
		||||
        expected = a.take([-1, -1, 0], allow_fill=True)
 | 
			
		||||
        tm.assert_interval_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
        result = a.shift(-1)
 | 
			
		||||
        expected = a.take([1, 2, -1], allow_fill=True)
 | 
			
		||||
        tm.assert_interval_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSetitem:
 | 
			
		||||
    def test_set_na(self, left_right_dtypes):
 | 
			
		||||
        left, right = left_right_dtypes
 | 
			
		||||
        left = left.copy(deep=True)
 | 
			
		||||
        right = right.copy(deep=True)
 | 
			
		||||
        result = IntervalArray.from_arrays(left, right)
 | 
			
		||||
 | 
			
		||||
        if result.dtype.subtype.kind not in ["m", "M"]:
 | 
			
		||||
            msg = "'value' should be an interval type, got <.*NaTType'> instead."
 | 
			
		||||
            with pytest.raises(TypeError, match=msg):
 | 
			
		||||
                result[0] = pd.NaT
 | 
			
		||||
        if result.dtype.subtype.kind in ["i", "u"]:
 | 
			
		||||
            msg = "Cannot set float NaN to integer-backed IntervalArray"
 | 
			
		||||
            with pytest.raises(ValueError, match=msg):
 | 
			
		||||
                result[0] = np.NaN
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        result[0] = np.nan
 | 
			
		||||
 | 
			
		||||
        expected_left = Index([left._na_value] + list(left[1:]))
 | 
			
		||||
        expected_right = Index([right._na_value] + list(right[1:]))
 | 
			
		||||
        expected = IntervalArray.from_arrays(expected_left, expected_right)
 | 
			
		||||
 | 
			
		||||
        tm.assert_extension_array_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
    def test_setitem_mismatched_closed(self):
 | 
			
		||||
        arr = IntervalArray.from_breaks(range(4))
 | 
			
		||||
        orig = arr.copy()
 | 
			
		||||
        other = arr.set_closed("both")
 | 
			
		||||
 | 
			
		||||
        msg = "'value.closed' is 'both', expected 'right'"
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[0] = other[0]
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:1] = other[:1]
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:0] = other[:0]
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:] = other[::-1]
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:] = list(other[::-1])
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:] = other[::-1].astype(object)
 | 
			
		||||
        with pytest.raises(ValueError, match=msg):
 | 
			
		||||
            arr[:] = other[::-1].astype("category")
 | 
			
		||||
 | 
			
		||||
        # empty list should be no-op
 | 
			
		||||
        arr[:0] = []
 | 
			
		||||
        tm.assert_interval_array_equal(arr, orig)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_repr():
 | 
			
		||||
    # GH 25022
 | 
			
		||||
    arr = IntervalArray.from_tuples([(0, 1), (1, 2)])
 | 
			
		||||
    result = repr(arr)
 | 
			
		||||
    expected = (
 | 
			
		||||
        "<IntervalArray>\n"
 | 
			
		||||
        "[(0, 1], (1, 2]]\n"
 | 
			
		||||
        "Length: 2, dtype: interval[int64, right]"
 | 
			
		||||
    )
 | 
			
		||||
    assert result == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestReductions:
 | 
			
		||||
    def test_min_max_invalid_axis(self, left_right_dtypes):
 | 
			
		||||
        left, right = left_right_dtypes
 | 
			
		||||
        left = left.copy(deep=True)
 | 
			
		||||
        right = right.copy(deep=True)
 | 
			
		||||
        arr = IntervalArray.from_arrays(left, right)
 | 
			
		||||
 | 
			
		||||
        msg = "`axis` must be fewer than the number of dimensions"
 | 
			
		||||
        for axis in [-2, 1]:
 | 
			
		||||
            with pytest.raises(ValueError, match=msg):
 | 
			
		||||
                arr.min(axis=axis)
 | 
			
		||||
            with pytest.raises(ValueError, match=msg):
 | 
			
		||||
                arr.max(axis=axis)
 | 
			
		||||
 | 
			
		||||
        msg = "'>=' not supported between"
 | 
			
		||||
        with pytest.raises(TypeError, match=msg):
 | 
			
		||||
            arr.min(axis="foo")
 | 
			
		||||
        with pytest.raises(TypeError, match=msg):
 | 
			
		||||
            arr.max(axis="foo")
 | 
			
		||||
 | 
			
		||||
    def test_min_max(self, left_right_dtypes, index_or_series_or_array):
 | 
			
		||||
        # GH#44746
 | 
			
		||||
        left, right = left_right_dtypes
 | 
			
		||||
        left = left.copy(deep=True)
 | 
			
		||||
        right = right.copy(deep=True)
 | 
			
		||||
        arr = IntervalArray.from_arrays(left, right)
 | 
			
		||||
 | 
			
		||||
        # The expected results below are only valid if monotonic
 | 
			
		||||
        assert left.is_monotonic_increasing
 | 
			
		||||
        assert Index(arr).is_monotonic_increasing
 | 
			
		||||
 | 
			
		||||
        MIN = arr[0]
 | 
			
		||||
        MAX = arr[-1]
 | 
			
		||||
 | 
			
		||||
        indexer = np.arange(len(arr))
 | 
			
		||||
        np.random.shuffle(indexer)
 | 
			
		||||
        arr = arr.take(indexer)
 | 
			
		||||
 | 
			
		||||
        arr_na = arr.insert(2, np.nan)
 | 
			
		||||
 | 
			
		||||
        arr = index_or_series_or_array(arr)
 | 
			
		||||
        arr_na = index_or_series_or_array(arr_na)
 | 
			
		||||
 | 
			
		||||
        for skipna in [True, False]:
 | 
			
		||||
            res = arr.min(skipna=skipna)
 | 
			
		||||
            assert res == MIN
 | 
			
		||||
            assert type(res) == type(MIN)
 | 
			
		||||
 | 
			
		||||
            res = arr.max(skipna=skipna)
 | 
			
		||||
            assert res == MAX
 | 
			
		||||
            assert type(res) == type(MAX)
 | 
			
		||||
 | 
			
		||||
        res = arr_na.min(skipna=False)
 | 
			
		||||
        assert np.isnan(res)
 | 
			
		||||
        res = arr_na.max(skipna=False)
 | 
			
		||||
        assert np.isnan(res)
 | 
			
		||||
 | 
			
		||||
        res = arr_na.min(skipna=True)
 | 
			
		||||
        assert res == MIN
 | 
			
		||||
        assert type(res) == type(MIN)
 | 
			
		||||
        res = arr_na.max(skipna=True)
 | 
			
		||||
        assert res == MAX
 | 
			
		||||
        assert type(res) == type(MAX)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ----------------------------------------------------------------------------
 | 
			
		||||
# Arrow interaction
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
pyarrow_skip = td.skip_if_no("pyarrow")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pyarrow_skip
 | 
			
		||||
def test_arrow_extension_type():
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    from pandas.core.arrays._arrow_utils import ArrowIntervalType
 | 
			
		||||
 | 
			
		||||
    p1 = ArrowIntervalType(pa.int64(), "left")
 | 
			
		||||
    p2 = ArrowIntervalType(pa.int64(), "left")
 | 
			
		||||
    p3 = ArrowIntervalType(pa.int64(), "right")
 | 
			
		||||
 | 
			
		||||
    assert p1.closed == "left"
 | 
			
		||||
    assert p1 == p2
 | 
			
		||||
    assert not p1 == p3
 | 
			
		||||
    assert hash(p1) == hash(p2)
 | 
			
		||||
    assert not hash(p1) == hash(p3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pyarrow_skip
 | 
			
		||||
def test_arrow_array():
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    from pandas.core.arrays._arrow_utils import ArrowIntervalType
 | 
			
		||||
 | 
			
		||||
    intervals = pd.interval_range(1, 5, freq=1).array
 | 
			
		||||
 | 
			
		||||
    result = pa.array(intervals)
 | 
			
		||||
    assert isinstance(result.type, ArrowIntervalType)
 | 
			
		||||
    assert result.type.closed == intervals.closed
 | 
			
		||||
    assert result.type.subtype == pa.int64()
 | 
			
		||||
    assert result.storage.field("left").equals(pa.array([1, 2, 3, 4], type="int64"))
 | 
			
		||||
    assert result.storage.field("right").equals(pa.array([2, 3, 4, 5], type="int64"))
 | 
			
		||||
 | 
			
		||||
    expected = pa.array([{"left": i, "right": i + 1} for i in range(1, 5)])
 | 
			
		||||
    assert result.storage.equals(expected)
 | 
			
		||||
 | 
			
		||||
    # convert to its storage type
 | 
			
		||||
    result = pa.array(intervals, type=expected.type)
 | 
			
		||||
    assert result.equals(expected)
 | 
			
		||||
 | 
			
		||||
    # unsupported conversions
 | 
			
		||||
    with pytest.raises(TypeError, match="Not supported to convert IntervalArray"):
 | 
			
		||||
        pa.array(intervals, type="float64")
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(TypeError, match="different 'subtype'"):
 | 
			
		||||
        pa.array(intervals, type=ArrowIntervalType(pa.float64(), "left"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pyarrow_skip
 | 
			
		||||
def test_arrow_array_missing():
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    from pandas.core.arrays._arrow_utils import ArrowIntervalType
 | 
			
		||||
 | 
			
		||||
    arr = IntervalArray.from_breaks([0.0, 1.0, 2.0, 3.0])
 | 
			
		||||
    arr[1] = None
 | 
			
		||||
 | 
			
		||||
    result = pa.array(arr)
 | 
			
		||||
    assert isinstance(result.type, ArrowIntervalType)
 | 
			
		||||
    assert result.type.closed == arr.closed
 | 
			
		||||
    assert result.type.subtype == pa.float64()
 | 
			
		||||
 | 
			
		||||
    # fields have missing values (not NaN)
 | 
			
		||||
    left = pa.array([0.0, None, 2.0], type="float64")
 | 
			
		||||
    right = pa.array([1.0, None, 3.0], type="float64")
 | 
			
		||||
    assert result.storage.field("left").equals(left)
 | 
			
		||||
    assert result.storage.field("right").equals(right)
 | 
			
		||||
 | 
			
		||||
    # structarray itself also has missing values on the array level
 | 
			
		||||
    vals = [
 | 
			
		||||
        {"left": 0.0, "right": 1.0},
 | 
			
		||||
        {"left": None, "right": None},
 | 
			
		||||
        {"left": 2.0, "right": 3.0},
 | 
			
		||||
    ]
 | 
			
		||||
    expected = pa.StructArray.from_pandas(vals, mask=np.array([False, True, False]))
 | 
			
		||||
    assert result.storage.equals(expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pyarrow_skip
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "breaks",
 | 
			
		||||
    [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")],
 | 
			
		||||
    ids=["float", "datetime64[ns]"],
 | 
			
		||||
)
 | 
			
		||||
def test_arrow_table_roundtrip(breaks):
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    from pandas.core.arrays._arrow_utils import ArrowIntervalType
 | 
			
		||||
 | 
			
		||||
    arr = IntervalArray.from_breaks(breaks)
 | 
			
		||||
    arr[1] = None
 | 
			
		||||
    df = pd.DataFrame({"a": arr})
 | 
			
		||||
 | 
			
		||||
    table = pa.table(df)
 | 
			
		||||
    assert isinstance(table.field("a").type, ArrowIntervalType)
 | 
			
		||||
    result = table.to_pandas()
 | 
			
		||||
    assert isinstance(result["a"].dtype, pd.IntervalDtype)
 | 
			
		||||
    tm.assert_frame_equal(result, df)
 | 
			
		||||
 | 
			
		||||
    table2 = pa.concat_tables([table, table])
 | 
			
		||||
    result = table2.to_pandas()
 | 
			
		||||
    expected = pd.concat([df, df], ignore_index=True)
 | 
			
		||||
    tm.assert_frame_equal(result, expected)
 | 
			
		||||
 | 
			
		||||
    # GH-41040
 | 
			
		||||
    table = pa.table(
 | 
			
		||||
        [pa.chunked_array([], type=table.column(0).type)], schema=table.schema
 | 
			
		||||
    )
 | 
			
		||||
    result = table.to_pandas()
 | 
			
		||||
    tm.assert_frame_equal(result, expected[0:0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pyarrow_skip
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "breaks",
 | 
			
		||||
    [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")],
 | 
			
		||||
    ids=["float", "datetime64[ns]"],
 | 
			
		||||
)
 | 
			
		||||
def test_arrow_table_roundtrip_without_metadata(breaks):
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    arr = IntervalArray.from_breaks(breaks)
 | 
			
		||||
    arr[1] = None
 | 
			
		||||
    df = pd.DataFrame({"a": arr})
 | 
			
		||||
 | 
			
		||||
    table = pa.table(df)
 | 
			
		||||
    # remove the metadata
 | 
			
		||||
    table = table.replace_schema_metadata()
 | 
			
		||||
    assert table.schema.metadata is None
 | 
			
		||||
 | 
			
		||||
    result = table.to_pandas()
 | 
			
		||||
    assert isinstance(result["a"].dtype, pd.IntervalDtype)
 | 
			
		||||
    tm.assert_frame_equal(result, df)
 | 
			
		||||
		Reference in New Issue
	
	Block a user