133 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Test extension array for storing nested data in a pandas container.
 | 
						|
 | 
						|
The ListArray stores an ndarray of lists.
 | 
						|
"""
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import numbers
 | 
						|
import random
 | 
						|
import string
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from pandas._typing import type_t
 | 
						|
 | 
						|
from pandas.core.dtypes.base import ExtensionDtype
 | 
						|
 | 
						|
import pandas as pd
 | 
						|
from pandas.api.types import (
 | 
						|
    is_object_dtype,
 | 
						|
    is_string_dtype,
 | 
						|
)
 | 
						|
from pandas.core.arrays import ExtensionArray
 | 
						|
 | 
						|
 | 
						|
class ListDtype(ExtensionDtype):
 | 
						|
    type = list
 | 
						|
    name = "list"
 | 
						|
    na_value = np.nan
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def construct_array_type(cls) -> type_t[ListArray]:
 | 
						|
        """
 | 
						|
        Return the array type associated with this dtype.
 | 
						|
 | 
						|
        Returns
 | 
						|
        -------
 | 
						|
        type
 | 
						|
        """
 | 
						|
        return ListArray
 | 
						|
 | 
						|
 | 
						|
class ListArray(ExtensionArray):
 | 
						|
    dtype = ListDtype()
 | 
						|
    __array_priority__ = 1000
 | 
						|
 | 
						|
    def __init__(self, values, dtype=None, copy=False):
 | 
						|
        if not isinstance(values, np.ndarray):
 | 
						|
            raise TypeError("Need to pass a numpy array as values")
 | 
						|
        for val in values:
 | 
						|
            if not isinstance(val, self.dtype.type) and not pd.isna(val):
 | 
						|
                raise TypeError("All values must be of type " + str(self.dtype.type))
 | 
						|
        self.data = values
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _from_sequence(cls, scalars, dtype=None, copy=False):
 | 
						|
        data = np.empty(len(scalars), dtype=object)
 | 
						|
        data[:] = scalars
 | 
						|
        return cls(data)
 | 
						|
 | 
						|
    def __getitem__(self, item):
 | 
						|
        if isinstance(item, numbers.Integral):
 | 
						|
            return self.data[item]
 | 
						|
        else:
 | 
						|
            # slice, list-like, mask
 | 
						|
            return type(self)(self.data[item])
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return len(self.data)
 | 
						|
 | 
						|
    def isna(self):
 | 
						|
        return np.array(
 | 
						|
            [not isinstance(x, list) and np.isnan(x) for x in self.data], dtype=bool
 | 
						|
        )
 | 
						|
 | 
						|
    def take(self, indexer, allow_fill=False, fill_value=None):
 | 
						|
        # re-implement here, since NumPy has trouble setting
 | 
						|
        # sized objects like UserDicts into scalar slots of
 | 
						|
        # an ndarary.
 | 
						|
        indexer = np.asarray(indexer)
 | 
						|
        msg = (
 | 
						|
            "Index is out of bounds or cannot do a "
 | 
						|
            "non-empty take from an empty array."
 | 
						|
        )
 | 
						|
 | 
						|
        if allow_fill:
 | 
						|
            if fill_value is None:
 | 
						|
                fill_value = self.dtype.na_value
 | 
						|
            # bounds check
 | 
						|
            if (indexer < -1).any():
 | 
						|
                raise ValueError
 | 
						|
            try:
 | 
						|
                output = [
 | 
						|
                    self.data[loc] if loc != -1 else fill_value for loc in indexer
 | 
						|
                ]
 | 
						|
            except IndexError as err:
 | 
						|
                raise IndexError(msg) from err
 | 
						|
        else:
 | 
						|
            try:
 | 
						|
                output = [self.data[loc] for loc in indexer]
 | 
						|
            except IndexError as err:
 | 
						|
                raise IndexError(msg) from err
 | 
						|
 | 
						|
        return self._from_sequence(output)
 | 
						|
 | 
						|
    def copy(self):
 | 
						|
        return type(self)(self.data[:])
 | 
						|
 | 
						|
    def astype(self, dtype, copy=True):
 | 
						|
        if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
 | 
						|
            if copy:
 | 
						|
                return self.copy()
 | 
						|
            return self
 | 
						|
        elif is_string_dtype(dtype) and not is_object_dtype(dtype):
 | 
						|
            # numpy has problems with astype(str) for nested elements
 | 
						|
            return np.array([str(x) for x in self.data], dtype=dtype)
 | 
						|
        return np.array(self.data, dtype=dtype, copy=copy)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _concat_same_type(cls, to_concat):
 | 
						|
        data = np.concatenate([x.data for x in to_concat])
 | 
						|
        return cls(data)
 | 
						|
 | 
						|
 | 
						|
def make_data():
 | 
						|
    # TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer
 | 
						|
    data = np.empty(100, dtype=object)
 | 
						|
    data[:] = [
 | 
						|
        [random.choice(string.ascii_letters) for _ in range(random.randint(0, 10))]
 | 
						|
        for _ in range(100)
 | 
						|
    ]
 | 
						|
    return data
 |