Upload modeling_skywork_chat.py
Browse files- modeling_skywork_chat.py +4 -1
modeling_skywork_chat.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import warnings
|
|
|
|
| 2 |
from typing import List, Optional, Tuple, Union
|
| 3 |
|
| 4 |
import torch.utils.checkpoint
|
|
@@ -251,7 +252,7 @@ class SkyworkChatModel(PreTrainedModel):
|
|
| 251 |
|
| 252 |
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
| 253 |
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
| 254 |
-
verbose=False):
|
| 255 |
|
| 256 |
if history is None and pixel_values is not None and '<image>' not in question:
|
| 257 |
question = '<image>\n' + question
|
|
@@ -275,6 +276,8 @@ class SkyworkChatModel(PreTrainedModel):
|
|
| 275 |
template.append_message(template.roles[0], question)
|
| 276 |
template.append_message(template.roles[1], None)
|
| 277 |
query = template.get_prompt()
|
|
|
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
if verbose and pixel_values is not None:
|
|
|
|
| 1 |
import warnings
|
| 2 |
+
import re
|
| 3 |
from typing import List, Optional, Tuple, Union
|
| 4 |
|
| 5 |
import torch.utils.checkpoint
|
|
|
|
| 252 |
|
| 253 |
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
| 254 |
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
| 255 |
+
verbose=False, mode="think"):
|
| 256 |
|
| 257 |
if history is None and pixel_values is not None and '<image>' not in question:
|
| 258 |
question = '<image>\n' + question
|
|
|
|
| 276 |
template.append_message(template.roles[0], question)
|
| 277 |
template.append_message(template.roles[1], None)
|
| 278 |
query = template.get_prompt()
|
| 279 |
+
if mode != "think":
|
| 280 |
+
query = re.sub(r'\n<think>', '', query, count=1)
|
| 281 |
|
| 282 |
|
| 283 |
if verbose and pixel_values is not None:
|