Molecule Attention Transformer(二)

阅读: 评论:0

Molecule Attention Transformer(二)

Molecule Attention Transformer(二)

应用Transformer框架对分子属性进行预测,代码:MAT,原文:Molecule Attention Transformer。变量名,函数名很多来自The Annotated Transformer,在《深入浅出Embedding》一书中也做了讲解。本文主要从实例运行开始一步步看代码具体内容,整体模型如下:


文章目录

  • 2.模型构建
    • 2.1.make_model & run
    • 2.2.GraphTransformer
    • 2.3.Embedding
    • 2.4.Encoder
    • 2.5.Norm
    • 2.6.EncoderLayer
    • 2.7.SublayerConnection
    • 2.8.MultiHeadedAttention
    • 2.9.attention
    • 2.10.PositionwiseFeedForward
    • 2.11.Generator
    • 2.12.summary

2.模型构建

from transformer import make_modeld_atom = X[0][0].shape[1]  # It depends on the del_params = {'d_atom': d_atom,'d_model': 1024,'N': 8,'h': 16,'N_dense': 1,'lambda_attention': 0.33, 'lambda_distance': 0.33,'leaky_relu_slope': 0.1, 'dense_output_nonlinearity': 'relu', 'distance_matrix_kernel': 'exp', 'dropout': 0.0,'aggregation_type': 'mean'
}model = make_model(**model_params)
  • 利用 make_model 返回构建模型,d_model 是每个原子的特征数,此处是28,d_model 是经过 Embedding 后的维度,N 是 Transformer 块的重复次数,h 是头数,N_dense 是最终模型输出维度,输出标量应该设为1。整个模型构建与 Transformer 类似。

2.1.make_model & run

def make_model(d_atom, N=2, d_model=128, h=8, dropout=0.1, lambda_attention=0.3, lambda_distance=0.3, trainable_lambda=False,N_dense=2, leaky_relu_slope=0.0, aggregation_type='mean', dense_output_nonlinearity='relu', distance_matrix_kernel='softmax',use_edge_features=False, n_output=1,control_edges=False, integrated_distances=False, scale_norm=False, init_type='uniform', use_adapter=False, n_generator_layers=1):"Helper: Construct a model from hyperparameters."c = copy.deepcopyattn = MultiHeadedAttention(h, d_model, dropout, lambda_attention, lambda_distance, trainable_lambda, distance_matrix_kernel, use_edge_features, control_edges, integrated_distances)ff = PositionwiseFeedForward(d_model, N_dense, dropout, leaky_relu_slope, dense_output_nonlinearity)model = GraphTransformer(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout, scale_norm, use_adapter), N, scale_norm),Embeddings(d_model, d_atom, dropout),Generator(d_model, aggregation_type, n_output, n_generator_layers, leaky_relu_slope, dropout, scale_norm))# This was important from their code. # Initialize parameters with Glorot / fan_avg.for p in model.parameters():if p.dim() > 1:if init_type == 'uniform':nn.init.xavier_uniform_(p)elif init_type == 'normal':nn.init.xavier_normal_(p)elif init_type == 'small_normal_init':xavier_normal_small_init_(p)elif init_type == 'small_uniform_init':xavier_uniform_small_init_(p)return modelfor batch in data_loader:adjacency_matrix, node_features, distance_matrix, y = batchbatch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
  • GraphTransformer 由 Embeddings,Encoder,Generator 构成,根据参数初始化,forward 中 src = node_features,下面以(batch_size,max_size,28)(即分子 padding 后"有" max_size 个原子,每个原子以28维 one-hot 编码)为例说明维度变化,batch_mask 是原子成功编码的标志,只要分子中此原子被编码就会为 True,padding 的不存在原子为 False,用来标明有效长度。adj_matrix 和 distances_matrix 是邻接矩阵和距离矩阵,用于做 Molecule self attention。None 表示不使用 edges_att,原文提到使用 edges_att 并没有提升模型性能。
batch_size=2
for batch in data_loader:adjacency_matrix, node_features, distance_matrix, y = batchbatch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0print(node_features)print(batch_mask)break
"""
tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.]],[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False]])
"""

2.2.GraphTransformer

class GraphTransformer(nn.Module):def __init__(self, encoder, src_embed, generator):super(GraphTransformer, self).__init__()der = encoderself.src_embed = ator = generatordef forward(self, src, src_mask, adj_matrix, distances_matrix, edges_att):"Take in and process masked src and target sequences."return self.de(src, src_mask, adj_matrix, distances_matrix, edges_att), src_mask)def encode(self, src, src_mask, adj_matrix, distances_matrix, edges_att):der(self.src_embed(src), src_mask, adj_matrix, distances_matrix, edges_att)def predict(self, out, out_mask):ator(out, out_mask)
  • 先经过 Encoder 编码,再用 Generator 输出,Encoder 中先对 src 进行Embedding

2.3.Embedding

class Embeddings(nn.Module):def __init__(self, d_model, d_atom, dropout):super(Embeddings, self).__init__()self.lut = nn.Linear(d_atom, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.lut(x))
  • 经过线性变换和 dropout,max_size 个原子的分子被编码为 (batch_size,max_size,1024) 维矩阵,这里没有用 Embedding,Transformer 的实现中使用的是 Embedding

2.4.Encoder

def clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class Encoder(nn.Module):"Core encoder is a stack of N layers"def __init__(self, layer, N, scale_norm):super(Encoder, self).__init__()self.layers = clones(layer,  = ScaleNorm(layer.size) if scale_norm else LayerNorm(layer.size)def forward(self, x, mask, adj_matrix, distances_matrix, edges_att):"Pass the input (and mask) through each layer in turn."for layer in self.layers:x = layer(x, mask, adj_matrix, distances_matrix, edges_att)(x)
  • Encoder 是 N 个 EncoderLayer 的堆叠,最后添加 Norm 层。Norm 分为 ScaleNorm 和 LayerNorm

2.5.Norm

class LayerNorm(nn.Module):"Construct a layernorm module (See citation for details)."def __init__(self, features, eps=1e-6):super(LayerNorm, self).__init__()self.a_2 = nn.s(features))self.b_2 = nn.s(features))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)std = x.std(-1, keepdim=True)return self.a_2 * (x - mean) / (std + self.eps) + self.b_2class ScaleNorm(nn.Module):"""ScaleNorm""""All g’s in SCALE NORM are initialized to sqrt(d)"def __init__(self, scale, eps=1e-5):super(ScaleNorm, self).__init__()self.scale = nn.sor(math.sqrt(scale)))self.eps = epsdef forward(self, x):norm = self.scale / (x, dim=-1, keepdim=True).clamp(min=self.eps)return x * norm
  • norm 层的两种方式,LayerNorm 适用于有 padding 存在的情况,ScaleNorm 进行了 l 2 l_2 l2​归一化,这里使用的是 LayerNorm
  • eps 是为了避免除以 0 的情况发生

2.6.EncoderLayer

class EncoderLayer(nn.Module):"Encoder is made up of self-attn and feed forward (defined below)"def __init__(self, size, self_attn, feed_forward, dropout, scale_norm, use_adapter):super(EncoderLayer, self).__init__()self.self_attn = self_attnself.feed_forward = feed_forwardself.sublayer = clones(SublayerConnection(size, dropout, scale_norm, use_adapter), 2)self.size = sizedef forward(self, x, mask, adj_matrix, distances_matrix, edges_att):"Follow Figure 1 (left) for connections."x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, adj_matrix, distances_matrix, edges_att, mask))return self.sublayer[1](x, self.feed_forward)
  • EncoderLayer 包含 2 个 SublayerConnection 层,每个 SublayerConnection 层包含一个自注意力层和一个全连接层,SublayerConnection 作为一个类抽象出残差连接

2.7.SublayerConnection

class SublayerConnection(nn.Module):"""A residual connection followed by a layer norm.Note for code simplicity the norm is first as opposed to last."""def __init__(self, size, dropout, scale_norm, use_adapter):super(SublayerConnection, self).__init__() = ScaleNorm(size) if scale_norm else LayerNorm(size)self.dropout = nn.Dropout(dropout)self.use_adapter = use_adapterself.adapter = Adapter(size, 8) if use_adapter else Nonedef forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."if self.use_adapter:return x + self.dropout(self.adapter((x))))return x + self.dropout((x)))
  • Adapter 暂时不清楚哪里来的…但 run 的时候设置为是 False,所以不影响。forward 通过传入输入和层函数来发挥残差连接的作用

2.8.MultiHeadedAttention

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1, lambda_attention=0.3, lambda_distance=0.3, trainable_lambda=False, distance_matrix_kernel='softmax', use_edge_features=False, control_edges=False, integrated_distances=False):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = ainable_lambda = trainable_lambdaif trainable_lambda:lambda_adjacency = 1. - lambda_attention - lambda_distancelambdas_tensor = sor([lambda_attention, lambda_distance, lambda_adjacency], requires_grad=True)self.lambdas = Parameter(lambdas_tensor)else:lambda_adjacency = 1. - lambda_attention - lambda_distanceself.lambdas = (lambda_attention, lambda_distance, lambda_adjacency)self.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)if distance_matrix_kernel == 'softmax':self.distance_matrix_kernel = lambda x: F.softmax(-x, dim = -1)elif distance_matrix_kernel == 'exp':self.distance_matrix_kernel = lambda x: p(-x)self.integrated_distances = integrated_distancesself.use_edge_features = use_l_edges = control_edgesif use_edge_features:d_edge = 11 if not integrated_distances else 12self.edges_feature_layer = EdgeFeaturesLayer(d_model, d_edge, h, dropout)def forward(self, query, key, value, adj_matrix, distances_matrix, edges_att, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]# Prepare distances matrixdistances_matrix = distances_matrix.masked_peat(1, mask.shape[-1], 1) == 0, np.inf)distances_matrix = self.distance_matrix_kernel(distances_matrix)p_dist = distances_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)if self.use_edge_features:if self.integrated_distances:edges_att = torch.cat((edges_att, distances_matrix.unsqueeze(1)), dim=1)edges_att = self.edges_feature_layer(edges_att)# 2) Apply attention on all the projected vectors in batch. x, self.attn, self.self_attn = attention(query, key, value, adj_matrix, p_dist, edges_att,mask=mask, dropout=self.dropout,lambdas=self.lambdas,trainable_lambda&#ainable_lambda,distance_matrix_kernel=self.distance_matrix_kernel,use_edge_features=self.use_edge_features,control_edges&#l_edges)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() .view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)
  • 这里的参数基本与 Transformer 中的一致,self.lambdas 是 MAT 中不同于 Transformer 的点,当不训练时设置为定值

  • self.linears 基本对应 Transformer中的 W Q , W K , W V , W O W^Q,W^K,W^V,W^O WQ,WK,WV,WO,但维度不一致,此代码中没有进行 concat,而是统一处理

  • forward 中的 query, key, value 都是 x,即 node_featues 经过 Embedding 后的矩阵,维度是 (max_size,1024),mask 是 batch_mask,标明有效长度的矩阵,维度是 (batch_size,max_size),unsqueeze 在维度为 1 处增加维度,最终维度变为 (batch_size,1,max_size),示例如下:

import torch
batch_size=2
max_size=14
mask&#s((batch_size,max_size))
print(mask)
print(mask.unsqueeze(1))
"""
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
"""
  • 接下来用线性层将 query,key,value 进行转换,但并没有维度变化,它们的维度仍然是 (batch_size,max_size,d_model),继续使用 view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 进行维度转换,最后它们的维度变为(batch_size,h,max_size,d_k),示例如下:
query=torch.Tensor(64,14,1024)
l&#Linear(1024,1024)
nbatches,h,d_k=64,16,64
l(query).view(nbatches, -1, h, d_k).transpose(1, 2).shape #torch.Size([64, 16, 14, 64])
mask&#sor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False]])
mask=mask.unsqueeze(1).repeat(1, mask.shape[-1], 1)
print(mask.shape) #torch.Size([2, 11, 11]),这里batch_size=2,max_size=11
print(mask)
"""
tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True],[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True]],[[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False],[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,False]]])
"""
  • 对 mask 为 False 的地方在 distance_matrix 填充 np.inf,再进行 lambda x: p(-x) 的映射,距离为无穷大的地方会变成0
  • distances_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1),p_distance 的维度变为(batch_size,h,max_size,max_size)
  • use_edge_features 为 False,将数据输入 attention

2.9.attention

def attention(query, key, value, adj_matrix, distances_matrix, edges_att,mask=None, dropout=None, lambdas=(0.3, 0.3, 0.4), trainable_lambda=False,distance_matrix_kernel=None, use_edge_features=False, control_edges=False,eps=1e-6, inf=1e12):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, anspose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2], 1) == 0, -inf)p_attn = F.softmax(scores, dim = -1)if use_edge_features:adj_matrix = edges_att.view(adj_matrix.shape)# Prepare adjacency matrixadj_matrix = adj_matrix / (adj_matrix.sum(dim=-1).unsqueeze(2) + eps)adj_matrix = adj_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)p_adj = adj_matrixp_dist = distances_matrixif trainable_lambda:softmax_attention, softmax_distance, softmax_adjacency = lambdas.cuda()p_weighted = softmax_attention * p_attn + softmax_distance * p_dist + softmax_adjacency * p_adjelse:lambda_attention, lambda_distance, lambda_adjacency = lambdasp_weighted = lambda_attention * p_attn + lambda_distance * p_dist + lambda_adjacency * p_adjif dropout is not None:p_weighted = dropout(p_weighted)atoms_featrues = torch.matmul(p_weighted, value)     return atoms_featrues, p_weighted, p_attn
  • scores 是 query 和 key 的相似度得分, Q b a t c h _ s i z e × h × m a x _ s i z e × d _ k K b a t c h _ s i z e × h × d _ k × m a x _ s i z e = S b a t c h _ s i z e × h × m a x _ s i z e × m a x _ s i z e Q_{batch_sizetimes htimes max_sizetimes d_k} K_{batch_sizetimes htimes d_ktimes max_size}=S_{batch_sizetimes htimes max_sizetimes max_size} Qbatch_size×h×max_size×d_k​Kbatch_size×h×d_k×max_size​=Sbatch_size×h×max_size×max_size​,mask 的维度是(batch_size,1,max_size),mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2], 1) 后的维度是(batch_size,h,max_size,max_size)与scores的维度匹配,将 padding 的部分scores设为负无穷,相当于注意力为0
  • adj_matrix 的维度是(batch_size,max_size,max_size),adj_matrix.sum(dim=-1) 得到的是矩阵维度是 (batch_size,max_size),代表的意义是 batch 中每个分子的原子所连原子(包括本身)的数量,第一个原子是 dummy_node。示例如下
batch_size=1
eps=1e-6
adj_matrix&#sor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]]])
adj_matrix / (adj_matrix.sum(dim=-1).unsqueeze(2) + eps)
"""
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.2500, 0.0000, 0.2500, 0.2500, 0.2500, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.2500, 0.0000, 0.2500, 0.2500,0.0000, 0.0000, 0.0000, 0.2500, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333,0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333,0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.2500, 0.2500, 0.2500, 0.0000, 0.2500, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000,0.0000, 0.0000, 0.3333, 0.3333, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.3333, 0.0000, 0.0000, 0.3333, 0.3333],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]])
"""
  • p_adj 维度变为(batch_size,h,max_size,max_size),与 p_dis 和 p_attn 维度统一,与对应稀疏相乘后得到最后的 MolculeAttention 得分,p_weighted 维度也是 (batch_size,h,max_size,max_size)

  • value 的维度是 (batch_size,h,max_size,d_k),atoms_featrues 最终的维度为 (batch_size,h,max_size,d_k)

2.10.PositionwiseFeedForward

class PositionwiseFeedForward(nn.Module):"Implements FFN equation."def __init__(self, d_model, N_dense, dropout=0.1, leaky_relu_slope=0.0, dense_output_nonlinearity='relu'):super(PositionwiseFeedForward, self).__init__()self.N_dense = N_denseself.linears = clones(nn.Linear(d_model, d_model), N_dense)self.dropout = clones(nn.Dropout(dropout), N_dense)self.leaky_relu_slope = leaky_relu_slopeif dense_output_nonlinearity == 'relu':self.dense_output_nonlinearity = lambda x: F.leaky_relu(x, negative_slope=self.leaky_relu_slope)elif dense_output_nonlinearity == 'tanh':self.tanh = Tanh()self.dense_output_nonlinearity = lambda x: self.tanh(x)elif dense_output_nonlinearity == 'none':self.dense_output_nonlinearity = lambda x: xdef forward(self, x):if self.N_dense == 0:return xfor i in range(len(self.linears)-1):x = self.dropout[i](F.leaky_relu(self.linears[i](x), negative_slope=self.leaky_relu_slope))return self.dropout[-1](self.dense_output_nonlinearity(self.linears[-1](x)))
  • N_dense 是 线性层的数量,最后输出的维度不变,进入下一个 EncoderLayer 块的维度是 (batch_size,max_size,d_model),与刚经过 Embedding 的维度一致,重复 N 次后进入 Norm 层,再进入 Generator

2.11.Generator

class Generator(nn.Module):"Define standard linear + softmax generation step."def __init__(self, d_model, aggregation_type='mean', n_output=1, n_layers=1, leaky_relu_slope=0.01, dropout=0.0, scale_norm=False):super(Generator, self).__init__()if n_layers == 1:self.proj = nn.Linear(d_model, n_output)else:self.proj = []for i in range(n_layers-1):self.proj.append(nn.Linear(d_model, d_model))self.proj.append(nn.LeakyReLU(leaky_relu_slope))self.proj.append(ScaleNorm(d_model) if scale_norm else LayerNorm(d_model))self.proj.append(nn.Dropout(dropout))self.proj.append(nn.Linear(d_model, n_output))self.proj = Sequential(*self.proj)self.aggregation_type = aggregation_typedef forward(self, x, mask):mask = mask.unsqueeze(-1).float()out_masked = x * maskif self.aggregation_type == 'mean':out_sum = out_masked.sum(dim=1)mask_sum = mask.sum(dim=(1))out_avg_pooling = out_sum / mask_sumelif self.aggregation_type == 'sum':out_sum = out_masked.sum(dim=1)out_avg_pooling = out_sumelif self.aggregation_type == 'dummy_node':out_avg_pooling = out_masked[:,0]projected = self.proj(out_avg_pooling)return projected
  • forward 中的 mask 是 batch_mask,维度是 (batch_size,max_size),x 的维度是 (batch_size,max_size,d_model),padding 的部分为 0,相乘有 broadcast,最终 out_masked 维度与 x 维度一致,之后进行聚合,消除 max_size 维度,再进入 Sequential,最终输出(batch_size,n_output)维度的预测值

2.12.summary

  • 最终模型定义如下:
GraphTransformer((encoder): Encoder((layers): ModuleList((0): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(1): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(2): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(3): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(4): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(5): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(6): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))))(7): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Linear(in_features=1024, out_features=1024, bias=True)(2): Linear(in_features=1024, out_features=1024, bias=True)(3): Linear(in_features=1024, out_features=1024, bias=True))(dropout): Dropout(p=0.0, inplace=False))(feed_forward): PositionwiseFeedForward((linears): ModuleList((0): Linear(in_features=1024, out_features=1024, bias=True))(dropout): ModuleList((0): Dropout(p=0.0, inplace=False)))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.0, inplace=False)))))(norm): LayerNorm())(src_embed): Embeddings((lut): Linear(in_features=28, out_features=1024, bias=True)(dropout): Dropout(p=0.0, inplace=False))(generator): Generator((proj): Linear(in_features=1024, out_features=1, bias=True))
)
  • 模型构建基本与 Transformer 一致,不同之处是没有进行位置编码,且 attention 略微不同,除了进行自注意力,还利用了邻接矩阵和距离矩阵的信息,这里没有使用 use_edge_features。另外现在不清楚 PositionGenerator 和 Adapter 的作用

本文发布于:2024-01-29 08:07:28,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170648685113877.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

上一篇:Multi
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23