Source code for cellarr_array.dataloaders.sparseloader

from typing import Optional
from warnings import warn

import scipy.sparse as sp
import tiledb
import torch
from torch.utils.data import DataLoader, Dataset

from ..core.sparse import SparseCellArray

__author__ = "Jayaram Kancherla"
__copyright__ = "Jayaram Kancherla"
__license__ = "MIT"


[docs] class SparseArrayDataset(Dataset):
[docs] def __init__( self, array_uri: str, attribute_name: str = "data", num_rows: Optional[int] = None, num_columns: Optional[int] = None, sparse_format=sp.csr_matrix, cellarr_ctx_config: Optional[dict] = None, transform=None, ): """PyTorch Dataset for sparse TileDB arrays accessed via SparseCellArray. Args: array_uri: URI of the TileDB sparse array. attribute_name: Name of the attribute to read from. num_rows: Total number of rows in the dataset. If None, will infer from `array.shape[0]`. num_columns: The number of columns in the dataset. If None, will attempt to infer `from array.shape[1]`. sparse_format: Format to return, defaults to csr_matrix. cellarr_ctx_config: Optional TileDB context configuration dict for CellArray. transform: Optional transform to be applied on a sample. """ self.array_uri = array_uri self.attribute_name = attribute_name self.sparse_format = sparse_format self.cellarr_ctx_config = cellarr_ctx_config self.transform = transform self.cell_array_instance = None if num_rows is not None and num_columns is not None: self._len = num_rows self.num_columns = num_columns else: print(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing sparse array...") init_ctx_config = tiledb.Config(self.cellarr_ctx_config) if self.cellarr_ctx_config else None try: temp_arr = SparseCellArray( uri=self.array_uri, attr=self.attribute_name, config_or_context=init_ctx_config, return_sparse=True, sparse_format=self.sparse_format, ) if temp_arr.ndim == 1: self._len = num_rows if num_rows is not None else temp_arr.shape[0] self.num_columns = 1 elif temp_arr.ndim == 2: self._len = num_rows if num_rows is not None else temp_arr.shape[0] self.num_columns = num_columns if num_columns is not None else temp_arr.shape[1] else: raise ValueError(f"Array ndim {temp_arr.ndim} not supported.") print(f"Dataset '{array_uri}': Inferred sparse shape. Rows: {self._len}, Columns: {self.num_columns}") except Exception as e: if num_rows is None or num_columns is None: raise ValueError( f"num_rows and num_columns must be provided if inferring sparse array shape fails for '{array_uri}'. Original error: {e}" ) from e self._len = num_rows if num_rows is not None else 0 self.num_columns = num_columns if num_columns is not None else 0 warn( f"Falling back to provided or zero dimensions for sparse '{array_uri}' due to inference error: {e}", RuntimeWarning, ) if self.num_columns is None or self.num_columns <= 0 and self._len > 0: raise ValueError( f"num_columns ({self.num_columns}) is invalid or could not be determined for sparse array '{array_uri}'." ) if self._len == 0: warn(f"SparseDataset for '{array_uri}' has length 0.", RuntimeWarning)
def _init_worker_state(self): if self.cell_array_instance is None: ctx = tiledb.Ctx(self.cellarr_ctx_config) if self.cellarr_ctx_config else None self.cell_array_instance = SparseCellArray( uri=self.array_uri, attr=self.attribute_name, mode="r", config_or_context=ctx, return_sparse=True, sparse_coerce=self.sparse_format, )
[docs] def __len__(self): return self._len
[docs] def __getitem__(self, idx): if not 0 <= idx < self._len: raise IndexError(f"Index {idx} out of bounds for dataset of length {self._len}.") self._init_worker_state() item_slice = (slice(idx, idx + 1), slice(None)) scipy_sparse_sample = self.cell_array_instance[item_slice] if self.transform: # e.g., convert to COO for easier collation scipy_sparse_sample = self.transform(scipy_sparse_sample) if not isinstance(scipy_sparse_sample, sp.coo_matrix): scipy_sparse_sample = scipy_sparse_sample.tocoo() return scipy_sparse_sample
[docs] def sparse_coo_collate_fn(batch): """Custom collate_fn for a batch of SciPy COO sparse matrices. Converts them into a single batched PyTorch sparse COO tensor. Each item in 'batch' is a SciPy coo_matrix representing one sample. """ all_data = [] all_row_indices = [] all_col_indices = [] for i, scipy_coo in enumerate(batch): if scipy_coo.nnz > 0: all_data.append(torch.from_numpy(scipy_coo.data)) all_row_indices.append(torch.full_like(torch.from_numpy(scipy_coo.row), fill_value=i, dtype=torch.long)) all_col_indices.append(torch.from_numpy(scipy_coo.col)) if not all_data: num_columns = batch[0].shape[1] if batch else 0 return torch.sparse_coo_tensor(torch.empty((2, 0), dtype=torch.long), torch.empty(0), (len(batch), num_columns)) data_cat = torch.cat(all_data) row_indices_cat = torch.cat(all_row_indices) col_indices_cat = torch.cat(all_col_indices) indices = torch.stack([row_indices_cat, col_indices_cat], dim=0) num_columns = batch[0].shape[1] batch_size = len(batch) sparse_tensor = torch.sparse_coo_tensor(indices, data_cat, (batch_size, num_columns)) return sparse_tensor
[docs] def construct_sparse_array_dataloader( array_uri: str, attribute_name: str = "data", num_rows: Optional[int] = None, num_columns: Optional[int] = None, batch_size: int = 1000, num_workers_dl: int = 2, ) -> DataLoader: """Construct an instance of `SparseArrayDataset` with PyTorch DataLoader. Args: array_uri: URI of the TileDB array. attribute_name: Name of the attribute to read from. num_rows: The total number of rows in the TileDB array. num_columns: The total number of columns in the TileDB array. batch_size: Number of random samples per batch generated by the dataset. num_workers_dl: Number of worker processes for the DataLoader. """ tiledb_ctx_config = { "sm.tile_cache_size": 1000 * 1024**2, "sm.num_reader_threads": 4, } dataset = SparseArrayDataset( array_uri=array_uri, attribute_name=attribute_name, num_rows=num_rows, num_columns=num_columns, sparse_format=sp.coo_matrix, cellarr_ctx_config=tiledb_ctx_config, ) if len(dataset) == 0: print("Dataset is empty, cannot create DataLoader.") return dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers_dl, collate_fn=sparse_coo_collate_fn, pin_memory=False, persistent_workers=True if num_workers_dl > 0 else False, ) return dataloader