Source code for cellarr_array.dataloaders.denseloader

from typing import Optional
from warnings import warn

import numpy as np
import tiledb
import torch
from torch.utils.data import DataLoader, Dataset

from ..core.dense import DenseCellArray

__author__ = "Jayaram Kancherla"
__copyright__ = "Jayaram Kancherla"
__license__ = "MIT"


[docs] class DenseArrayDataset(Dataset):
[docs] def __init__( self, array_uri: str, attribute_name: str = "data", num_rows: Optional[int] = None, num_columns: Optional[int] = None, cellarr_ctx_config: Optional[dict] = None, transform=None, ): """PyTorch Dataset for dense TileDB arrays accessed via DenseCellArray. Args: array_uri: URI of the TileDB dense array. attribute_name: Name of the attribute to read from. num_rows: Total number of rows in the dataset. If None, will infer from `array.shape[0]`. num_columns: The number of columns in the dataset. If None, will attempt to infer `from array.shape[1]`. cellarr_ctx_config: Optional TileDB context configuration dict for CellArray. transform: Optional transform to be applied on a sample. """ self.array_uri = array_uri self.attribute_name = attribute_name self.cellarr_ctx_config = cellarr_ctx_config self.transform = transform self.cell_array_instance = None if num_rows is not None and num_columns is not None: self._len = num_rows self.num_columns = num_columns else: # Infer the array shape print(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing array...") init_ctx_config = tiledb.Config(self.cellarr_ctx_config) if self.cellarr_ctx_config else None try: temp_arr = DenseCellArray( uri=self.array_uri, attr=self.attribute_name, config_or_context=init_ctx_config ) if temp_arr.ndim == 1: self._len = num_rows if num_rows is not None else temp_arr.shape[0] self.num_columns = 1 elif temp_arr.ndim == 2: self._len = num_rows if num_rows is not None else temp_arr.shape[0] self.num_columns = num_columns if num_columns is not None else temp_arr.shape[1] else: raise ValueError(f"Array ndim {temp_arr.ndim} not supported.") print(f"Dataset '{array_uri}': Inferred shape. Rows: {self._len}, Columns: {self.num_columns}") except Exception as e: if num_rows is None or num_columns is None: raise ValueError( f"num_rows and num_columns must be provided if inferring array shape fails for '{array_uri}'." ) from e self._len = num_rows self.feature_dim = num_columns warn( f"Falling back to provided or zero dimensions for '{array_uri}' due to inference error: {e}", RuntimeWarning, ) if self.num_columns is None or self.num_columns <= 0 and self._len > 0: # Check if num_columns is valid raise ValueError( f"num_columns ({self.num_columns}) is invalid or could not be determined for array '{array_uri}'." ) if self._len == 0: warn(f"Dataset for '{array_uri}' has length 0.", RuntimeWarning)
def _init_worker_state(self): """Initializes the DenseCellArray instance for the current worker.""" if self.cell_array_instance is None: ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None self.cell_array_instance = DenseCellArray( uri=self.array_uri, attr=self.attribute_name, mode="r", config_or_context=ctx ) # Sanity check: worker's shape against dataset's established dims # if self.cell_array_instance.shape[0] != self._len or \ # (self.cell_array_instance.ndim > 1 and self.cell_array_instance.shape[1] != self.feature_dim) or \ # (self.cell_array_instance.ndim == 1 and self.feature_dim != 1) : # print(f"Warning: Worker for {self.array_uri} sees shape {self.cell_array_instance.shape} " # f"but dataset initialized with len={self._len}, feat={self.feature_dim}")
[docs] def __len__(self): return self._len
[docs] def __getitem__(self, idx): if not 0 <= idx < self._len: raise IndexError(f"Index {idx} out of bounds for dataset of length {self._len}.") self._init_worker_state() if self.cell_array_instance.ndim == 2: item_slice = (slice(idx, idx + 1), slice(None)) elif self.cell_array_instance.ndim == 1: item_slice = slice(idx, idx + 1) else: raise ValueError(f"Array ndim {self.cell_array_instance.ndim} not supported in __getitem__.") sample_data_np = self.cell_array_instance[item_slice] if sample_data_np.ndim == 2 and sample_data_np.shape[0] == 1: sample_data_np = sample_data_np.squeeze(0) elif sample_data_np.ndim == 1 and sample_data_np.shape[0] == 1 and self.feature_dim == 1: pass elif sample_data_np.ndim == 0 and self.feature_dim == 1: sample_data_np = np.array([sample_data_np]) if self.transform: sample_data_np = self.transform(sample_data_np) return torch.from_numpy(sample_data_np)
[docs] def construct_dense_array_dataloader( array_uri: str, attribute_name: str = "data", num_rows: Optional[int] = None, num_columns: Optional[int] = None, batch_size: int = 1000, num_workers_dl: int = 2, ) -> DataLoader: """Construct an instance of `DenseArrayDataset` with PyTorch DataLoader. Args: array_uri: URI of the TileDB array. attribute_name: Name of the attribute to read from. num_rows: The total number of rows in the TileDB array. num_columns: The total number of columns in the TileDB array. batch_size: Number of random samples per batch generated by the dataset. num_workers_dl: Number of worker processes for the DataLoader. """ tiledb_ctx_config = { "sm.tile_cache_size": 1000 * 1024**2, # 1000 MB tile cache per worker "sm.num_reader_threads": 4, } dataset = DenseArrayDataset( array_uri=array_uri, attribute_name=attribute_name, num_rows=num_rows, num_columns=num_columns, cellarr_ctx_config=tiledb_ctx_config, ) if len(dataset) == 0: print("Dataset is empty, cannot create DataLoader.") return dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers_dl, pin_memory=True, prefetch_factor=2, persistent_workers=True if num_workers_dl > 0 else False, ) return dataloader