# src/modeling_multitask.py import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, AutoModel class MultiTaskConfig(PretrainedConfig): model_type = "multitask_xlm" def __init__( self, base_model_name="xlm-roberta-base", num_dept_labels=4, num_urg_labels=3, **kwargs ): super().__init__(**kwargs) self.base_model_name = base_model_name self.num_dept_labels = num_dept_labels self.num_urg_labels = num_urg_labels class MultiTaskForSequenceClassification(PreTrainedModel): config_class = MultiTaskConfig base_model_prefix = "encoder" def __init__(self, config: MultiTaskConfig): super().__init__(config) self.encoder = AutoModel.from_pretrained(config.base_model_name) hidden_size = self.encoder.config.hidden_size self.dept_head = nn.Linear(hidden_size, config.num_dept_labels) self.urg_head = nn.Linear(hidden_size, config.num_urg_labels) self.init_weights() def forward(self, input_ids=None, attention_mask=None, dept_labels=None, urg_labels=None): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] # CLS token dept_logits = self.dept_head(cls_output) urg_logits = self.urg_head(cls_output) loss = None if dept_labels is not None and urg_labels is not None: loss_fct = nn.CrossEntropyLoss() loss_dept = loss_fct(dept_logits.view(-1, self.config.num_dept_labels), dept_labels.view(-1)) loss_urg = loss_fct(urg_logits.view(-1, self.config.num_urg_labels), urg_labels.view(-1)) loss = loss_dept + loss_urg return {"loss": loss, "logits": (dept_logits, urg_logits)}