| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import argparse |
| |
|
| | import torch |
| |
|
| | from transformers import ( |
| | AddedToken, |
| | AutoConfig, |
| | AutoTokenizer, |
| | ) |
| | from configuration_llava import LlavaConfig |
| | from modeling_llava import LlavaForConditionalGeneration |
| |
|
| |
|
| | KEYS_TO_MODIFY_MAPPING = { |
| | "transformer.vision_tower.vision_tower": "vision_model", |
| | "transformer.mm_projector": "multi_modal_projector", |
| | "transformer": "language_model.transformer", |
| | "lm_head": "language_model.lm_head", |
| | "model.model": "language_model.transformer", |
| | "multi_modal_projector.0": "multi_modal_projector.linear_1", |
| | "multi_modal_projector.2": "multi_modal_projector.linear_2", |
| | } |
| |
|
| |
|
| | def convert_state_dict_to_hf(state_dict): |
| | new_state_dict = {} |
| | for key, value in state_dict.items(): |
| | for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |
| | if key_to_modify in key: |
| | key = key.replace(key_to_modify, new_key) |
| |
|
| | new_state_dict[key] = value |
| | return new_state_dict |
| |
|
| |
|
| | def convert_llava_llama_to_hf(text_model_id, vision_model_id, projector_tokens_num, output_path, old_state_dict_path): |
| | torch.set_default_dtype(torch.float16) |
| | text_config = AutoConfig.from_pretrained(text_model_id, trust_remote_code=True) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(text_model_id) |
| | tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True) |
| | tokenizer.add_special_tokens({"pad_token": "<pad>"}) |
| |
|
| | config = LlavaConfig(text_config=text_config, vocab_size=51200, vision_tower_name=vision_model_id, projector_tokens_num=projector_tokens_num) |
| | config.text_config.vocab_size = config.vocab_size |
| |
|
| | with torch.device("cuda"): |
| | model = LlavaForConditionalGeneration(config) |
| | |
| | state_dict = torch.load(old_state_dict_path, map_location="cpu") |
| | state_dict = convert_state_dict_to_hf(state_dict) |
| | model.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | model.config.vocab_size = model.config.vocab_size |
| | model.config.text_config.vocab_size = model.config.text_config.vocab_size |
| |
|
| | model.save_pretrained(output_path) |
| | tokenizer.save_pretrained(output_path) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--text_model_id", |
| | help="Hub location of the text model", |
| | ) |
| | parser.add_argument( |
| | "--vision_model_id", |
| | help="Hub location of the vision model", |
| | ) |
| | parser.add_argument( |
| | "--output_path", |
| | help="Location of the converted model", |
| | ) |
| | parser.add_argument( |
| | "--old_state_dict_path", |
| | help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", |
| | ) |
| | parser.add_argument( |
| | "--tokens_num", |
| | type=int, |
| | default=1 |
| | ) |
| | args = parser.parse_args() |
| | convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.tokens_num, args.output_path, args.old_state_dict_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |