| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from dataclasses import dataclass |
| | from typing import Callable, Optional, Tuple, Union |
| | from PIL import Image |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import Qwen3Model |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.processing_utils import Unpack |
| | from transformers.utils import TransformersKwargs, can_return_tuple, logging |
| |
|
| | from typing import Any, Literal, Optional, TypedDict, Union |
| |
|
| | from .configuration_step_vl import StepRoboticsConfig |
| | from .vision_encoder import StepRoboticsVisionEncoder |
| | logger = logging.get_logger(__name__) |
| |
|
| | class StepVLImagePixelInputs(TypedDict): |
| | type: Literal["pixel_values"] |
| | pixel_values: torch.Tensor |
| | patch_pixel_values: Optional[torch.Tensor] |
| | num_patches: list[int] |
| |
|
| |
|
| | class StepVLImageEmbeddingInputs(TypedDict): |
| | type: Literal["image_embeds"] |
| | image_embeds: torch.Tensor |
| |
|
| |
|
| | StepVLImageInputs = Union[StepVLImagePixelInputs, |
| | StepVLImageEmbeddingInputs] |
| |
|
| |
|
| | @dataclass |
| | class StepVLCausalLMOutputWithPast(ModelOutput): |
| | r""" |
| | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| | Language modeling loss (for next-token prediction). |
| | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| | past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
| | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| | `past_key_values` input) to speed up sequential decoding. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | last_hidden_state: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor = None |
| | past_key_values: Optional[list[torch.FloatTensor]] = None |
| | hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| | attentions: Optional[tuple[torch.FloatTensor]] = None |
| | image_hidden_states: Optional[torch.FloatTensor] = None |
| |
|
| | def _flatten_embeddings(embeddings) -> torch.Tensor: |
| | """ |
| | Recursively flattens and concatenates NestedTensors on all but the last |
| | dimension. |
| | """ |
| |
|
| | if isinstance(embeddings, torch.Tensor): |
| | |
| | return embeddings.flatten(0, -2) |
| |
|
| | return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) |
| |
|
| | def _embedding_count_expression(embeddings) -> str: |
| | """ |
| | Constructs a debugging representation of the number of embeddings in the |
| | NestedTensors. |
| | """ |
| |
|
| | if isinstance(embeddings, torch.Tensor): |
| | return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) |
| |
|
| | return " + ".join( |
| | _embedding_count_expression(inner) for inner in embeddings) |
| |
|
| | def _merge_multimodal_embeddings( |
| | inputs_embeds: torch.Tensor, |
| | is_multimodal: torch.Tensor, |
| | multimodal_embeddings, |
| | ) -> torch.Tensor: |
| | """ |
| | Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| | positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| | ``input_ids``. |
| | Note: |
| | This updates ``inputs_embeds`` in place. |
| | """ |
| | num_expected_tokens = is_multimodal.sum().item() |
| | assert isinstance(num_expected_tokens, int) |
| |
|
| | flattened = _flatten_embeddings(multimodal_embeddings) |
| | if flattened.shape[0] != num_expected_tokens: |
| | expr = _embedding_count_expression(multimodal_embeddings) |
| | raise ValueError( |
| | f"Attempted to assign {expr} = {flattened.shape[0]} " |
| | f"multimodal tokens to {num_expected_tokens} placeholders") |
| |
|
| | is_multimodal = is_multimodal.to(inputs_embeds.device) |
| | flattened = flattened.to(inputs_embeds.device) |
| | inputs_embeds[is_multimodal] = flattened |
| | return inputs_embeds |
| |
|
| | def merge_multimodal_embeddings( |
| | input_ids: torch.Tensor, |
| | inputs_embeds: torch.Tensor, |
| | multimodal_embeddings, |
| | placeholder_token_id: Union[int, list[int]], |
| | ) -> torch.Tensor: |
| | """ |
| | Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
| | positions in ``inputs_embeds`` corresponding to placeholder tokens in |
| | ``input_ids``. |
| | |
| | ``placeholder_token_id`` can be a list of token ids (e.g, token ids |
| | of img_start, img_break, and img_end tokens) when needed: This means |
| | the order of these tokens in the ``input_ids`` MUST MATCH the order of |
| | their embeddings in ``multimodal_embeddings`` since we need to |
| | slice-merge instead of individually scattering. |
| | For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where |
| | - T is text token |
| | - S is image start token |
| | - I is image embedding token |
| | - B is image break token |
| | - E is image end token. |
| | |
| | Then the image embeddings (that correspond to I's) from vision encoder |
| | must be padded with embeddings of S, B, and E in the same order of |
| | input_ids for a correct embedding merge. |
| | Note: |
| | This updates ``inputs_embeds`` in place. |
| | """ |
| | if isinstance(placeholder_token_id, list): |
| | placeholder_token_id = torch.tensor(placeholder_token_id, |
| | device=input_ids.device) |
| | return _merge_multimodal_embeddings( |
| | inputs_embeds, |
| | torch.isin(input_ids, placeholder_token_id), |
| | multimodal_embeddings, |
| | ) |
| |
|
| | return _merge_multimodal_embeddings( |
| | inputs_embeds, |
| | (input_ids == placeholder_token_id), |
| | multimodal_embeddings, |
| | ) |
| |
|
| | class StepRoboticsPreTrainedModel(PreTrainedModel): |
| | |
| | |
| | config_class = StepRoboticsConfig |
| | supports_gradient_checkpointing = True |
| | _skip_keys_device_placement = ["past_key_values"] |
| | _supports_flash_attn = False |
| | _supports_sdpa = True |
| | _supports_flex_attn = True |
| | _supports_static_cache = True |
| | _supports_attention_backend = True |
| |
|
| |
|
| | class StepRoboticsModel(StepRoboticsPreTrainedModel, GenerationMixin): |
| | config: StepRoboticsConfig |
| | base_model_prefix = "" |
| | def __init__(self, config: StepRoboticsConfig): |
| | super().__init__(config) |
| | self.vision_model = StepRoboticsVisionEncoder(config.vision_config) |
| | self.language_model = Qwen3Model(config.text_config) |
| | self.vocab_size = config.text_config.vocab_size |
| | self.vit_large_projector = nn.Linear( |
| | config.vision_config.width * 4, |
| | config.text_config.hidden_size, |
| | bias=config.projector_bias) |
| | self.image_placeholder_token_id = config.image_token_id |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings( |
| | self, |
| | input_ids: torch.Tensor, |
| | multimodal_embeddings = None, |
| | ) -> torch.Tensor: |
| | input_ids = input_ids.squeeze(0) |
| | if multimodal_embeddings is None: |
| | inputs_embeds = self.language_model.embed_tokens(input_ids) |
| | else: |
| | is_text = input_ids != self.config.image_token_id |
| | text_ids = input_ids[is_text] |
| | text_embeds = self.language_model.embed_tokens(text_ids) |
| | |
| | inputs_embeds = torch.empty(input_ids.shape[0], |
| | text_embeds.shape[-1], |
| | dtype=text_embeds.dtype, |
| | device=text_embeds.device) |
| | inputs_embeds[is_text] = text_embeds |
| | inputs_embeds = merge_multimodal_embeddings( |
| | input_ids, inputs_embeds, multimodal_embeddings, |
| | self.config.image_token_id) |
| | inputs_embeds = inputs_embeds.unsqueeze(0) |
| | return inputs_embeds |
| | |
| |
|
| | def set_input_embeddings(self, value): |
| | return self.language_model.set_input_embeddings(value) |
| |
|
| | def set_decoder(self, decoder): |
| | self.language_model = decoder |
| |
|
| | def get_decoder(self): |
| | return self.language_model |
| | |
| | def _parse_and_validate_image_input( |
| | self, **kwargs: object) -> Optional[StepVLImageInputs]: |
| | pixel_values = kwargs.pop("pixel_values", None) |
| | patch_pixel_values = kwargs.pop("patch_pixel_values", None) |
| | num_patches = kwargs.pop("num_patches", None) |
| | image_embeds = kwargs.pop("image_embeds", None) |
| |
|
| | if pixel_values is None and image_embeds is None: |
| | return None |
| |
|
| | if pixel_values is not None: |
| | |
| | if pixel_values.dim() >= 3: |
| | pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) |
| | if patch_pixel_values is not None: |
| | |
| | |
| | patch_pixel_values = patch_pixel_values.view( |
| | -1, *patch_pixel_values.shape[-3:]) |
| | |
| | if patch_pixel_values.shape[0] == 0: |
| | patch_pixel_values = None |
| |
|
| | return StepVLImagePixelInputs( |
| | type="pixel_values", |
| | pixel_values=pixel_values.to(self.dtype).to(self.device), |
| | patch_pixel_values=patch_pixel_values.to(self.dtype).to( |
| | self.device) if patch_pixel_values is not None else None, |
| | num_patches=num_patches, |
| | ) |
| |
|
| | if image_embeds is not None: |
| | if image_embeds.dim() == 2 or image_embeds.dim() >= 3: |
| | image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) |
| | else: |
| | raise ValueError( |
| | f"Unexpected shape for image_embeds: {image_embeds.shape}") |
| |
|
| | return StepVLImageEmbeddingInputs( |
| | type="image_embeds", |
| | image_embeds=image_embeds.to(self.dtype).to(self.device), |
| | ) |
| | return None |
| | |
| | def _process_image_features(self, |
| | image_features: torch.Tensor) -> torch.Tensor: |
| | B, P = image_features.shape[:2] |
| | HW = int(P ** 0.5) |
| | image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) |
| | image_features = self.vision_model.vit_downsampler1(image_features) |
| | image_features = self.vision_model.vit_downsampler2(image_features) |
| |
|
| | B, C, HW, HW = image_features.shape |
| | image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1) |
| | image_features = self.vit_large_projector(image_features) |
| | return image_features |
| |
|
| | def _get_vision_model_output(self, |
| | input_tensor: torch.Tensor) -> torch.Tensor: |
| | return self.vision_model(input_tensor) |
| |
|
| | def _process_image_input( |
| | self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]: |
| |
|
| | if image_input["type"] == "image_embeds": |
| | image_features = image_input["image_embeds"] |
| | else: |
| | image_features = self._get_vision_model_output( |
| | image_input["pixel_values"]) |
| | patch_image_features = self._get_vision_model_output( |
| | image_input["patch_pixel_values"] |
| | ) if image_input["patch_pixel_values"] is not None else None |
| | num_patches = image_input["num_patches"] |
| |
|
| | image_features = self._process_image_features(image_features) |
| | patch_image_features = self._process_image_features( |
| | patch_image_features) if patch_image_features is not None else None |
| |
|
| | merged_image_features = [] |
| | cur_patch_idx = 0 |
| | for i, num_patch in enumerate(num_patches): |
| | cur_feature = [] |
| | if num_patch > 0: |
| | patch_slice = patch_image_features[ |
| | cur_patch_idx:cur_patch_idx + num_patch] |
| | cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) |
| | cur_feature.append(image_features[i].view( |
| | -1, image_features.shape[-1])) |
| | cur_patch_idx += num_patch |
| | merged_image_features.append( |
| | torch.cat(cur_feature) if len(cur_feature) > |
| | 1 else cur_feature[0]) |
| | |
| | return merged_image_features |
| | |
| | def get_multimodal_embeddings(self, **kwargs): |
| | image_input = self._parse_and_validate_image_input(**kwargs) |
| | if image_input is None: |
| | return None |
| | vision_embeddings = self._process_image_input(image_input) |
| | return vision_embeddings |
| |
|
| | @can_return_tuple |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | logits_to_keep: Union[int, torch.Tensor] = 0, |
| | images: Optional[list[Image.Image]] = None, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> Union[tuple, StepVLCausalLMOutputWithPast]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | Example: |
| | ```python |
| | >>> from transformers import AutoTokenizer, Llama4ForCausalLM |
| | >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
| | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
| | >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| | >>> inputs = tokenizer(prompt, return_tensors="pt") |
| | >>> # Generate |
| | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | |
| | if inputs_embeds is None: |
| | input_ids = input_ids |
| | vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
| | inputs_embeds = self.get_input_embeddings(input_ids, |
| | vision_embeddings) |
| | input_ids = None |
| | |
| | outputs = self.language_model( |
| | input_ids=None, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| | output = StepVLCausalLMOutputWithPast( |
| | last_hidden_state=outputs.last_hidden_state, |
| | past_key_values=outputs.past_key_values, |
| | attentions=outputs.attentions, |
| | |
| | ) |
| | return output if return_dict else output.to_tuple() |
| |
|
| |
|
| |
|
| | class Step3VL10BForCausalLM(StepRoboticsPreTrainedModel, GenerationMixin): |
| | _checkpoint_conversion_mapping = { |
| | "^vision_model": "model.vision_model", |
| | r"^model(?!\.(language_model|vision_model))": "model.language_model", |
| | "^vit_large_projector": "model.vit_large_projector" |
| | } |
| | _tied_weights_keys = ["lm_head.weight"] |
| | config: StepRoboticsConfig |
| |
|
| | def __init__(self, config: StepRoboticsConfig): |
| | super().__init__(config) |
| | self.model = StepRoboticsModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False) |
| |
|
| | self.post_init() |
| | |
| | def get_input_embeddings(self): |
| | return self.model.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.set_input_embeddings(value) |
| |
|
| | def get_output_embeddings(self): |
| | return self.model.get_output_embeddings() |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.model.set_output_embeddings(new_embeddings) |
| |
|
| | def set_decoder(self, decoder): |
| | self.model.set_decoder(decoder) |
| |
|
| | def get_decoder(self): |
| | return self.model.get_decoder() |
| | |
| | @property |
| | def language_model(self): |
| | return self.model.language_model |
| |
|
| | @property |
| | def visual(self): |
| | return self.model.visual |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | num_patches = None, |
| | patch_pixel_values = None, |
| | patch_newline_mask = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> Union[tuple, StepVLCausalLMOutputWithPast]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | Example: |
| | ```python |
| | >>> from PIL import Image |
| | >>> import requests |
| | >>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
| | >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
| | >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
| | >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" |
| | >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
| | >>> image = Image.open(requests.get(url, stream=True).raw) |
| | >>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
| | >>> # Generate |
| | >>> generate_ids = model.generate(**inputs, max_new_tokens=15) |
| | >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" |
| | ```""" |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | num_patches = num_patches, |
| | patch_pixel_values = patch_pixel_values, |
| | patch_newline_mask=patch_newline_mask, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = outputs.last_hidden_state |
| | logits = self.lm_head(hidden_states) |
| |
|
| | los = None |
| | if labels is not None: |
| | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
| |
|
| | return StepVLCausalLMOutputWithPast( |
| | logits=logits, |
| | ) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | past_key_values=None, |
| | inputs_embeds=None, |
| | pixel_values=None, |
| | attention_mask=None, |
| | cache_position=None, |
| | logits_to_keep=None, |
| | **kwargs, |
| | ): |
| | |
| |
|
| | model_inputs = super().prepare_inputs_for_generation( |
| | input_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | logits_to_keep=logits_to_keep, |
| | **kwargs, |
| | ) |
| |
|
| | if cache_position[0] == 0: |
| | |
| | |
| | model_inputs["pixel_values"] = pixel_values |
| |
|
| | return model_inputs |
| | |
| | def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]: |
| | if key.startswith("language_model."): |
| | return key[len("language_model."):], True |
| | |
| | return key, False |
| |
|
| |
|