Source code for cellarr_frame.base

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Literal, Optional, Tuple, Union
from warnings import warn

import numpy as np
import pandas as pd
import tiledb

__author__ = "Jayaram Kancherla"
__copyright__ = "Jayaram Kancherla"
__license__ = "MIT"


[docs] class CellArrayBaseFrame(ABC): """Abstract base class for TileDB DataFrame operations."""
[docs] def __init__( self, uri: Optional[str] = None, tiledb_array_obj: Optional[tiledb.Array] = None, mode: Optional[Literal["r", "w", "d", "m"]] = None, config_or_context: Optional[Union[tiledb.Config, tiledb.Ctx]] = None, validate: bool = True, ): """Initialize the object. Args: uri: URI to the array. Required if 'tiledb_array_obj' is not provided. tiledb_array_obj: Optional, an already opened tiledb object. mode: Open mode ('r', 'w', 'd', 'm'). Defaults to None (auto). config_or_context: TileDB Config or Ctx. validate: Whether to validate the connection. """ self._array_passed_in = False self._opened_array_external = None self._ctx = None if tiledb_array_obj is not None: if not tiledb_array_obj.isopen: raise ValueError("Provided 'tiledb_array_obj' must be open.") self.uri = tiledb_array_obj.uri self._array_passed_in = True self._opened_array_external = tiledb_array_obj if mode is not None and tiledb_array_obj.mode != mode: raise ValueError( f"Provided array mode '{tiledb_array_obj.mode}' does not match requested mode '{mode}'." ) self._mode = tiledb_array_obj.mode self._ctx = tiledb_array_obj.ctx elif uri is not None: self.uri = uri self._mode = mode self._array_passed_in = False if config_or_context is None: self._ctx = None elif isinstance(config_or_context, tiledb.Config): self._ctx = tiledb.Ctx(config_or_context) elif isinstance(config_or_context, tiledb.Ctx): self._ctx = config_or_context else: raise TypeError("'config_or_context' must be a TileDB Config or Ctx object.") else: raise ValueError("Either 'uri' or 'tiledb_array_obj' must be provided.") self._column_names = None self._index_names = None self._shape = None self._nonempty_domain = None self._index = None if validate: self._validate()
def _validate(self): """Validate that the URI points to a valid TileDB array/dataframe.""" with self.open_array(mode="r") as A: if not isinstance(A, (tiledb.Array, tiledb.SparseArray, tiledb.DenseArray)): pass @property def mode(self) -> Optional[str]: """Get current array mode. If an external array is used, this is its open mode.""" if self._array_passed_in and self._opened_array_external: return self._opened_array_external.mode return self._mode @mode.setter def mode(self, value: Optional[str]): """Set array mode for subsequent operations if not using an external array.""" if self._array_passed_in: raise ValueError("Cannot change mode of an externally managed array.") if value is not None and value not in ["r", "w", "m", "d"]: raise ValueError("Mode must be one of: None, 'r', 'w', 'm', 'd'") self._mode = value
[docs] @contextmanager def open_array(self, mode: Optional[str] = None): """Context manager for array operations.""" if self._array_passed_in and self._opened_array_external: if not self._opened_array_external.isopen: try: self._opened_array_external.reopen() except Exception as e: raise tiledb.TileDBError(f"External array closed/cannot reopen: {e}") from e yield self._opened_array_external else: effective_mode = mode if mode is not None else (self.mode or "r") array = tiledb.open(self.uri, mode=effective_mode, ctx=self._ctx) try: yield array finally: array.close()
@property def column_names(self) -> List[str]: """Get attribute/column names of the dataframe.""" if self._column_names is None: with self.open_array(mode="r") as A: self._column_names = [A.schema.attr(i).name for i in range(A.schema.nattr)] return self._column_names @property def index_names(self) -> List[str]: """Get dimension/index names of the dataframe.""" if self._index_names is None: with self.open_array(mode="r") as A: self._index_names = [dim.name for dim in A.schema.domain] return self._index_names @property def index(self) -> pd.DataFrame: """Get index of the dataframe.""" if self._index is None: with self.open_array(mode="r") as A: if A.schema.sparse: try: self._index = pd.DataFrame(A.query(attrs=[])[:]) except Exception as _: warn("Failed to get index values.") self._index = pd.DataFrame() else: self._index = pd.DataFrame() return self._index
[docs] def rownames(self) -> pd.DataFrame: """Alias to :py:meth:`index`.""" return self.index
@property def shape(self) -> Tuple[int, ...]: """Get the shape of the dataframe (rows, columns).""" if self._shape is None: with self.open_array(mode="r") as A: ned = A.nonempty_domain() rows = 0 is_sparse = A.schema.sparse if is_sparse: ned = A.nonempty_domain() if ned: dim0_ned = ned[0] if isinstance(dim0_ned, tuple) and len(dim0_ned) == 2: if isinstance(dim0_ned[0], (int, np.integer, float, np.floating)): rows = int(dim0_ned[1]) - int(dim0_ned[0]) + 1 else: rows = -1 else: rows = -1 else: dom = A.schema.domain if dom.ndim == 1: if not np.issubdtype(dom.dim(0).dtype, np.str_): dmin = int(dom.dim(0).domain[0]) dmax = int(dom.dim(0).domain[1]) rows = dmax - dmin + 1 else: rows = -1 else: try: rows = A.shape[0] except Exception: rows = -1 self._shape = (rows, A.schema.nattr) return self._shape
[docs] def vacuum(self) -> None: tiledb.vacuum(self.uri, ctx=self._ctx)
[docs] def consolidate(self) -> None: tiledb.consolidate(self.uri, ctx=self._ctx) self.vacuum()
[docs] def __getitem__(self, key: Union[slice, str, Tuple[Any, ...]]) -> pd.DataFrame: """ Route slicing/querying to implementation. Note that strings passed with square bracket notation e.g. A["cell001"] are assumed to be queries. If you want to select a row using string indices, use a list of strings e.g. A[["cell001"]]. Args: key: - str: Query condition (e.g., "age > 20") - slice/int: Row selection - tuple: (rows, columns) selection """ row_spec = slice(None) col_spec = None # None implies all columns if isinstance(key, str): return self._read_query(condition=key, columns=None) if not isinstance(key, tuple): key = (key,) if len(key) >= 1: row_spec = key[0] if len(key) >= 2: col_spec = key[1] if isinstance(col_spec, str): col_spec = [col_spec] if isinstance(col_spec, (slice, range)): if isinstance(col_spec, range): col_spec = slice(col_spec.start, col_spec.stop, col_spec.step) col_spec = self.column_names[col_spec] if isinstance(row_spec, str): return self._read_query(condition=row_spec, columns=col_spec) return self._read_slice(row_spec, col_spec)
@abstractmethod def _read_slice(self, rows: Any, cols: Optional[List[str]]) -> pd.DataFrame: pass @abstractmethod def _read_query(self, condition: str, columns: Optional[List[str]]) -> pd.DataFrame: pass
[docs] @abstractmethod def write_batch(self, data: pd.DataFrame, **kwargs) -> None: """Write or append data to the frame.""" pass