| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch_cluster import radius_graph |
| | from torch_geometric.nn import MessagePassing |
| |
|
| |
|
| | class CosineCutoff(nn.Module): |
| | |
| | def __init__(self, cutoff): |
| | super(CosineCutoff, self).__init__() |
| | |
| | self.cutoff = cutoff |
| |
|
| | def forward(self, distances): |
| | cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff) + 1.0) |
| | cutoffs = cutoffs * (distances < self.cutoff).float() |
| | return cutoffs |
| |
|
| |
|
| | class ExpNormalSmearing(nn.Module): |
| | def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): |
| | super(ExpNormalSmearing, self).__init__() |
| | self.cutoff = cutoff |
| | self.num_rbf = num_rbf |
| | self.trainable = trainable |
| |
|
| | self.cutoff_fn = CosineCutoff(cutoff) |
| | self.alpha = 5.0 / cutoff |
| |
|
| | means, betas = self._initial_params() |
| | if trainable: |
| | self.register_parameter("means", nn.Parameter(means)) |
| | self.register_parameter("betas", nn.Parameter(betas)) |
| | else: |
| | self.register_buffer("means", means) |
| | self.register_buffer("betas", betas) |
| |
|
| | def _initial_params(self): |
| | start_value = torch.exp(torch.scalar_tensor(-self.cutoff)) |
| | means = torch.linspace(start_value, 1, self.num_rbf) |
| | betas = torch.tensor([(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf) |
| | return means, betas |
| |
|
| | def reset_parameters(self): |
| | means, betas = self._initial_params() |
| | self.means.data.copy_(means) |
| | self.betas.data.copy_(betas) |
| |
|
| | def forward(self, dist): |
| | dist = dist.unsqueeze(-1) |
| | return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2) |
| |
|
| |
|
| | class GaussianSmearing(nn.Module): |
| | def __init__(self, cutoff=5.0, num_rbf=50, trainable=True): |
| | super(GaussianSmearing, self).__init__() |
| | self.cutoff = cutoff |
| | self.num_rbf = num_rbf |
| | self.trainable = trainable |
| |
|
| | offset, coeff = self._initial_params() |
| | if trainable: |
| | self.register_parameter("coeff", nn.Parameter(coeff)) |
| | self.register_parameter("offset", nn.Parameter(offset)) |
| | else: |
| | self.register_buffer("coeff", coeff) |
| | self.register_buffer("offset", offset) |
| |
|
| | def _initial_params(self): |
| | offset = torch.linspace(0, self.cutoff, self.num_rbf) |
| | coeff = -0.5 / (offset[1] - offset[0]) ** 2 |
| | return offset, coeff |
| |
|
| | def reset_parameters(self): |
| | offset, coeff = self._initial_params() |
| | self.offset.data.copy_(offset) |
| | self.coeff.data.copy_(coeff) |
| |
|
| | def forward(self, dist): |
| | dist = dist.unsqueeze(-1) - self.offset |
| | return torch.exp(self.coeff * torch.pow(dist, 2)) |
| |
|
| |
|
| | rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing} |
| |
|
| |
|
| | class ShiftedSoftplus(nn.Module): |
| | def __init__(self): |
| | super(ShiftedSoftplus, self).__init__() |
| | self.shift = torch.log(torch.tensor(2.0)).item() |
| |
|
| | def forward(self, x): |
| | return F.softplus(x) - self.shift |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def __init__(self): |
| | super(Swish, self).__init__() |
| |
|
| | def forward(self, x): |
| | return x * torch.sigmoid(x) |
| |
|
| |
|
| | act_class_mapping = {"ssp": ShiftedSoftplus, "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "swish": Swish} |
| |
|
| |
|
| | class Sphere(nn.Module): |
| | |
| | def __init__(self, l=2): |
| | super(Sphere, self).__init__() |
| | self.l = l |
| | |
| | def forward(self, edge_vec): |
| | edge_sh = self._spherical_harmonics(self.l, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2]) |
| | return edge_sh |
| | |
| | @staticmethod |
| | def _spherical_harmonics(lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
| |
|
| | sh_1_0, sh_1_1, sh_1_2 = x, y, z |
| | |
| | if lmax == 1: |
| | return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1) |
| |
|
| | sh_2_0 = math.sqrt(3.0) * x * z |
| | sh_2_1 = math.sqrt(3.0) * x * y |
| | y2 = y.pow(2) |
| | x2z2 = x.pow(2) + z.pow(2) |
| | sh_2_2 = y2 - 0.5 * x2z2 |
| | sh_2_3 = math.sqrt(3.0) * y * z |
| | sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) |
| |
|
| | if lmax == 2: |
| | return torch.stack([sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], dim=-1) |
| |
|
| |
|
| | class VecLayerNorm(nn.Module): |
| | def __init__(self, hidden_channels, trainable, norm_type="max_min"): |
| | super(VecLayerNorm, self).__init__() |
| | |
| | self.hidden_channels = hidden_channels |
| | self.eps = 1e-12 |
| | |
| | weight = torch.ones(self.hidden_channels) |
| | if trainable: |
| | self.register_parameter("weight", nn.Parameter(weight)) |
| | else: |
| | self.register_buffer("weight", weight) |
| | |
| | if norm_type == "rms": |
| | self.norm = self.rms_norm |
| | elif norm_type == "max_min": |
| | self.norm = self.max_min_norm |
| | else: |
| | self.norm = self.none_norm |
| | |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | weight = torch.ones(self.hidden_channels) |
| | self.weight.data.copy_(weight) |
| | |
| | def none_norm(self, vec): |
| | return vec |
| | |
| | def rms_norm(self, vec): |
| | |
| | dist = torch.norm(vec, dim=1) |
| | |
| | if (dist == 0).all(): |
| | return torch.zeros_like(vec) |
| | |
| | dist = dist.clamp(min=self.eps) |
| | dist = torch.sqrt(torch.mean(dist ** 2, dim=-1)) |
| | return vec / F.relu(dist).unsqueeze(-1).unsqueeze(-1) |
| | |
| | def max_min_norm(self, vec): |
| | |
| | dist = torch.norm(vec, dim=1, keepdim=True) |
| | |
| | if (dist == 0).all(): |
| | return torch.zeros_like(vec) |
| | |
| | dist = dist.clamp(min=self.eps) |
| | direct = vec / dist |
| | |
| | max_val, _ = torch.max(dist, dim=-1) |
| | min_val, _ = torch.min(dist, dim=-1) |
| | delta = (max_val - min_val).view(-1) |
| | delta = torch.where(delta == 0, torch.ones_like(delta), delta) |
| | dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) |
| | |
| | return F.relu(dist) * direct |
| |
|
| | def forward(self, vec): |
| | |
| | if vec.shape[1] == 3: |
| | vec = self.norm(vec) |
| | return vec * self.weight.unsqueeze(0).unsqueeze(0) |
| | elif vec.shape[1] == 8: |
| | vec1, vec2 = torch.split(vec, [3, 5], dim=1) |
| | vec1 = self.norm(vec1) |
| | vec2 = self.norm(vec2) |
| | vec = torch.cat([vec1, vec2], dim=1) |
| | return vec * self.weight.unsqueeze(0).unsqueeze(0) |
| | else: |
| | raise ValueError("VecLayerNorm only support 3 or 8 channels") |
| |
|
| |
|
| | class Distance(nn.Module): |
| | def __init__(self, cutoff, max_num_neighbors=32, loop=True): |
| | super(Distance, self).__init__() |
| | self.cutoff = cutoff |
| | self.max_num_neighbors = max_num_neighbors |
| | self.loop = loop |
| |
|
| | def forward(self, pos, batch): |
| | edge_index = radius_graph(pos, r=self.cutoff, batch=batch, loop=self.loop, max_num_neighbors=self.max_num_neighbors) |
| | edge_vec = pos[edge_index[0]] - pos[edge_index[1]] |
| |
|
| | if self.loop: |
| | mask = edge_index[0] != edge_index[1] |
| | edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) |
| | edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) |
| | else: |
| | edge_weight = torch.norm(edge_vec, dim=-1) |
| |
|
| | return edge_index, edge_weight, edge_vec |
| |
|
| |
|
| | class NeighborEmbedding(MessagePassing): |
| | def __init__(self, hidden_channels, num_rbf, cutoff, max_z=100): |
| | super(NeighborEmbedding, self).__init__(aggr="add") |
| | self.embedding = nn.Embedding(max_z, hidden_channels) |
| | self.distance_proj = nn.Linear(num_rbf, hidden_channels) |
| | self.combine = nn.Linear(hidden_channels * 2, hidden_channels) |
| | self.cutoff = CosineCutoff(cutoff) |
| | |
| | self.reset_parameters() |
| | |
| | def reset_parameters(self): |
| | self.embedding.reset_parameters() |
| | nn.init.xavier_uniform_(self.distance_proj.weight) |
| | nn.init.xavier_uniform_(self.combine.weight) |
| | self.distance_proj.bias.data.fill_(0) |
| | self.combine.bias.data.fill_(0) |
| |
|
| | def forward(self, z, x, edge_index, edge_weight, edge_attr): |
| | |
| | mask = edge_index[0] != edge_index[1] |
| | if not mask.all(): |
| | edge_index = edge_index[:, mask] |
| | edge_weight = edge_weight[mask] |
| | edge_attr = edge_attr[mask] |
| |
|
| | C = self.cutoff(edge_weight) |
| | W = self.distance_proj(edge_attr) * C.view(-1, 1) |
| |
|
| | x_neighbors = self.embedding(z) |
| | |
| | x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None) |
| | x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) |
| | return x_neighbors |
| |
|
| | def message(self, x_j, W): |
| | return x_j * W |
| |
|
| | |
| | class EdgeEmbedding(MessagePassing): |
| | |
| | def __init__(self, num_rbf, hidden_channels): |
| | super(EdgeEmbedding, self).__init__(aggr=None) |
| | self.edge_proj = nn.Linear(num_rbf, hidden_channels) |
| | |
| | self.reset_parameters() |
| | |
| | def reset_parameters(self): |
| | nn.init.xavier_uniform_(self.edge_proj.weight) |
| | self.edge_proj.bias.data.fill_(0) |
| | |
| | def forward(self, edge_index, edge_attr, x): |
| | |
| | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) |
| | return out |
| | |
| | def message(self, x_i, x_j, edge_attr): |
| | return (x_i + x_j) * self.edge_proj(edge_attr) |
| | |
| | def aggregate(self, features, index): |
| | |
| | return features |