Spaces:
Runtime error
Runtime error
| from torch import nn | |
| class ClassificationHead(nn.Module): | |
| """Classification Head for transformer encoders""" | |
| def __init__(self, class_size, embed_size): | |
| super().__init__() | |
| self.class_size = class_size | |
| self.embed_size = embed_size | |
| # self.mlp1 = nn.Linear(embed_size, embed_size) | |
| # self.mlp2 = (nn.Linear(embed_size, class_size)) | |
| self.mlp = nn.Linear(embed_size, class_size) | |
| def forward(self, hidden_state): | |
| # hidden_state = nn.functional.relu(self.mlp1(hidden_state)) | |
| # hidden_state = self.mlp2(hidden_state) | |
| logits = self.mlp(hidden_state) | |
| return logits | |