| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Union |
| |
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.generation.utils import GenerationMixin |
| |
|
| | from .shared_space_config import SharedSpaceDecoderConfig |
| | from .shared_space_decoder import ( |
| | SharedSpaceDecoderPreTrainedModel, |
| | SharedSpaceDecoderModel, |
| | DeepseekV3RMSNorm |
| | ) |
| |
|
| | def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: |
| | """ |
| | Create a normalization layer based on the config norm_type. |
| | |
| | Args: |
| | hidden_size: The dimension to normalize over |
| | config: Configuration containing norm_type and epsilon values |
| | |
| | Returns: |
| | Either a LayerNorm or RMSNorm layer |
| | """ |
| | if config.norm_type == "layernorm": |
| | return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) |
| | elif config.norm_type == "rmsnorm": |
| | from .shared_space_decoder import DeepseekV3RMSNorm |
| | return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) |
| | else: |
| | |
| | raise ValueError(f"Unknown norm_type: {config.norm_type}") |
| |
|
| |
|
| | class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel): |
| | """ |
| | Subspace Decoder model with a causal language modeling head. |
| | |
| | This model extends the SharedSpaceDecoderModel with: |
| | - A language modeling head that projects hidden states to vocabulary logits |
| | - Support for computing cross-entropy loss for language modeling |
| | - Proper HuggingFace compatibility for causal language modeling tasks |
| | - Decoder-specific initialization strategies |
| | |
| | The model can be used for: |
| | - Text generation |
| | - Language modeling pretraining |
| | - Fine-tuning on downstream tasks |
| | """ |
| |
|
| | def __init__(self, config: SharedSpaceDecoderConfig) -> None: |
| | super().__init__(config) |
| | |
| | |
| | self.model = SharedSpaceDecoderModel(config) |
| | |
| | |
| | self.norm = create_norm_layer(config.hidden_size, config) |
| | |
| | |
| | |
| | self.lm_head = nn.Linear( |
| | config.hidden_size, |
| | config.vocab_size, |
| | bias=False |
| | ) |
| | |
| | |
| | |
| | self.post_init() |
| |
|
| | def _init_weights(self, module: nn.Module) -> None: |
| | """ |
| | Decoder-specific weight initialization with special handling for language modeling head. |
| | |
| | Key differences from encoder initialization: |
| | - Language modeling head gets specialized initialization for stability |
| | - Configurable normalization layers (LayerNorm or RMSNorm) are properly handled |
| | - Weight tying considerations for embedding/lm_head relationship |
| | """ |
| | |
| | |
| | super()._init_weights(module) |
| | |
| | |
| | if module is self.lm_head: |
| | |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | |
| | |
| | if self.model.vocab_proj is not None: |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range * 0.5) |
| |
|
| | def get_input_embeddings(self): |
| | """Return the input embedding layer for compatibility with HuggingFace.""" |
| | return self.model.vocab_embed |
| |
|
| | def set_input_embeddings(self, value): |
| | """Set the input embedding layer for compatibility with HuggingFace.""" |
| | self.model.vocab_embed = value |
| |
|
| | def get_output_embeddings(self): |
| | """Return the output embedding layer (lm_head) for compatibility.""" |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | """Set the output embedding layer for compatibility.""" |
| | self.lm_head = new_embeddings |
| |
|
| | def tie_weights(self): |
| | """ |
| | Tie the input and output embedding weights. |
| | |
| | This method sets the language modeling head's weight to be the same as |
| | the input embedding weight. This reduces the number of parameters and |
| | is a common practice in modern language models. |
| | |
| | Note: For vocab subspace models, we need to handle the case where |
| | input embeddings go through a projection layer. |
| | """ |
| | |
| | if getattr(self.model, "vocab_proj", None) is None: |
| | |
| | self._tie_or_clone_weights(self.lm_head, self.model.vocab_embed) |
| | |
| |
|
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> Union[CausalLMOutputWithPast, tuple]: |
| | """ |
| | Forward pass for causal language modeling. |
| | |
| | Args: |
| | input_ids: Token ids of shape [batch_size, seq_len] |
| | attention_mask: Attention mask of shape [batch_size, seq_len] |
| | (1 for real tokens, 0 for padding) |
| | labels: Ground truth token ids for computing loss. Same shape as input_ids. |
| | If provided, loss will be computed. Typically input_ids shifted by 1. |
| | |
| | Returns: |
| | CausalLMOutputWithPast containing: |
| | - logits: Prediction logits of shape [batch_size, seq_len, vocab_size] |
| | - loss: Cross-entropy loss if labels provided, else None |
| | - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size] |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | if attention_mask is None and input_ids is not None: |
| | |
| | |
| | attention_mask = torch.ones( |
| | (input_ids.size(0), input_ids.size(1)), |
| | dtype=torch.long, |
| | device=input_ids.device, |
| | ) |
| | |
| | |
| | |
| | |
| | hidden_states = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **kwargs |
| | ) |
| | |
| | |
| | |
| | hidden_states = self.norm(hidden_states) |
| | |
| | |
| | |
| | logits = self.lm_head(hidden_states) |
| | |
| | |
| | |
| | |
| | loss = None |
| | if labels is not None: |
| | |
| | loss = self.loss_function( |
| | logits, |
| | labels, |
| | vocab_size=self.config.vocab_size, |
| | **kwargs, |
| | ) |
| | |
| | |
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=None, |
| | |
| | hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None, |
| | attentions=None, |
| | ) |
| | |
| | |
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | past_key_values=None, |
| | attention_mask=None, |
| | **kwargs, |
| | ): |
| | |
| | return {"input_ids": input_ids, "attention_mask": attention_mask} |
| |
|
| | |
| | def _reorder_cache(self, past_key_values, beam_idx): |
| | return past_key_values |
| |
|
| |
|