Craft Module
craft
Classes
CraftModel
CraftModel(config)
Bases: PreTrainedModel
CRAFT Model with a masked language modeling head.
Functions
forward
forward(gene_input_ids=None, gene_attention_mask=None, gene_token_type_ids=None, atac_input_ids=None, atac_attention_mask=None, return_dict=None, **kwargs)
Forward pass through the CRAFT model.
Args: gene_input_ids (torch.Tensor): Input IDs for the gene encoder. gene_attention_mask (torch.Tensor): Attention mask for the gene encoder. gene_token_type_ids (torch.Tensor): Token type IDs for the gene encoder. atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder. return_dict (bool): Whether to return a dictionary or tuple. Returns: torch.Tensor: The logits from the CRAFT model.
CraftForContrastiveLearning
CraftForContrastiveLearning(config)
Bases: PreTrainedModel
CRAFT model for contrastive learning between gene and ATAC embeddings. While this looks redudant with the CraftModel, it makes it easier to use the model for further tasks like gene activity prediction without needing to instantiate the CraftModel directly.
Mostly used for pre-training tasks
Functions
forward
forward(gene_input_ids, gene_attention_mask, gene_token_type_ids, atac_input_ids, atac_attention_mask)
Forward pass through the model.
Args: gene_input_ids (torch.Tensor): Input IDs for the gene encoder. gene_attention_mask (torch.Tensor): Attention mask for the gene encoder. gene_token_type_ids (torch.Tensor): Token type IDs for the gene encoder. atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder.
Returns: CraftOutput: The output of the CRAFT model containing loss and logits.
CraftForGeneActivityPrediction
CraftForGeneActivityPrediction(config)
Bases: PreTrainedModel
CRAFT model for gene activity prediction.
Functions
forward
forward(atac_input_ids, atac_attention_mask, gene_activity=None, return_dict=True)
Forward pass through the model.
Args: atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder. gene_activity (Optional[torch.Tensor]): Optional gene activity scores for computing loss. Returns: torch.Tensor: The predicted gene activity scores.
CraftConfig
CraftConfig(geneformer_config=None, atacformer_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs)
Bases: PretrainedConfig
Configuration for the CRAFT model, a contrastive RNA-ATAC transformer that attempts to learn leverage Geneformer and Atacformer to learn a joint representation of RNA and ATAC data.
Joint configuration for the CRAFT model.
Args: geneformer_config (GeneformerConfig): Configuration for the Geneformer model. atacformer_config (AtacformerConfig): Configuration for the Atacformer model. projection_dim (int): Dimension of the projection layer. logit_scale_init_value (float): Initial value for the logit scale parameter.
Functions
DataCollatorForCraft
DataCollatorForCraft(gene_pad, atac_pad, gene_max_len=None, atac_max_len=None)
Pads + builds masks for gene/ATAC pairs used by CraftModel.
DataCollatorForCraftGeneActivityPrediction
DataCollatorForCraftGeneActivityPrediction(atac_pad, atac_max_len=None)
Pads + builds masks for ATAC-gene pairs used by CraftForGeneActivityPrediction.
Gene activity is always the same shape for everything, its just a set of floats representing the activity of each gene in the genome, so we don't pad it.