Using atacformer with Apple Silicon
Apple's M-series chips have become increasingly popular for local-first machine learning tasks. Conveniently, PyTorch has first-class support for Apple Silicon, making it a great choice for running Atacformer models on these devices.
In most cases, it could be as simple as moving your model to the mps
device:
from geniml.atacformer import AtacformerForCellClustering
model = AtacformerForCellClustering.from_pretrained("databio/atacformer-base-hg38")
model = model.to("mps")
However, at the moment, not all tensor operations are supported on the mps
device. One of these is the aten::_nested_tensor_from_mask_left_aligned
operation inside the TransformerEncoder
class, which is used in the Atacformer model.
To work around this, we've provided a patch that modifies the TransformerEncoder
to use a different operation that is supported on the mps
device. You can use it like so:
from geniml.atacformer import AtacformerForCellClustering, patch_atacformer_model_for_mps
model = AtacformerForCellClustering.from_pretrained("databio/atacformer-base-hg38")
model = model.to("mps")
patch_atacformer_model_for_mps(model)
# use the model as usual
# ...