Skip to content

Craft Module

craft

Classes

CraftModel

CraftModel(config)

Bases: PreTrainedModel

CRAFT Model with a masked language modeling head.

Functions
forward
forward(gene_input_ids=None, gene_attention_mask=None, gene_token_type_ids=None, atac_input_ids=None, atac_attention_mask=None, return_dict=None, **kwargs)

Forward pass through the CRAFT model.

Args: gene_input_ids (torch.Tensor): Input IDs for the gene encoder. gene_attention_mask (torch.Tensor): Attention mask for the gene encoder. gene_token_type_ids (torch.Tensor): Token type IDs for the gene encoder. atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder. return_dict (bool): Whether to return a dictionary or tuple. Returns: torch.Tensor: The logits from the CRAFT model.

CraftForContrastiveLearning

CraftForContrastiveLearning(config)

Bases: PreTrainedModel

CRAFT model for contrastive learning between gene and ATAC embeddings. While this looks redudant with the CraftModel, it makes it easier to use the model for further tasks like gene activity prediction without needing to instantiate the CraftModel directly.

Mostly used for pre-training tasks

Functions
forward
forward(gene_input_ids, gene_attention_mask, gene_token_type_ids, atac_input_ids, atac_attention_mask)

Forward pass through the model.

Args: gene_input_ids (torch.Tensor): Input IDs for the gene encoder. gene_attention_mask (torch.Tensor): Attention mask for the gene encoder. gene_token_type_ids (torch.Tensor): Token type IDs for the gene encoder. atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder.

Returns: CraftOutput: The output of the CRAFT model containing loss and logits.

CraftForGeneActivityPrediction

CraftForGeneActivityPrediction(config)

Bases: PreTrainedModel

CRAFT model for gene activity prediction.

Functions
forward
forward(atac_input_ids, atac_attention_mask, gene_activity=None, return_dict=True)

Forward pass through the model.

Args: atac_input_ids (torch.Tensor): Input IDs for the ATAC encoder. atac_attention_mask (torch.Tensor): Attention mask for the ATAC encoder. gene_activity (Optional[torch.Tensor]): Optional gene activity scores for computing loss. Returns: torch.Tensor: The predicted gene activity scores.

CraftConfig

CraftConfig(geneformer_config=None, atacformer_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs)

Bases: PretrainedConfig

Configuration for the CRAFT model, a contrastive RNA-ATAC transformer that attempts to learn leverage Geneformer and Atacformer to learn a joint representation of RNA and ATAC data.

Joint configuration for the CRAFT model.

Args: geneformer_config (GeneformerConfig): Configuration for the Geneformer model. atacformer_config (AtacformerConfig): Configuration for the Atacformer model. projection_dim (int): Dimension of the projection layer. logit_scale_init_value (float): Initial value for the logit scale parameter.

Functions

DataCollatorForCraft

DataCollatorForCraft(gene_pad, atac_pad, gene_max_len=None, atac_max_len=None)

Pads + builds masks for gene/ATAC pairs used by CraftModel.

DataCollatorForCraftGeneActivityPrediction

DataCollatorForCraftGeneActivityPrediction(atac_pad, atac_max_len=None)

Pads + builds masks for ATAC-gene pairs used by CraftForGeneActivityPrediction.

Gene activity is always the same shape for everything, its just a set of floats representing the activity of each gene in the genome, so we don't pad it.