184 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import datetime as dt
 | 
						|
from typing import (
 | 
						|
    Any,
 | 
						|
    Optional,
 | 
						|
    Sequence,
 | 
						|
    Tuple,
 | 
						|
    Union,
 | 
						|
    cast,
 | 
						|
)
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from pandas._typing import (
 | 
						|
    Dtype,
 | 
						|
    PositionalIndexer,
 | 
						|
)
 | 
						|
 | 
						|
from pandas.core.dtypes.dtypes import register_extension_dtype
 | 
						|
 | 
						|
from pandas.api.extensions import (
 | 
						|
    ExtensionArray,
 | 
						|
    ExtensionDtype,
 | 
						|
)
 | 
						|
from pandas.api.types import pandas_dtype
 | 
						|
 | 
						|
 | 
						|
@register_extension_dtype
 | 
						|
class DateDtype(ExtensionDtype):
 | 
						|
    @property
 | 
						|
    def type(self):
 | 
						|
        return dt.date
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self):
 | 
						|
        return "DateDtype"
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def construct_from_string(cls, string: str):
 | 
						|
        if not isinstance(string, str):
 | 
						|
            raise TypeError(
 | 
						|
                f"'construct_from_string' expects a string, got {type(string)}"
 | 
						|
            )
 | 
						|
 | 
						|
        if string == cls.__name__:
 | 
						|
            return cls()
 | 
						|
        else:
 | 
						|
            raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def construct_array_type(cls):
 | 
						|
        return DateArray
 | 
						|
 | 
						|
    @property
 | 
						|
    def na_value(self):
 | 
						|
        return dt.date.min
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return self.name
 | 
						|
 | 
						|
 | 
						|
class DateArray(ExtensionArray):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        dates: Union[
 | 
						|
            dt.date,
 | 
						|
            Sequence[dt.date],
 | 
						|
            Tuple[np.ndarray, np.ndarray, np.ndarray],
 | 
						|
            np.ndarray,
 | 
						|
        ],
 | 
						|
    ) -> None:
 | 
						|
        if isinstance(dates, dt.date):
 | 
						|
            self._year = np.array([dates.year])
 | 
						|
            self._month = np.array([dates.month])
 | 
						|
            self._day = np.array([dates.year])
 | 
						|
            return
 | 
						|
 | 
						|
        ldates = len(dates)
 | 
						|
        if isinstance(dates, list):
 | 
						|
            # pre-allocate the arrays since we know the size before hand
 | 
						|
            self._year = np.zeros(ldates, dtype=np.uint16)  # 65535 (0, 9999)
 | 
						|
            self._month = np.zeros(ldates, dtype=np.uint8)  # 255 (1, 31)
 | 
						|
            self._day = np.zeros(ldates, dtype=np.uint8)  # 255 (1, 12)
 | 
						|
            # populate them
 | 
						|
            for i, (y, m, d) in enumerate(
 | 
						|
                map(lambda date: (date.year, date.month, date.day), dates)
 | 
						|
            ):
 | 
						|
                self._year[i] = y
 | 
						|
                self._month[i] = m
 | 
						|
                self._day[i] = d
 | 
						|
 | 
						|
        elif isinstance(dates, tuple):
 | 
						|
            # only support triples
 | 
						|
            if ldates != 3:
 | 
						|
                raise ValueError("only triples are valid")
 | 
						|
            # check if all elements have the same type
 | 
						|
            if any(map(lambda x: not isinstance(x, np.ndarray), dates)):
 | 
						|
                raise TypeError("invalid type")
 | 
						|
            ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
 | 
						|
            if not ly == lm == ld:
 | 
						|
                raise ValueError(
 | 
						|
                    f"tuple members must have the same length: {(ly, lm, ld)}"
 | 
						|
                )
 | 
						|
            self._year = dates[0].astype(np.uint16)
 | 
						|
            self._month = dates[1].astype(np.uint8)
 | 
						|
            self._day = dates[2].astype(np.uint8)
 | 
						|
 | 
						|
        elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
 | 
						|
            self._year = np.zeros(ldates, dtype=np.uint16)  # 65535 (0, 9999)
 | 
						|
            self._month = np.zeros(ldates, dtype=np.uint8)  # 255 (1, 31)
 | 
						|
            self._day = np.zeros(ldates, dtype=np.uint8)  # 255 (1, 12)
 | 
						|
 | 
						|
            # "object_" object is not iterable  [misc]
 | 
						|
            for (i,), (y, m, d) in np.ndenumerate(  # type: ignore[misc]
 | 
						|
                np.char.split(dates, sep="-")
 | 
						|
            ):
 | 
						|
                self._year[i] = int(y)
 | 
						|
                self._month[i] = int(m)
 | 
						|
                self._day[i] = int(d)
 | 
						|
 | 
						|
        else:
 | 
						|
            raise TypeError(f"{type(dates)} is not supported")
 | 
						|
 | 
						|
    @property
 | 
						|
    def dtype(self) -> ExtensionDtype:
 | 
						|
        return DateDtype()
 | 
						|
 | 
						|
    def astype(self, dtype, copy=True):
 | 
						|
        dtype = pandas_dtype(dtype)
 | 
						|
 | 
						|
        if isinstance(dtype, DateDtype):
 | 
						|
            data = self.copy() if copy else self
 | 
						|
        else:
 | 
						|
            data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
 | 
						|
 | 
						|
        return data
 | 
						|
 | 
						|
    @property
 | 
						|
    def nbytes(self) -> int:
 | 
						|
        return self._year.nbytes + self._month.nbytes + self._day.nbytes
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return len(self._year)  # all 3 arrays are enforced to have the same length
 | 
						|
 | 
						|
    def __getitem__(self, item: PositionalIndexer):
 | 
						|
        if isinstance(item, int):
 | 
						|
            return dt.date(self._year[item], self._month[item], self._day[item])
 | 
						|
        else:
 | 
						|
            raise NotImplementedError("only ints are supported as indexes")
 | 
						|
 | 
						|
    def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any):
 | 
						|
        if not isinstance(key, int):
 | 
						|
            raise NotImplementedError("only ints are supported as indexes")
 | 
						|
 | 
						|
        if not isinstance(value, dt.date):
 | 
						|
            raise TypeError("you can only set datetime.date types")
 | 
						|
 | 
						|
        self._year[key] = value.year
 | 
						|
        self._month[key] = value.month
 | 
						|
        self._day[key] = value.day
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f"DateArray{list(zip(self._year, self._month, self._day))}"
 | 
						|
 | 
						|
    def copy(self) -> "DateArray":
 | 
						|
        return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
 | 
						|
 | 
						|
    def isna(self) -> np.ndarray:
 | 
						|
        return np.logical_and(
 | 
						|
            np.logical_and(
 | 
						|
                self._year == dt.date.min.year, self._month == dt.date.min.month
 | 
						|
            ),
 | 
						|
            self._day == dt.date.min.day,
 | 
						|
        )
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
 | 
						|
        if isinstance(scalars, dt.date):
 | 
						|
            pass
 | 
						|
        elif isinstance(scalars, DateArray):
 | 
						|
            pass
 | 
						|
        elif isinstance(scalars, np.ndarray):
 | 
						|
            scalars = scalars.astype("U10")  # 10 chars for yyyy-mm-dd
 | 
						|
            return DateArray(scalars)
 |