1
2
3
4
5
import torch
import esm
import numpy as np
import pandas as pd
import pickle

Load ESM Model

1
2
3
4
5
6
7
8
9
10
11
12
## first time 
## path:#/home/lwb/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

## download already
model, alphabet = esm.pretrained.load_model_and_alphabet_local('/data/lwb/WorkSpace/PSPs/ESM/model/esm2_t33_650M_UR50D.pt')
batch_converter = alphabet.get_batch_converter()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

Get Feature

1
2
3
4
5
6
7
8
9
10
11
12
13
## ProteinName Seq
## 显卡不支持做batch
def genfeature_esm(name,seq):
data = [(name, seq)]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.to(device)

with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
# 第0个(only one actually)
esmfea=token_representations[0][1:-1].cpu().numpy()
return esmfea

Split Long Seq

1
2
3
4
5
6
7
8
9
10
11
def split_sequence(name, sequence, segment_length=1024):
num_segments = len(sequence) // segment_length
segments = [sequence[i * segment_length: (i + 1) * segment_length] for i in range(num_segments)]
if len(sequence) % segment_length != 0:
segments.append(sequence[num_segments * segment_length:])
array_list=[]
for i in segments:
array_list.append(genfeature_esm(name,i))

result_array = np.concatenate(array_list, axis=0)
return result_array

Demo

1
2
3
4
5
6
7
8
9
10
11
12
data = pd.read_csv('...') # dataframe
namelist = data['name'].to_list()
seqlist = data['seq'].to_list()

for i in range(len(data)):
if len(seqlist[i])<=1024:
tmpfea = genfeature_esm(namelist(i),seqlist(i))
else:
tmpfea = split_sequence(namelist(i), seqlist[i], segment_length=1024)

dir='...'+namelist[i]+'.npy'
np.save(dir,tmpfea)