1
2
3
4
5
6
7
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
import numpy as np
import time
import pandas as pd
import pickle
1
2
3
4
5
6
7
8
# already download
model_path = '/data/lwb/WorkSpace/ProtTrans/prot_t5_xl_half_uniref50-enc'
tokenizer = T5Tokenizer.from_pretrained(model_path, do_lower_case=False, legacy=False)
model = T5EncoderModel.from_pretrained(model_path)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model = model.eval()
1
2
3
4
5
6
7
8
9
10
11
12
13
def extract_fea(seqs):
sequence_examples=[]
sequence_examples.append(seqs)
# sequence_examples = ["PRTEINO", "SEQWENCE"]
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
with torch.no_grad():
embedding_rpr = model(input_ids=input_ids, attention_mask=attention_mask)
emb_0 = embedding_rpr.last_hidden_state[0, :]

return emb_0.detach().cpu().numpy()

Demo

1
2
3
4
5
6
7
8
data = pd.read_csv('...')
namelist=list(data['name'])
seqlist=list(data['seq'])

for i in range(len(data)):
out = extract_fea(seqlist[i])
dir='/home/lwb/.../'+namelist[i]+'.npy'
np.save(dir,output)
  • Ref ESM Blog if seq too long