Skip to content

Train Atacformer from scratch

This tutorial will guide you through the process of training an Atacformer model from scratch using genomic interval data. We will cover the necessary steps, including data preparation, model configuration, and training.

Overview

The pre-training process uses a replaced token detection objective (similar to ELECTRA) to learn representations of genomic regions and their accessibility patterns.

Prerequisites

Before starting, ensure you have:

  • Access to pre-tokenized single-cell ATAC-seq data
  • A universe file (.bed.gz) defining the genomic regions of interest
  • GPU resources for training

Model Configuration

You can configure the Atacformer model however you like. This is what was used in the original paper:

config = AtacformerConfig(
    vocab_size=tokenizer.vocab_size,    # Based on universe file
    hidden_size=192,                    # Hidden dimension size
    num_hidden_layers=6,                # Number of transformer layers
    num_attention_heads=8,              # Number of attention heads
    intermediate_size=768,              # Feed-forward network size
    max_position_embeddings=8192,       # Maximum sequence length
    pad_token_id=tokenizer.pad_token_id,
)

Module imports

First, import the necessary modules:

import torch

from datasets import Dataset
from transformers import TrainingArguments, Trainer
from atacformer import (
    AtacformerConfig,
    AtacformerForReplacedTokenDetection,
    DataCollatorForReplacedTokenDetection,
    TrainingTokenizer,
    get_decaying_cosine_with_hard_restarts_schedule_with_warmup,
)

Training setup

Set up torch

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("medium")

Make sure your hyperparameters are set up correctly:

MLM_PROBABILITY = 0.45
BATCH_SIZE = 16
MAX_LEARNING_RATE = 1.5e-4

RUN_NAME = "atacformer-pretraining"

Data Preparation

Then load your dataset. The training expects a pre-tokenized dataset in Parquet format with the following columns:

  • input_ids: Tokenized genomic regions
  • Additional metadata columns (removed during training)
dataset_path = "path/to/dataset.parquet"
tokenized_dataset = Dataset.from_parquet(dataset_path)
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.1, shuffle=True, seed=42)

Tokenizer Setup

Next, set up the tokenizer:

tokenizer = TrainingTokenizer("path/to/universe.bed.gz")

The tokenizer is created from a universe file that defines the genomic regions:

tokenizer = TrainingTokenizer(universe_bed_file)

Model and Training Arguments

Instantiate a new Atacformer model with the replaced token detection (RTD) objective. We move the model to bfloat16 because its much faster when using "ampere" GPUs:

config = AtacformerConfig(
    use_pos_embeddings=False,
    vocab_size=tokenizer.vocab_size,
    hidden_size=192,
    num_hidden_layers=6,
    num_attention_heads=8,
    intermediate_size=768,
    max_position_embeddings=8192,
    pad_token_id=tokenizer.pad_token_id,
)
model = AtacformerForReplacedTokenDetection(config)
model = model.to(torch.bfloat16)

print(f"Model size: {model.num_parameters()} parameters")

Data Collation

The DataCollatorForReplacedTokenDetection handles: - Random token replacement based on mlm_probability - Proper masking and attention handling - Batch preparation for training

Create the data collator:

data_collator = DataCollatorForReplacedTokenDetection(
    tokenizer=tokenizer,
    mlm_probability=MLM_PROBABILITY,
)

Setup the Trainer

Now, set up the Trainer with the model, training arguments, and data collator:

training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=25,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_strategy="steps",
    logging_steps=10,
    run_name=RUN_NAME,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_steps=500,
    learning_rate=MAX_LEARNING_RATE,
    bf16=True,
    max_grad_norm=1.0,
)

# trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

Training

Finally, start the training process:

trainer.train()
model.save_pretrained("output/atacformer-pretrained")