| | from abc import ABC, abstractmethod
|
| | from typing import Any, Dict, List, Type, Union
|
| |
|
| | from pydantic import BaseModel
|
| | from app.utils.converter import to_snake_case
|
| |
|
| | from app.schemas.schema_tools import (
|
| | convert_attribute_to_model,
|
| | validate_json_data,
|
| | validate_json_schema,
|
| | )
|
| |
|
| |
|
| | def cf_style_to_pydantic_percentage_shema(
|
| | cf_style_schema: dict,
|
| | ) -> str:
|
| | """
|
| | Convert CF style schema to Pydantic schema
|
| | """
|
| | print(f'{cf_style_schema}')
|
| | attributes_line_in_product = []
|
| | values_classes = []
|
| | for attribute, attribute_info in cf_style_schema.items():
|
| | multiple = False
|
| | if "list" in attribute_info.data_type:
|
| | multiple = True
|
| | else:
|
| | multiple = False
|
| | class_name = "Class_" + attribute.capitalize()
|
| | multiple_desc = "multi-label classification" if multiple else "single-label classification"
|
| | attribute_desc = attribute_info.description
|
| | attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'
|
| |
|
| | class_code = f"""
|
| | class {class_name}(BaseModel):
|
| |
|
| | """
|
| | for value in attribute_info.allowed_values:
|
| | class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
|
| |
|
| | values_classes.append(class_code)
|
| | attributes_line_in_product.append(attribute_line)
|
| | attributes_line = "\n ".join(attributes_line_in_product)
|
| | values_classes_code = "\n".join(values_classes)
|
| | pydantic_schema = f"""
|
| | from pydantic import BaseModel, Field
|
| | {values_classes_code}
|
| | class Product(BaseModel):
|
| | {attributes_line}
|
| | """
|
| | pydantic_code = pydantic_schema.strip()
|
| | exec(pydantic_code, globals())
|
| | return Product
|
| |
|
| | def build_attributes_types_prompt(attributes):
|
| | list_of_types_prompt = "\n List of attributes types:\n"
|
| | for key, value in attributes.items():
|
| | list_of_types_prompt += f"- {key}: {value.data_type}\n"
|
| | return list_of_types_prompt
|
| |
|
| |
|
| | class BaseAttributionService(ABC):
|
| | @abstractmethod
|
| | async def extract_attributes(
|
| | self,
|
| | attributes_model: Type[BaseModel],
|
| | ai_model: str,
|
| | img_urls: List[str],
|
| | product_taxonomy: str,
|
| | pil_images: List[Any] = None,
|
| | appended_prompt: str = "",
|
| | ) -> Dict[str, Any]:
|
| | pass
|
| |
|
| | @abstractmethod
|
| | async def reevaluate_atributes(
|
| | self,
|
| | attributes_model: Type[BaseModel],
|
| | ai_model: str,
|
| | img_urls: List[str],
|
| | product_taxonomy: str,
|
| | pil_images: List[Any] = None,
|
| | appended_prompt: str = "",
|
| | ) -> Dict[str, Any]:
|
| | pass
|
| |
|
| | @abstractmethod
|
| | async def follow_schema(
|
| | self, schema: Dict[str, Any], data: Dict[str, Any]
|
| | ) -> Dict[str, Any]:
|
| | pass
|
| |
|
| | async def extract_attributes_with_validation(
|
| | self,
|
| | attributes: Dict[str, Any],
|
| | ai_model: str,
|
| | img_urls: List[str],
|
| | product_taxonomy: str,
|
| | product_data: Dict[str, Union[str, List[str]]],
|
| | pil_images: List[Any] = None,
|
| | img_paths: List[str] = None,
|
| | appended_prompt = str
|
| | ) -> Dict[str, Any]:
|
| |
|
| |
|
| |
|
| | forward_mapping = {}
|
| | reverse_mapping = {}
|
| | for i, key in enumerate(attributes.keys()):
|
| | forward_mapping[key] = f'{to_snake_case(key)}_{i}'
|
| | reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
|
| |
|
| | transformed_attributes = {}
|
| | for key, value in attributes.items():
|
| | transformed_attributes[forward_mapping[key]] = value
|
| |
|
| | attributes_types_prompt = build_attributes_types_prompt(attributes)
|
| |
|
| |
|
| | attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
|
| | schema = attributes_percentage_model.model_json_schema()
|
| | data = await self.extract_attributes(
|
| | attributes_percentage_model,
|
| | ai_model,
|
| | img_urls,
|
| | product_taxonomy if product_taxonomy != "" else "main",
|
| | product_data,
|
| |
|
| | img_paths=img_paths,
|
| | appended_prompt=attributes_types_prompt
|
| | )
|
| | validate_json_data(data, schema)
|
| |
|
| | str_data = str(data)
|
| | reevaluate_data = await self.reevaluate_atributes(
|
| | attributes_percentage_model,
|
| | ai_model,
|
| | img_urls,
|
| | product_taxonomy if product_taxonomy != "" else "main",
|
| | str_data,
|
| |
|
| | img_paths=img_paths,
|
| | appended_prompt=attributes_types_prompt
|
| | )
|
| |
|
| | init_reevaluate_data = {}
|
| | for field_name, field in attributes_percentage_model.model_fields.items():
|
| | print(f"{field_name}: {field.description}")
|
| | if "single-label" in field.description.lower():
|
| | max_percentage = 0
|
| | for k, v in reevaluate_data[field_name].items():
|
| | if v > max_percentage:
|
| | max_percentage = v
|
| | init_reevaluate_data[field_name] = k
|
| | elif "multi-label" in field.description.lower():
|
| | init_list = []
|
| | for k, v in reevaluate_data[field_name].items():
|
| | if v >= 60:
|
| | init_list.append(k)
|
| | init_reevaluate_data[field_name] = init_list
|
| | else:
|
| | assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"
|
| |
|
| |
|
| | reverse_data = {}
|
| | for key, value in init_reevaluate_data.items():
|
| | reverse_data[reverse_mapping[key]] = value
|
| | return data, reverse_data
|
| |
|
| | async def follow_schema_with_validation(
|
| | self, schema: Dict[str, Any], data: Dict[str, Any]
|
| | ) -> Dict[str, Any]:
|
| | validate_json_schema(schema)
|
| | data = await self.follow_schema(schema, data)
|
| | validate_json_data(data, schema)
|
| | return data
|
| |
|