273 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			273 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
import decimal
 | 
						|
import numbers
 | 
						|
import random
 | 
						|
import sys
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from pandas._typing import type_t
 | 
						|
 | 
						|
from pandas.core.dtypes.base import ExtensionDtype
 | 
						|
from pandas.core.dtypes.common import (
 | 
						|
    is_dtype_equal,
 | 
						|
    is_float,
 | 
						|
    pandas_dtype,
 | 
						|
)
 | 
						|
 | 
						|
import pandas as pd
 | 
						|
from pandas.api.extensions import (
 | 
						|
    no_default,
 | 
						|
    register_extension_dtype,
 | 
						|
)
 | 
						|
from pandas.api.types import (
 | 
						|
    is_list_like,
 | 
						|
    is_scalar,
 | 
						|
)
 | 
						|
from pandas.core import arraylike
 | 
						|
from pandas.core.arraylike import OpsMixin
 | 
						|
from pandas.core.arrays import (
 | 
						|
    ExtensionArray,
 | 
						|
    ExtensionScalarOpsMixin,
 | 
						|
)
 | 
						|
from pandas.core.indexers import check_array_indexer
 | 
						|
 | 
						|
 | 
						|
@register_extension_dtype
 | 
						|
class DecimalDtype(ExtensionDtype):
 | 
						|
    type = decimal.Decimal
 | 
						|
    name = "decimal"
 | 
						|
    na_value = decimal.Decimal("NaN")
 | 
						|
    _metadata = ("context",)
 | 
						|
 | 
						|
    def __init__(self, context=None):
 | 
						|
        self.context = context or decimal.getcontext()
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        return f"DecimalDtype(context={self.context})"
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def construct_array_type(cls) -> type_t[DecimalArray]:
 | 
						|
        """
 | 
						|
        Return the array type associated with this dtype.
 | 
						|
 | 
						|
        Returns
 | 
						|
        -------
 | 
						|
        type
 | 
						|
        """
 | 
						|
        return DecimalArray
 | 
						|
 | 
						|
    @property
 | 
						|
    def _is_numeric(self) -> bool:
 | 
						|
        return True
 | 
						|
 | 
						|
 | 
						|
class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
 | 
						|
    __array_priority__ = 1000
 | 
						|
 | 
						|
    def __init__(self, values, dtype=None, copy=False, context=None):
 | 
						|
        for i, val in enumerate(values):
 | 
						|
            if is_float(val):
 | 
						|
                if np.isnan(val):
 | 
						|
                    values[i] = DecimalDtype.na_value
 | 
						|
                else:
 | 
						|
                    values[i] = DecimalDtype.type(val)
 | 
						|
            elif not isinstance(val, decimal.Decimal):
 | 
						|
                raise TypeError("All values must be of type " + str(decimal.Decimal))
 | 
						|
        values = np.asarray(values, dtype=object)
 | 
						|
 | 
						|
        self._data = values
 | 
						|
        # Some aliases for common attribute names to ensure pandas supports
 | 
						|
        # these
 | 
						|
        self._items = self.data = self._data
 | 
						|
        # those aliases are currently not working due to assumptions
 | 
						|
        # in internal code (GH-20735)
 | 
						|
        # self._values = self.values = self.data
 | 
						|
        self._dtype = DecimalDtype(context)
 | 
						|
 | 
						|
    @property
 | 
						|
    def dtype(self):
 | 
						|
        return self._dtype
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _from_sequence(cls, scalars, dtype=None, copy=False):
 | 
						|
        return cls(scalars)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
 | 
						|
        return cls._from_sequence([decimal.Decimal(x) for x in strings], dtype, copy)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _from_factorized(cls, values, original):
 | 
						|
        return cls(values)
 | 
						|
 | 
						|
    _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)
 | 
						|
 | 
						|
    def to_numpy(
 | 
						|
        self, dtype=None, copy: bool = False, na_value=no_default, decimals=None
 | 
						|
    ) -> np.ndarray:
 | 
						|
        result = np.asarray(self, dtype=dtype)
 | 
						|
        if decimals is not None:
 | 
						|
            result = np.asarray([round(x, decimals) for x in result])
 | 
						|
        return result
 | 
						|
 | 
						|
    def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
 | 
						|
        #
 | 
						|
        if not all(
 | 
						|
            isinstance(t, self._HANDLED_TYPES + (DecimalArray,)) for t in inputs
 | 
						|
        ):
 | 
						|
            return NotImplemented
 | 
						|
 | 
						|
        inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
 | 
						|
        result = getattr(ufunc, method)(*inputs, **kwargs)
 | 
						|
 | 
						|
        if method == "reduce":
 | 
						|
            result = arraylike.dispatch_reduction_ufunc(
 | 
						|
                self, ufunc, method, *inputs, **kwargs
 | 
						|
            )
 | 
						|
            if result is not NotImplemented:
 | 
						|
                return result
 | 
						|
 | 
						|
        def reconstruct(x):
 | 
						|
            if isinstance(x, (decimal.Decimal, numbers.Number)):
 | 
						|
                return x
 | 
						|
            else:
 | 
						|
                return DecimalArray._from_sequence(x)
 | 
						|
 | 
						|
        if ufunc.nout > 1:
 | 
						|
            return tuple(reconstruct(x) for x in result)
 | 
						|
        else:
 | 
						|
            return reconstruct(result)
 | 
						|
 | 
						|
    def __getitem__(self, item):
 | 
						|
        if isinstance(item, numbers.Integral):
 | 
						|
            return self._data[item]
 | 
						|
        else:
 | 
						|
            # array, slice.
 | 
						|
            item = pd.api.indexers.check_array_indexer(self, item)
 | 
						|
            return type(self)(self._data[item])
 | 
						|
 | 
						|
    def take(self, indexer, allow_fill=False, fill_value=None):
 | 
						|
        from pandas.api.extensions import take
 | 
						|
 | 
						|
        data = self._data
 | 
						|
        if allow_fill and fill_value is None:
 | 
						|
            fill_value = self.dtype.na_value
 | 
						|
 | 
						|
        result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill)
 | 
						|
        return self._from_sequence(result)
 | 
						|
 | 
						|
    def copy(self):
 | 
						|
        return type(self)(self._data.copy(), dtype=self.dtype)
 | 
						|
 | 
						|
    def astype(self, dtype, copy=True):
 | 
						|
        if is_dtype_equal(dtype, self._dtype):
 | 
						|
            if not copy:
 | 
						|
                return self
 | 
						|
        dtype = pandas_dtype(dtype)
 | 
						|
        if isinstance(dtype, type(self.dtype)):
 | 
						|
            return type(self)(self._data, copy=copy, context=dtype.context)
 | 
						|
 | 
						|
        return super().astype(dtype, copy=copy)
 | 
						|
 | 
						|
    def __setitem__(self, key, value):
 | 
						|
        if is_list_like(value):
 | 
						|
            if is_scalar(key):
 | 
						|
                raise ValueError("setting an array element with a sequence.")
 | 
						|
            value = [decimal.Decimal(v) for v in value]
 | 
						|
        else:
 | 
						|
            value = decimal.Decimal(value)
 | 
						|
 | 
						|
        key = check_array_indexer(self, key)
 | 
						|
        self._data[key] = value
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return len(self._data)
 | 
						|
 | 
						|
    def __contains__(self, item) -> bool | np.bool_:
 | 
						|
        if not isinstance(item, decimal.Decimal):
 | 
						|
            return False
 | 
						|
        elif item.is_nan():
 | 
						|
            return self.isna().any()
 | 
						|
        else:
 | 
						|
            return super().__contains__(item)
 | 
						|
 | 
						|
    @property
 | 
						|
    def nbytes(self) -> int:
 | 
						|
        n = len(self)
 | 
						|
        if n:
 | 
						|
            return n * sys.getsizeof(self[0])
 | 
						|
        return 0
 | 
						|
 | 
						|
    def isna(self):
 | 
						|
        return np.array([x.is_nan() for x in self._data], dtype=bool)
 | 
						|
 | 
						|
    @property
 | 
						|
    def _na_value(self):
 | 
						|
        return decimal.Decimal("NaN")
 | 
						|
 | 
						|
    def _formatter(self, boxed=False):
 | 
						|
        if boxed:
 | 
						|
            return "Decimal: {}".format
 | 
						|
        return repr
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _concat_same_type(cls, to_concat):
 | 
						|
        return cls(np.concatenate([x._data for x in to_concat]))
 | 
						|
 | 
						|
    def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
 | 
						|
 | 
						|
        if skipna:
 | 
						|
            # If we don't have any NAs, we can ignore skipna
 | 
						|
            if self.isna().any():
 | 
						|
                other = self[~self.isna()]
 | 
						|
                return other._reduce(name, **kwargs)
 | 
						|
 | 
						|
        if name == "sum" and len(self) == 0:
 | 
						|
            # GH#29630 avoid returning int 0 or np.bool_(False) on old numpy
 | 
						|
            return decimal.Decimal(0)
 | 
						|
 | 
						|
        try:
 | 
						|
            op = getattr(self.data, name)
 | 
						|
        except AttributeError as err:
 | 
						|
            raise NotImplementedError(
 | 
						|
                f"decimal does not support the {name} operation"
 | 
						|
            ) from err
 | 
						|
        return op(axis=0)
 | 
						|
 | 
						|
    def _cmp_method(self, other, op):
 | 
						|
        # For use with OpsMixin
 | 
						|
        def convert_values(param):
 | 
						|
            if isinstance(param, ExtensionArray) or is_list_like(param):
 | 
						|
                ovalues = param
 | 
						|
            else:
 | 
						|
                # Assume it's an object
 | 
						|
                ovalues = [param] * len(self)
 | 
						|
            return ovalues
 | 
						|
 | 
						|
        lvalues = self
 | 
						|
        rvalues = convert_values(other)
 | 
						|
 | 
						|
        # If the operator is not defined for the underlying objects,
 | 
						|
        # a TypeError should be raised
 | 
						|
        res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
 | 
						|
 | 
						|
        return np.asarray(res, dtype=bool)
 | 
						|
 | 
						|
    def value_counts(self, dropna: bool = True):
 | 
						|
        from pandas.core.algorithms import value_counts
 | 
						|
 | 
						|
        return value_counts(self.to_numpy(), dropna=dropna)
 | 
						|
 | 
						|
 | 
						|
def to_decimal(values, context=None):
 | 
						|
    return DecimalArray([decimal.Decimal(x) for x in values], context=context)
 | 
						|
 | 
						|
 | 
						|
def make_data():
 | 
						|
    return [decimal.Decimal(random.random()) for _ in range(100)]
 | 
						|
 | 
						|
 | 
						|
DecimalArray._add_arithmetic_ops()
 |