Skip to content

Atacformer Module

atacformer

Classes

TrainingTokenizer

TrainingTokenizer(*args, **kwargs)

Bases: Tokenizer, PreTrainedTokenizerBase

A special training tokenizer. This class is a subclass of both our Tokenizer and PreTrainedTokenizerBase. This is because the data collator requires a collator that is both a Tokenizer and a PreTrainedTokenizerBase. This is a workaround to make the code work with our Tokenizer.

Attributes
all_special_ids property
all_special_ids

Returns a list of all special token ids.

AtacformerModel

AtacformerModel(config)

Bases: AtacformerPreTrainedModel

atacformer model with a simple embedding layer that skips positional encoding.

Functions
forward
forward(input_ids=None, attention_mask=None, return_dict=None, **kwargs)

Forward pass of the model.

Args: input_ids (torch.LongTensor): Input tensor of shape (batch_size, sequence_length). attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices. return_dict (bool, optional): Whether to return the outputs as a dict or a tuple. Raises: ValueError: If input_ids is not provided. Returns: torch.Tensor: The output of the model. It will be a tensor of shape (batch_size, sequence_length, hidden_size).

AtacformerForMaskedLM

AtacformerForMaskedLM(config)

Bases: EncodeTokenizedCellsMixin, AtacformerPreTrainedModel

Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, return_dict=None)

Forward pass of the model.

Args: input_ids (torch.LongTensor): Input tensor of shape (batch_size, sequence_length). attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices. labels (torch.LongTensor, optional): Labels for masked language modeling. return_dict (bool, optional): Whether to return the outputs as a dict or a tuple. Returns: MaskedLMOutput: The output of the model.

AtacformerForReplacedTokenDetection

AtacformerForReplacedTokenDetection(config)

Bases: EncodeTokenizedCellsMixin, AtacformerPreTrainedModel

Atacformer model for replaced token detection. This model uses the ELECTRA framework to train a discriminator (this model) to detect replaced tokens.

https://arxiv.org/abs/2003.10555

Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, return_dict=None)

Forward pass of the model.

Args: input_ids (torch.LongTensor): Input tensor of shape (batch_size, sequence_length). attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices. labels (torch.LongTensor, optional): Labels for masked language modeling. return_dict (bool, optional): Whether to return the outputs as a dict or a tuple. Returns: TokenClassifierOutput: The output of the model.

AtacformerForCellClustering

AtacformerForCellClustering(config)

Bases: EncodeTokenizedCellsMixin, AtacformerPreTrainedModel

Atacformer model for cell clustering. It follows a similar learning framework to SentenceBERT (SBERT), where the model is trained to minimize the distance between positive pairs and maximize the distance between negative pairs.

AtacformerForUnsupervisedBatchCorrection

AtacformerForUnsupervisedBatchCorrection(config)

Bases: EncodeTokenizedCellsMixin, AtacformerPreTrainedModel

Atacformer model for batch correction. It follows a similar learning framework to the one used in domain adaptation, where the model is trained to correct batch effects in the embeddings.

Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, batch_labels=None, return_dict=None)

Forward pass of the model.

Args: input_ids (torch.LongTensor): Input tensor of shape (batch_size, sequence_length). attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices. labels (torch.Tensor, optional): Labels for masked language modeling (ELECTRA). batch_labels (torch.Tensor, optional): Labels for batch prediction. Should be of shape (batch_size,). return_dict (bool, optional): Whether to return the outputs as a dict or a tuple. Returns: BaseModelOutput: The output of the model.

AtacformerConfig

AtacformerConfig(use_pos_embeddings=True, vocab_size=890711, max_position_embeddings=8192, hidden_size=384, intermediate_size=1536, num_hidden_layers=6, num_attention_heads=8, pad_token_id=890705, eos_token_id=890708, bos_token_id=890709, cls_token_id=890707, sep_token_id=890710, sparse_prediction=True, norm_eps=1e-05, embedding_dropout=0.0, initializer_range=0.02, initializer_cutoff_factor=2.0, tie_word_embeddings=True, num_batches=None, lambda_adv=1.0, grl_alpha=1.0, bc_unfreeze_last_n_layers=2, **kwargs)

Bases: PretrainedConfig

This is the configuration class to store the configuration of an AtacformerModel. it inherits from [ModernBertConfig] and expands it for Atacformer specific settings. instantiating a configuration with the defaults will yield a similar configuration to that of the modernbert base configuration.

Args: use_pos_embeddings (bool, optional, defaults to True): whether to use positional embeddings. vocab_size (int, optional, defaults to 890711): vocabulary size tailored for genomic regions. max_position_embeddings (int, optional, defaults to 8192): the maximum sequence length that this model might ever be used with. hidden_size (int, optional, defaults to 384): the size of the encoder layers and the pooler layer. intermediate_size (int, optional, defaults to 1536): the size of the "intermediate" (often named feed-forward) layer in the transformer. num_hidden_layers (int, optional, defaults to 6): the number of hidden layers in the Transformer encoder. num_attention_heads (int, optional, defaults to 8): the number of attention heads in each attention layer. pad_token_id (int, optional, defaults to 890705): the id of the token used for padding. eos_token_id (int, optional, defaults to 890708): the id of the token used for the end of a sequence. bos_token_id (int, optional, defaults to 890709): the id of the token used for the beginning of a sequence. cls_token_id (int, optional, defaults to 890707): the id of the token used for classification tasks. sep_token_id (int, optional, defaults to 890710): the id of the token used to separate segments in a sequence. sparse_prediction (bool, optional, defaults to True): whether to use sparse prediction for the output layer. norm_eps (float, optional, defaults to 1e-5): the epsilon value used for layer normalization. embedding_dropout (float, optional, defaults to 0.0): the dropout probability for the embedding layer. initializer_range (float, optional, defaults to 0.02): the standard deviation of the truncated_normal_initializer for initializing all weight matrices. initializer_cutoff_factor (float, optional, defaults to 2.0): the cutoff factor for the truncated normal initializer. tie_word_embeddings (bool, optional, defaults to True): whether to tie the word embeddings with the output layer. num_batches (int, optional, defaults to 1): the number of batches when doing batch correction training. lambda_adv: float = 1.0, the weight for the adversarial loss. grl_alpha: float = 1.0, the alpha value for the gradient reversal layer. bc_unfreeze_last_n_layers (int, optional, defaults to 0): the number of last layers to unfreeze during training for batch correction. kwargs: (additional keyword arguments, optional): additional configuration parameters.

DataCollatorForReplacedTokenDetection

DataCollatorForReplacedTokenDetection(tokenizer, mlm_probability=0.15, seed=None)

Bases: WandbMixin, DataCollatorForLanguageModeling

Like HF’s MLM collator but: • never uses [MASK] • picks replacement tokens from a user-supplied distribution • returns per-token 0/1 labels for ELECTRA-style discrimination

Simple data collator for ELECTRA-style token replacement detection. Args: tokenizer (TrainingTokenizer): The tokenizer to use. vocab_counts (torch.Tensor | None): 1-D tensor, size == vocab, log-probs OR probs mlm_probability (float): Probability of masking a token. seed (int | None): Random seed for reproducibility.

Functions

DataCollatorForTripletLoss

DataCollatorForTripletLoss(tokenizer, max_position_embeddings=None)

A simple data collator for triplet loss to fine-tune Atacformer for cell-type clustering

ModelParameterChangeCallback

ModelParameterChangeCallback(initial_params)

Bases: WandbMixin, TrainerCallback

A callback to log the changes in model parameters after training.

Functions
on_log
on_log(args, state, control, **kwargs)

Log the changes in model parameters after training.

AdjustedRandIndexCallback

AdjustedRandIndexCallback(input_ids, cell_type_labels, pad_token_id, batch_size=128, log_every_n_steps=500)

Bases: WandbMixin, TrainerCallback

A callback to log the adjusted Rand index (ARI) during training.

Functions
on_log
on_log(args, state, control, **kwargs)

Log the adjusted Rand index (ARI) during training.

Functions

freeze_except_last_n

freeze_except_last_n(model, n=2)

Freeze all parameters except the last n layers of the encoder. Also keeps all layer norms trainable for stability.

Args: model (AtacformerModel): The model to freeze. n (int): The number of last layers to keep trainable.

patch_atacformer_model_for_mps

patch_atacformer_model_for_mps(model)

Look for any TransformerEncoder layers in the model and patch them by setting enable_nested_tensor to False and setting use_nested_tensor to False.

Args: model (nn.Module): The model to patch.

get_git_hash

get_git_hash()

Get the current git hash of the repository.

get_decaying_cosine_with_hard_restarts_schedule_with_warmup

get_decaying_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1, last_epoch=-1)

Very similar to huggingfaces built-in cosine with restarts, however the amplitude slowly decreases so that the "kick ups" are less aggressive.

Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.

Args: optimizer ([~torch.optim.Optimizer]): The optimizer for which to schedule the learning rate. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. num_cycles (int, optional, defaults to 1): The number of hard restarts to use. last_epoch (int, optional, defaults to -1): The index of the last epoch when resuming training.

Return: torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.

tokenize_anndata

tokenize_anndata(adata, tokenizer)

Tokenize an AnnData object. This is more involved, so it gets its own function.

Args: adata (sc.AnnData): The AnnData object to tokenize. tokenizer (Tokenizer): The tokenizer to use.