|
|
import os |
|
|
import yaml |
|
|
import json |
|
|
import regex |
|
|
from uuid import uuid4 |
|
|
from datetime import datetime, date |
|
|
from pydantic import BaseModel |
|
|
from pydantic_core import PydanticUndefined, ValidationError |
|
|
from typing import Union, Type, Any, List, Dict, get_origin, get_args |
|
|
|
|
|
from .logging import logger |
|
|
|
|
|
def make_parent_folder(path: str): |
|
|
|
|
|
dir_folder = os.path.dirname(path) |
|
|
if len(dir_folder.strip()) == 0: |
|
|
return |
|
|
if not os.path.exists(dir_folder): |
|
|
os.makedirs(dir_folder, exist_ok=True) |
|
|
|
|
|
def generate_id(): |
|
|
return uuid4().hex |
|
|
|
|
|
def get_timestamp(): |
|
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
def load_json(path: str, type: str="json"): |
|
|
|
|
|
assert type in ["json", "jsonl"] |
|
|
if not os.path.exists(path=path): |
|
|
logger.error(f"File \"{path}\" does not exists!") |
|
|
|
|
|
if type == "json": |
|
|
try: |
|
|
with open(path, "r", encoding="utf-8") as file: |
|
|
|
|
|
outputs = json.loads(file.read()) |
|
|
except Exception: |
|
|
logger.error(f"File \"{path}\" is not a valid json file!") |
|
|
|
|
|
elif type == "jsonl": |
|
|
outputs = [] |
|
|
with open(path, "r", encoding="utf-8") as fin: |
|
|
for line in fin: |
|
|
|
|
|
outputs.append(json.loads(line)) |
|
|
else: |
|
|
outputs = [] |
|
|
|
|
|
return outputs |
|
|
|
|
|
def save_json(data, path: str, type: str="json", use_indent: bool=True) -> str: |
|
|
|
|
|
""" |
|
|
save data to a json file |
|
|
|
|
|
Args: |
|
|
data: The json data to be saved. It can be a JSON str or a Serializable object when type=="json" or a list of JSON str or Serializable object when type=="jsonl". |
|
|
path(str): The path of the saved json file. |
|
|
type(str): The type of the json file, chosen from ["json" or "jsonl"]. |
|
|
use_indent: Whether to use indent when saving the json file. |
|
|
|
|
|
Returns: |
|
|
path: the path where the json data is saved. |
|
|
""" |
|
|
|
|
|
assert type in ["json", "jsonl"] |
|
|
make_parent_folder(path) |
|
|
|
|
|
if type == "json": |
|
|
with open(path, "w", encoding="utf-8") as fout: |
|
|
if use_indent: |
|
|
fout.write(data if isinstance(data, str) else json.dumps(data, indent=4)) |
|
|
else: |
|
|
fout.write(data if isinstance(data, str) else json.dumps(data)) |
|
|
|
|
|
elif type == "jsonl": |
|
|
with open(path, "w", encoding="utf-8") as fout: |
|
|
for item in data: |
|
|
fout.write("{}\n".format(item if isinstance(item, str) else json.dumps(item))) |
|
|
|
|
|
return path |
|
|
|
|
|
def escape_json_values(string: str) -> str: |
|
|
|
|
|
def escape_value(match): |
|
|
raw_value = match.group(1) |
|
|
raw_value = raw_value.replace('\n', '\\n') |
|
|
return f'"{raw_value}"' |
|
|
|
|
|
def fix_json(match): |
|
|
raw_key = match.group(1) |
|
|
raw_value = match.group(2) |
|
|
raw_value = raw_value.replace("\n", "\\n") |
|
|
raw_value = regex.sub(r'(?<!\\)"', '\\\"', raw_value) |
|
|
return f'"{raw_key}": "{raw_value}"' |
|
|
|
|
|
try: |
|
|
json.loads(string) |
|
|
return string |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
string = regex.sub(r'(?<!\\)"', '\\\"', string) |
|
|
pattern_key = r'\\"([^"]+)\\"(?=\s*:\s*)' |
|
|
string = regex.sub(pattern_key, r'"\1"', string) |
|
|
pattern_value = r'(?<=:\s*)\\"((?:\\.|[^"\\])*)\\"' |
|
|
string = regex.sub(pattern_value, escape_value, string, flags=regex.DOTALL) |
|
|
pattern_nested_json = r'"([^"]+)"\s*:\s*\\"([^"]*\{+[\S\s]*?\}+)[\r\n\\n]*"' |
|
|
string = regex.sub(pattern_nested_json, fix_json, string, flags=regex.DOTALL) |
|
|
json.loads(string) |
|
|
return string |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
return string |
|
|
|
|
|
def fix_json_booleans(string: str) -> str: |
|
|
""" |
|
|
Finds and replaces isolated "True" and "False" with "true" and "false". |
|
|
|
|
|
The '\b' in the regex stands for a "word boundary", which ensures that |
|
|
we only match the full words and not substrings like "True" in "IsTrue". |
|
|
|
|
|
Args: |
|
|
json_string (str): The input JSON string. |
|
|
|
|
|
Returns: |
|
|
str: The modified JSON string with booleans in lowercase. |
|
|
""" |
|
|
|
|
|
|
|
|
modified_string = regex.sub(r'\bTrue\b', 'true', string) |
|
|
modified_string = regex.sub(r'\bFalse\b', 'false', modified_string) |
|
|
return modified_string |
|
|
|
|
|
|
|
|
def fix_json(string: str) -> str: |
|
|
string = fix_json_booleans(string) |
|
|
string = escape_json_values(string) |
|
|
return string |
|
|
|
|
|
|
|
|
def parse_json_from_text(text: str) -> List[str]: |
|
|
""" |
|
|
Autoregressively extract JSON object from text |
|
|
|
|
|
Args: |
|
|
text (str): a text that includes JSON data |
|
|
|
|
|
Returns: |
|
|
List[str]: a list of parsed JSON data |
|
|
""" |
|
|
json_pattern = r"""(?:\{(?:[^{}]*|(?R))*\}|\[(?:[^\[\]]*|(?R))*\])""" |
|
|
pattern = regex.compile(json_pattern, regex.VERBOSE) |
|
|
matches = pattern.findall(text) |
|
|
matches = [fix_json(match) for match in matches] |
|
|
return matches |
|
|
|
|
|
|
|
|
def parse_xml_from_text(text: str, label: str) -> List[str]: |
|
|
pattern = rf"<{label}>(.*?)</{label}>" |
|
|
matches: List[str] = regex.findall(pattern, text, regex.DOTALL) |
|
|
values = [] |
|
|
if matches: |
|
|
values = [match.strip() for match in matches] |
|
|
return values |
|
|
|
|
|
def parse_data_from_text(text: str, datatype: str): |
|
|
|
|
|
if datatype == "str": |
|
|
data = text |
|
|
elif datatype == "int": |
|
|
data = int(text) |
|
|
elif datatype == "float": |
|
|
data = float(text) |
|
|
elif datatype == "bool": |
|
|
data = text.lower() in ("true", "yes", "1", "on", "True") |
|
|
elif datatype == "list": |
|
|
data = eval(text) |
|
|
elif datatype == "dict": |
|
|
data = eval(text) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return text |
|
|
return data |
|
|
|
|
|
def parse_json_from_llm_output(text: str) -> dict: |
|
|
""" |
|
|
Extract JSON str from LLM outputs and convert it to dict. |
|
|
""" |
|
|
json_list = parse_json_from_text(text=text) |
|
|
if json_list: |
|
|
json_text = json_list[0] |
|
|
try: |
|
|
data = yaml.safe_load(json_text) |
|
|
except Exception: |
|
|
raise ValueError(f"The following generated text is not a valid JSON string!\n{json_text}") |
|
|
else: |
|
|
raise ValueError(f"The follwoing generated text does not contain JSON string!\n{text}") |
|
|
return data |
|
|
|
|
|
def extract_code_blocks(text: str, return_type: bool = False) -> Union[List[str], List[tuple]]: |
|
|
""" |
|
|
Extract code blocks from text enclosed in triple backticks. |
|
|
|
|
|
Args: |
|
|
text (str): The text containing code blocks |
|
|
return_type (bool): If True, returns tuples of (language, code), otherwise just code |
|
|
|
|
|
Returns: |
|
|
Union[List[str], List[tuple]]: Either list of code blocks or list of (language, code) tuples |
|
|
""" |
|
|
|
|
|
code_block_pattern = r"```((?:[a-zA-Z]*)?)\n*(.*?)\n*```" |
|
|
|
|
|
matches = regex.findall(code_block_pattern, text, regex.DOTALL) |
|
|
|
|
|
|
|
|
if not matches: |
|
|
return [(None, text.strip())] if return_type else [text.strip()] |
|
|
|
|
|
if return_type: |
|
|
|
|
|
return [(lang.strip() or None, code.strip()) for lang, code in matches] |
|
|
else: |
|
|
|
|
|
return [code.strip() for _, code in matches] |
|
|
|
|
|
def remove_repr_quotes(json_string): |
|
|
pattern = r'"([A-Za-z_]\w*\(.*\))"' |
|
|
result = regex.sub(pattern, r'\1', json_string) |
|
|
return result |
|
|
|
|
|
def custom_serializer(obj: Any): |
|
|
|
|
|
if isinstance(obj, (bytes, bytearray)): |
|
|
return obj.decode() |
|
|
if isinstance(obj, (datetime, date)): |
|
|
return obj.strftime("%Y-%m-%d %H:%M:%S") |
|
|
if isinstance(obj, set): |
|
|
return list(obj) |
|
|
if hasattr(obj, "read") and hasattr(obj, "name"): |
|
|
return f"<FileObject name={getattr(obj, 'name', 'unknown')}>" |
|
|
if callable(obj): |
|
|
return obj.__name__ |
|
|
if hasattr(obj, "__class__"): |
|
|
return obj.__repr__() if hasattr(obj, "__repr__") else obj.__class__.__name__ |
|
|
|
|
|
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_type_name(typ): |
|
|
|
|
|
origin = get_origin(typ) |
|
|
if origin is None: |
|
|
return getattr(typ, "__name__", str(typ)) |
|
|
|
|
|
if origin is Union: |
|
|
args = get_args(typ) |
|
|
return " | ".join(get_type_name(arg) for arg in args) |
|
|
|
|
|
if origin is type: |
|
|
return f"Type[{get_type_name(args[0])}]" if args else "Type[Any]" |
|
|
|
|
|
if origin in (list, tuple): |
|
|
args = get_args(typ) |
|
|
return f"{origin.__name__}[{', '.join(get_type_name(arg) for arg in args)}]" |
|
|
|
|
|
if origin is dict: |
|
|
key_type, value_type = get_args(typ) |
|
|
return f"dict[{get_type_name(key_type)}, {get_type_name(value_type)}]" |
|
|
|
|
|
return str(origin) |
|
|
|
|
|
def get_pydantic_field_types(model: Type[BaseModel]) -> Dict[str, Union[str, dict]]: |
|
|
|
|
|
field_types = {} |
|
|
for field_name, field_info in model.model_fields.items(): |
|
|
field_type = field_info.annotation |
|
|
if hasattr(field_type, "model_fields"): |
|
|
field_types[field_name] = get_pydantic_field_types(field_type) |
|
|
else: |
|
|
type_name = get_type_name(field_type) |
|
|
field_types[field_name] = type_name |
|
|
|
|
|
return field_types |
|
|
|
|
|
def get_pydantic_required_field_types(model: Type[BaseModel]) -> Dict[str, str]: |
|
|
|
|
|
required_field_types = {} |
|
|
for field_name, field_info in model.model_fields.items(): |
|
|
if not field_info.is_required(): |
|
|
continue |
|
|
if field_info.default is not PydanticUndefined or field_info.default_factory is not None: |
|
|
continue |
|
|
field_type = field_info.annotation |
|
|
type_name = get_type_name(field_type) |
|
|
required_field_types[field_name] = type_name |
|
|
|
|
|
return required_field_types |
|
|
|
|
|
def format_pydantic_field_types(field_types: Dict[str, str]) -> str: |
|
|
|
|
|
output = ", ".join(f"\"{field_name}\": {field_type}" for field_name, field_type in field_types.items()) |
|
|
output = "{" + output + "}" |
|
|
return output |
|
|
|
|
|
def get_error_message(errors: List[Union[ValidationError, Exception]]) -> str: |
|
|
|
|
|
if not isinstance(errors, list): |
|
|
errors = [errors] |
|
|
|
|
|
validation_errors, exceptions = [], [] |
|
|
for error in errors: |
|
|
if isinstance(error, ValidationError): |
|
|
validation_errors.append(error) |
|
|
else: |
|
|
exceptions.append(error) |
|
|
|
|
|
message = "" |
|
|
if len(validation_errors) > 0: |
|
|
message += f" >>>>>>>> {len(validation_errors)} Validation Errors: <<<<<<<<\n\n" |
|
|
message += "\n\n".join([str(error) for error in validation_errors]) |
|
|
if len(exceptions) > 0: |
|
|
if len(message) > 0: |
|
|
message += "\n\n" |
|
|
message += f">>>>>>>> {len(exceptions)} Exception Errors: <<<<<<<<\n\n" |
|
|
message += "\n\n".join([str(type(error).__name__) + ": " +str(error) for error in exceptions]) |
|
|
return message |
|
|
|
|
|
def get_base_module_init_error_message(cls, data: Dict[str, Any], errors: List[Union[ValidationError, Exception]]) -> str: |
|
|
|
|
|
if not isinstance(errors, list): |
|
|
errors = [errors] |
|
|
|
|
|
message = f"Can not instantiate {cls.__name__} from: " |
|
|
formatted_data = json.dumps(data, indent=4, default=custom_serializer) |
|
|
formatted_data = remove_repr_quotes(formatted_data) |
|
|
message += formatted_data |
|
|
message += "\n\n" + get_error_message(errors) |
|
|
return message |
|
|
|
|
|
|