Atacformer Module
atacformer
Classes
TrainingTokenizer
TrainingTokenizer(*args, **kwargs)
Bases: Tokenizer
, PreTrainedTokenizerBase
A special training tokenizer. This class is a subclass of both our Tokenizer and PreTrainedTokenizerBase. This is because the data collator requires a collator that is both a Tokenizer and a PreTrainedTokenizerBase. This is a workaround to make the code work with our Tokenizer.
Attributes
all_special_ids
property
all_special_ids
Returns a list of all special token ids.
AtacformerModel
AtacformerModel(config)
Bases: AtacformerPreTrainedModel
atacformer model with a simple embedding layer that skips positional encoding.
Functions
forward
forward(input_ids=None, attention_mask=None, return_dict=None, **kwargs)
Forward pass of the model.
Args:
input_ids (torch.LongTensor
):
Input tensor of shape (batch_size, sequence_length).
attention_mask (torch.Tensor
, optional):
Mask to avoid performing attention on padding token indices.
return_dict (bool
, optional):
Whether to return the outputs as a dict or a tuple.
Raises:
ValueError: If input_ids
is not provided.
Returns:
torch.Tensor
: The output of the model. It will be a tensor of shape (batch_size, sequence_length, hidden_size).
AtacformerForMaskedLM
AtacformerForMaskedLM(config)
Bases: EncodeTokenizedCellsMixin
, AtacformerPreTrainedModel
Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, return_dict=None)
Forward pass of the model.
Args:
input_ids (torch.LongTensor
):
Input tensor of shape (batch_size, sequence_length).
attention_mask (torch.Tensor
, optional):
Mask to avoid performing attention on padding token indices.
labels (torch.LongTensor
, optional):
Labels for masked language modeling.
return_dict (bool
, optional):
Whether to return the outputs as a dict or a tuple.
Returns:
MaskedLMOutput
: The output of the model.
AtacformerForReplacedTokenDetection
AtacformerForReplacedTokenDetection(config)
Bases: EncodeTokenizedCellsMixin
, AtacformerPreTrainedModel
Atacformer model for replaced token detection. This model uses the ELECTRA framework to train a discriminator (this model) to detect replaced tokens.
https://arxiv.org/abs/2003.10555
Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, return_dict=None)
Forward pass of the model.
Args:
input_ids (torch.LongTensor
):
Input tensor of shape (batch_size, sequence_length).
attention_mask (torch.Tensor
, optional):
Mask to avoid performing attention on padding token indices.
labels (torch.LongTensor
, optional):
Labels for masked language modeling.
return_dict (bool
, optional):
Whether to return the outputs as a dict or a tuple.
Returns:
TokenClassifierOutput
: The output of the model.
AtacformerForCellClustering
AtacformerForCellClustering(config)
Bases: EncodeTokenizedCellsMixin
, AtacformerPreTrainedModel
Atacformer model for cell clustering. It follows a similar learning framework to SentenceBERT (SBERT), where the model is trained to minimize the distance between positive pairs and maximize the distance between negative pairs.
AtacformerForUnsupervisedBatchCorrection
AtacformerForUnsupervisedBatchCorrection(config)
Bases: EncodeTokenizedCellsMixin
, AtacformerPreTrainedModel
Atacformer model for batch correction. It follows a similar learning framework to the one used in domain adaptation, where the model is trained to correct batch effects in the embeddings.
Functions
forward
forward(input_ids=None, attention_mask=None, labels=None, batch_labels=None, return_dict=None)
Forward pass of the model.
Args:
input_ids (torch.LongTensor
):
Input tensor of shape (batch_size, sequence_length).
attention_mask (torch.Tensor
, optional):
Mask to avoid performing attention on padding token indices.
labels (torch.Tensor
, optional):
Labels for masked language modeling (ELECTRA).
batch_labels (torch.Tensor
, optional):
Labels for batch prediction. Should be of shape (batch_size,).
return_dict (bool
, optional):
Whether to return the outputs as a dict or a tuple.
Returns:
BaseModelOutput
: The output of the model.
AtacformerConfig
AtacformerConfig(use_pos_embeddings=True, vocab_size=890711, max_position_embeddings=8192, hidden_size=384, intermediate_size=1536, num_hidden_layers=6, num_attention_heads=8, pad_token_id=890705, eos_token_id=890708, bos_token_id=890709, cls_token_id=890707, sep_token_id=890710, sparse_prediction=True, norm_eps=1e-05, embedding_dropout=0.0, initializer_range=0.02, initializer_cutoff_factor=2.0, tie_word_embeddings=True, num_batches=None, lambda_adv=1.0, grl_alpha=1.0, bc_unfreeze_last_n_layers=2, **kwargs)
Bases: PretrainedConfig
This is the configuration class to store the configuration of an AtacformerModel
.
it inherits from [ModernBertConfig
] and expands it for Atacformer specific settings.
instantiating a configuration with the defaults will yield a similar configuration to that of the
modernbert base configuration.
Args:
use_pos_embeddings (bool
, optional, defaults to True
):
whether to use positional embeddings.
vocab_size (int
, optional, defaults to 890711):
vocabulary size tailored for genomic regions.
max_position_embeddings (int
, optional, defaults to 8192):
the maximum sequence length that this model might ever be used with.
hidden_size (int
, optional, defaults to 384):
the size of the encoder layers and the pooler layer.
intermediate_size (int
, optional, defaults to 1536):
the size of the "intermediate" (often named feed-forward) layer in the transformer.
num_hidden_layers (int
, optional, defaults to 6):
the number of hidden layers in the Transformer encoder.
num_attention_heads (int
, optional, defaults to 8):
the number of attention heads in each attention layer.
pad_token_id (int
, optional, defaults to 890705):
the id of the token used for padding.
eos_token_id (int
, optional, defaults to 890708):
the id of the token used for the end of a sequence.
bos_token_id (int
, optional, defaults to 890709):
the id of the token used for the beginning of a sequence.
cls_token_id (int
, optional, defaults to 890707):
the id of the token used for classification tasks.
sep_token_id (int
, optional, defaults to 890710):
the id of the token used to separate segments in a sequence.
sparse_prediction (bool
, optional, defaults to True
):
whether to use sparse prediction for the output layer.
norm_eps (float
, optional, defaults to 1e-5):
the epsilon value used for layer normalization.
embedding_dropout (float
, optional, defaults to 0.0):
the dropout probability for the embedding layer.
initializer_range (float
, optional, defaults to 0.02):
the standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_cutoff_factor (float
, optional, defaults to 2.0):
the cutoff factor for the truncated normal initializer.
tie_word_embeddings (bool
, optional, defaults to True
):
whether to tie the word embeddings with the output layer.
num_batches (int
, optional, defaults to 1):
the number of batches when doing batch correction training.
lambda_adv: float = 1.0,
the weight for the adversarial loss.
grl_alpha: float = 1.0,
the alpha value for the gradient reversal layer.
bc_unfreeze_last_n_layers (int
, optional, defaults to 0):
the number of last layers to unfreeze during training for batch correction.
kwargs: (additional keyword arguments, optional):
additional configuration parameters.
DataCollatorForReplacedTokenDetection
DataCollatorForReplacedTokenDetection(tokenizer, mlm_probability=0.15, seed=None)
Bases: WandbMixin
, DataCollatorForLanguageModeling
Like HFâs MLM collator but: ⢠never uses [MASK] ⢠picks replacement tokens from a user-supplied distribution ⢠returns per-token 0/1 labels for ELECTRA-style discrimination
Simple data collator for ELECTRA-style token replacement detection. Args: tokenizer (TrainingTokenizer): The tokenizer to use. vocab_counts (torch.Tensor | None): 1-D tensor, size == vocab, log-probs OR probs mlm_probability (float): Probability of masking a token. seed (int | None): Random seed for reproducibility.
Functions
DataCollatorForTripletLoss
DataCollatorForTripletLoss(tokenizer, max_position_embeddings=None)
A simple data collator for triplet loss to fine-tune Atacformer for cell-type clustering
ModelParameterChangeCallback
ModelParameterChangeCallback(initial_params)
Bases: WandbMixin
, TrainerCallback
A callback to log the changes in model parameters after training.
Functions
on_log
on_log(args, state, control, **kwargs)
Log the changes in model parameters after training.
AdjustedRandIndexCallback
AdjustedRandIndexCallback(input_ids, cell_type_labels, pad_token_id, batch_size=128, log_every_n_steps=500)
Bases: WandbMixin
, TrainerCallback
A callback to log the adjusted Rand index (ARI) during training.
Functions
on_log
on_log(args, state, control, **kwargs)
Log the adjusted Rand index (ARI) during training.
Functions
freeze_except_last_n
freeze_except_last_n(model, n=2)
Freeze all parameters except the last n layers of the encoder. Also keeps all layer norms trainable for stability.
Args: model (AtacformerModel): The model to freeze. n (int): The number of last layers to keep trainable.
patch_atacformer_model_for_mps
patch_atacformer_model_for_mps(model)
Look for any TransformerEncoder
layers in the model and patch them
by setting enable_nested_tensor
to False and setting
use_nested_tensor
to False.
Args: model (nn.Module): The model to patch.
get_git_hash
get_git_hash()
Get the current git hash of the repository.
get_decaying_cosine_with_hard_restarts_schedule_with_warmup
get_decaying_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1, last_epoch=-1)
Very similar to huggingfaces built-in cosine with restarts, however the amplitude slowly decreases so that the "kick ups" are less aggressive.
Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([~torch.optim.Optimizer
]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (int
):
The number of steps for the warmup phase.
num_training_steps (int
):
The total number of training steps.
num_cycles (int
, optional, defaults to 1):
The number of hard restarts to use.
last_epoch (int
, optional, defaults to -1):
The index of the last epoch when resuming training.
Return:
torch.optim.lr_scheduler.LambdaLR
with the appropriate schedule.
tokenize_anndata
tokenize_anndata(adata, tokenizer)
Tokenize an AnnData object. This is more involved, so it gets its own function.
Args: adata (sc.AnnData): The AnnData object to tokenize. tokenizer (Tokenizer): The tokenizer to use.