[docs]def__init__(self,array_uri:str,attribute_name:str="data",num_rows:Optional[int]=None,num_columns:Optional[int]=None,cellarr_ctx_config:Optional[dict]=None,transform=None,):"""PyTorch Dataset for dense TileDB arrays accessed via DenseCellArray. Args: array_uri: URI of the TileDB dense array. attribute_name: Name of the attribute to read from. num_rows: Total number of rows in the dataset. If None, will infer from `array.shape[0]`. num_columns: The number of columns in the dataset. If None, will attempt to infer `from array.shape[1]`. cellarr_ctx_config: Optional TileDB context configuration dict for CellArray. transform: Optional transform to be applied on a sample. """self.array_uri=array_uriself.attribute_name=attribute_nameself.cellarr_ctx_config=cellarr_ctx_configself.transform=transformself.cell_array_instance=Noneifnum_rowsisnotNoneandnum_columnsisnotNone:self._len=num_rowsself.num_columns=num_columnselse:# Infer the array shapeprint(f"Dataset '{array_uri}': num_rows or num_columns not provided. Probing array...")init_ctx_config=tiledb.Config(self.cellarr_ctx_config)ifself.cellarr_ctx_configelseNonetry:temp_arr=DenseCellArray(uri=self.array_uri,attr=self.attribute_name,config_or_context=init_ctx_config)iftemp_arr.ndim==1:self._len=num_rowsifnum_rowsisnotNoneelsetemp_arr.shape[0]self.num_columns=1eliftemp_arr.ndim==2:self._len=num_rowsifnum_rowsisnotNoneelsetemp_arr.shape[0]self.num_columns=num_columnsifnum_columnsisnotNoneelsetemp_arr.shape[1]else:raiseValueError(f"Array ndim {temp_arr.ndim} not supported.")print(f"Dataset '{array_uri}': Inferred shape. Rows: {self._len}, Columns: {self.num_columns}")exceptExceptionase:ifnum_rowsisNoneornum_columnsisNone:raiseValueError(f"num_rows and num_columns must be provided if inferring array shape fails for '{array_uri}'.")fromeself._len=num_rowsself.feature_dim=num_columnswarn(f"Falling back to provided or zero dimensions for '{array_uri}' due to inference error: {e}",RuntimeWarning,)ifself.num_columnsisNoneorself.num_columns<=0andself._len>0:# Check if num_columns is validraiseValueError(f"num_columns ({self.num_columns}) is invalid or could not be determined for array '{array_uri}'.")ifself._len==0:warn(f"Dataset for '{array_uri}' has length 0.",RuntimeWarning)
def_init_worker_state(self):"""Initializes the DenseCellArray instance for the current worker."""ifself.cell_array_instanceisNone:ctx=tiledb.Ctx(self.cellarr_ctx_config)ifself.cellarr_ctx_configelseNoneself.cell_array_instance=DenseCellArray(uri=self.array_uri,attr=self.attribute_name,mode="r",config_or_context=ctx)# Sanity check: worker's shape against dataset's established dims# if self.cell_array_instance.shape[0] != self._len or \# (self.cell_array_instance.ndim > 1 and self.cell_array_instance.shape[1] != self.feature_dim) or \# (self.cell_array_instance.ndim == 1 and self.feature_dim != 1) :# print(f"Warning: Worker for {self.array_uri} sees shape {self.cell_array_instance.shape} "# f"but dataset initialized with len={self._len}, feat={self.feature_dim}")
[docs]def__getitem__(self,idx):ifnot0<=idx<self._len:raiseIndexError(f"Index {idx} out of bounds for dataset of length {self._len}.")self._init_worker_state()ifself.cell_array_instance.ndim==2:item_slice=(slice(idx,idx+1),slice(None))elifself.cell_array_instance.ndim==1:item_slice=slice(idx,idx+1)else:raiseValueError(f"Array ndim {self.cell_array_instance.ndim} not supported in __getitem__.")sample_data_np=self.cell_array_instance[item_slice]ifsample_data_np.ndim==2andsample_data_np.shape[0]==1:sample_data_np=sample_data_np.squeeze(0)elifsample_data_np.ndim==1andsample_data_np.shape[0]==1andself.feature_dim==1:passelifsample_data_np.ndim==0andself.feature_dim==1:sample_data_np=np.array([sample_data_np])ifself.transform:sample_data_np=self.transform(sample_data_np)returntorch.from_numpy(sample_data_np)
[docs]defconstruct_dense_array_dataloader(array_uri:str,attribute_name:str="data",num_rows:Optional[int]=None,num_columns:Optional[int]=None,batch_size:int=1000,num_workers_dl:int=2,)->DataLoader:"""Construct an instance of `DenseArrayDataset` with PyTorch DataLoader. Args: array_uri: URI of the TileDB array. attribute_name: Name of the attribute to read from. num_rows: The total number of rows in the TileDB array. num_columns: The total number of columns in the TileDB array. batch_size: Number of random samples per batch generated by the dataset. num_workers_dl: Number of worker processes for the DataLoader. """tiledb_ctx_config={"sm.tile_cache_size":1000*1024**2,# 1000 MB tile cache per worker"sm.num_reader_threads":4,}dataset=DenseArrayDataset(array_uri=array_uri,attribute_name=attribute_name,num_rows=num_rows,num_columns=num_columns,cellarr_ctx_config=tiledb_ctx_config,)iflen(dataset)==0:print("Dataset is empty, cannot create DataLoader.")returndataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers_dl,pin_memory=True,prefetch_factor=2,persistent_workers=Trueifnum_workers_dl>0elseFalse,)returndataloader