import pandas as pd import threading import time import os from pathlib import Path from typing import Dict, List, Tuple, Union, Any, Optional, Callable import gradio as gr from ..models.model_manager import ModelManager from ..utils.data_processing import extract_file_dict, validate_data, extract_binary_output from ..config.config_manager import ConfigManager from ..utils.metrics import create_accuracy_table from datetime import datetime import boto3 class InferenceEngine: """Engine for handling batch inference and processing control.""" def __init__(self, model_manager: ModelManager, config_manager: ConfigManager): """ Initialize the inference engine. Args: model_manager: Model manager instance config_manager: Configuration manager instance """ self.model_manager = model_manager self.config_manager = config_manager self.processing_lock = threading.Lock() self.stop_processing = False self.full_df = None # Store full dataframe with image paths def set_stop_flag(self) -> str: """Set the global stop flag to interrupt processing.""" with self.processing_lock: self.stop_processing = True print("🛑 Stop signal received. Processing will halt after current image...") return "🛑 Stopping process... Please wait for current image to complete." def reset_stop_flag(self) -> None: """Reset the global stop flag before starting new processing.""" with self.processing_lock: self.stop_processing = False def check_stop_flag(self) -> bool: """Check if processing should be stopped.""" with self.processing_lock: return self.stop_processing def _should_load_model(self, model_selection: str, quantization_type: str) -> bool: """ Check if we need to load the model. Args: model_selection: Selected model name quantization_type: Selected quantization type Returns: True if model needs to be loaded, False otherwise """ # If no model is loaded, we need to load if not self.model_manager.current_model or not self.model_manager.current_model.is_model_loaded(): return True # If different model is selected, we need to load if self.model_manager.current_model_name != model_selection: return True # If same model but different quantization, we need to reload if self.model_manager.current_model.current_quantization != quantization_type: return True return False def _ensure_correct_model_loaded(self, model_selection: str, quantization_type: str, progress: gr.Progress()) -> None: """ Ensure the correct model with correct quantization is loaded. Args: model_selection: Selected model name quantization_type: Selected quantization type progress: Gradio progress object """ if self._should_load_model(model_selection, quantization_type): progress(0, desc=f"🚀 Loading {model_selection} ({quantization_type})...") print(f"🚀 Loading {model_selection} with {quantization_type}...") success = self.model_manager.load_model(model_selection, quantization_type) if not success: raise Exception(f"Failed to load model {model_selection} with {quantization_type}") else: print(f"✅ Correct model already loaded: {model_selection} with {quantization_type}") def process_folder_input( self, folder_path: List[Path], prompt: str, quantization_type: str, model_selection: str, progress: gr.Progress() ) -> Tuple[Any, ...]: """ Process input folder with images and optional CSV. Args: folder_path: List of Path objects from Gradio prompt: Text prompt for inference quantization_type: Model quantization type model_selection: Selected model name progress: Gradio progress object Returns: Tuple of UI update states and results """ # Reset stop flag at the beginning of processing self.reset_stop_flag() # Extract file dictionary file_dict = extract_file_dict(folder_path) # Print all file names for debug for fname in file_dict: print(fname) validation_result, message = validate_data(file_dict) # Handle different validation results if validation_result == False: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), message, gr.update(visible=False), "" elif validation_result in ["no_csv", "multiple_csv"]: return self._process_without_csv(file_dict, prompt, quantization_type, model_selection, progress) else: return self._process_with_csv(file_dict, prompt, quantization_type, model_selection, progress) def _process_without_csv( self, file_dict: Dict[str, Path], prompt: str, quantization_type: str, model_selection: str, progress: gr.Progress() ) -> Tuple[Any, ...]: """Process images without CSV file.""" image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] image_file_dict = {fname: file_dict[fname] for fname in file_dict if any(fname.lower().endswith(ext) for ext in image_exts)} filtered_rows = [] total_images = len(image_file_dict) if total_images == 0: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No image files found.", gr.update(visible=False), "" # Ensure correct model is loaded self._ensure_correct_model_loaded(model_selection, quantization_type, progress) # Initialize progress progress(0, desc=f"🚀 Starting to process {total_images} images...") print(f"Starting to process {total_images} images with {model_selection}...") for idx, (img_name, img_path) in enumerate(image_file_dict.items()): # Check stop flag before processing each image if self.check_stop_flag(): print(f"🛑 Processing stopped by user at image {idx + 1}/{total_images}") # Add remaining images as "Not processed" entries for remaining_idx, (remaining_name, remaining_path) in enumerate(list(image_file_dict.items())[idx:]): filtered_rows.append({ 'S.No': idx + remaining_idx + 1, 'Image Name': remaining_name, 'Ground Truth': '', 'Binary Output': 'Not processed (stopped)', 'Model Output': 'Processing stopped by user', 'Image Path': str(remaining_path) }) display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] self.full_df = pd.DataFrame(filtered_rows) final_message = f"🛑 Processing stopped by user. Completed {idx}/{total_images} images." print(final_message) return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message try: # Update progress with current image info current_progress = idx / total_images progress_msg = f"🔄 Processing image {idx + 1}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"🔄 Processing image {idx + 1}/{total_images}: {img_name}" progress(current_progress, desc=progress_msg) print(progress_msg) # Use model inference model_output = self.model_manager.inference(str(img_path), prompt) if prompt else "No prompt provided" # Extract binary output (no ground truth available for file-based processing) binary_output = extract_binary_output(model_output, "", []) filtered_rows.append({ 'S.No': idx + 1, 'Image Name': img_name, 'Ground Truth': '', # Empty for manual input 'Binary Output': binary_output, 'Model Output': model_output, 'Image Path': str(img_path) }) # Update progress after successful processing current_progress = (idx + 1) / total_images progress_msg = f"✅ Completed {idx + 1}/{total_images} images" progress(current_progress, desc=progress_msg) print(f"Successfully processed image {idx + 1} of {total_images}") except Exception as e: print(f"Error processing image {idx + 1} of {total_images}: {str(e)}") filtered_rows.append({ 'S.No': idx + 1, 'Image Name': img_name, 'Ground Truth': '', 'Binary Output': 'Enter the output manually', # Default for errors 'Model Output': f"Error: {str(e)}", 'Image Path': str(img_path) }) # Update progress even for errors current_progress = (idx + 1) / total_images progress_msg = f"⚠️ Processed {idx + 1}/{total_images} images (with errors)" progress(current_progress, desc=progress_msg) # Check if processing was completed or stopped if self.check_stop_flag(): final_message = f"🛑 Processing stopped by user. Completed {len(filtered_rows)}/{total_images} images." else: final_message = f"🎉 Successfully completed processing all {total_images} images!" display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] # Save the full dataframe (with Image Path) for preview self.full_df = pd.DataFrame(filtered_rows) self.save_results_to_s3(display_df) print(final_message) # Make the table editable for ground truth input return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message def _process_with_csv( self, file_dict: Dict[str, Path], prompt: str, quantization_type: str, model_selection: str, progress: gr.Progress() ) -> Tuple[Any, ...]: """Process images with CSV file.""" csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')] csv_file = file_dict[csv_files[0]] df = pd.read_csv(csv_file) # Collect all ground truth values for unique keyword extraction all_ground_truths = [str(row['Ground Truth']) for idx, row in df.iterrows() if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()] # Find image files image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] image_file_dict = {fname: file_dict[fname] for fname in file_dict if any(fname.lower().endswith(ext) for ext in image_exts)} # Only keep rows where image file exists filtered_rows = [] matching_images = [row for idx, row in df.iterrows() if row['Image Name'] in image_file_dict] total_images = len(matching_images) if total_images == 0: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No matching images found for entries in CSV.", gr.update(visible=False), "" # Ensure correct model is loaded self._ensure_correct_model_loaded(model_selection, quantization_type, progress) # Initialize progress progress(0, desc=f"🚀 Starting to process {total_images} images...") print(f"Starting to process {total_images} images with {model_selection}...") processed_count = 0 for idx, row in df.iterrows(): img_name = row['Image Name'] if img_name in image_file_dict: # Check stop flag before processing each image if self.check_stop_flag(): print(f"🛑 Processing stopped by user at image {processed_count + 1}/{total_images}") # Add remaining unprocessed images for remaining_idx, remaining_row in df.iloc[idx:].iterrows(): if remaining_row['Image Name'] in image_file_dict: filtered_rows.append({ 'S.No': len(filtered_rows) + 1, 'Image Name': remaining_row['Image Name'], 'Ground Truth': remaining_row['Ground Truth'], 'Binary Output': 'Not processed (stopped)', 'Model Output': 'Processing stopped by user', 'Image Path': str(image_file_dict[remaining_row['Image Name']]) }) display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] self.full_df = pd.DataFrame(filtered_rows) final_message = f"🛑 Processing stopped by user. Completed {processed_count}/{total_images} images." print(final_message) return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message try: processed_count += 1 # Update progress with current image info current_progress = (processed_count - 1) / total_images progress_msg = f"🔄 Processing image {processed_count}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"🔄 Processing image {processed_count}/{total_images}: {img_name}" progress(current_progress, desc=progress_msg) print(progress_msg) # Use model inference model_output = self.model_manager.inference(str(image_file_dict[img_name]), prompt) # Extract binary output using ground truth and all ground truths for keyword extraction ground_truth = str(row['Ground Truth']) if pd.notna(row['Ground Truth']) else "" binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths) filtered_rows.append({ 'S.No': len(filtered_rows) + 1, 'Image Name': img_name, 'Ground Truth': row['Ground Truth'], 'Binary Output': binary_output, 'Model Output': model_output, 'Image Path': str(image_file_dict[img_name]) }) # Update progress after successful processing current_progress = processed_count / total_images progress_msg = f"✅ Completed {processed_count}/{total_images} images" progress(current_progress, desc=progress_msg) print(f"Successfully processed image {processed_count} of {total_images}") except Exception as e: print(f"Error processing image {processed_count} of {total_images}: {str(e)}") filtered_rows.append({ 'S.No': len(filtered_rows) + 1, 'Image Name': img_name, 'Ground Truth': row['Ground Truth'], 'Binary Output': 'Enter the output manually', # Default for errors 'Model Output': f"Error: {str(e)}", 'Image Path': str(image_file_dict[img_name]) }) # Update progress even for errors current_progress = processed_count / total_images progress_msg = f"⚠️ Processed {processed_count}/{total_images} images (with errors)" progress(current_progress, desc=progress_msg) # Check if processing was completed or stopped if self.check_stop_flag(): final_message = f"🛑 Processing stopped by user. Completed {len([r for r in filtered_rows if 'stopped' not in r['Model Output']])}/{total_images} images." else: final_message = f"🎉 Successfully completed processing all {total_images} images!" display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] # Save the full dataframe (with Image Path) for preview self.full_df = pd.DataFrame(filtered_rows) self.save_results_to_s3(display_df) print(final_message) return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message def rerun_with_new_prompt( self, df: pd.DataFrame, new_prompt: str, quantization_type: str, model_selection: str, progress: gr.Progress() ) -> Tuple[Any, ...]: """Rerun processing with new prompt and clear accuracy data.""" if df is None or not new_prompt.strip(): return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "⚠️ Please provide a valid prompt" # Reset stop flag at the beginning of reprocessing self.reset_stop_flag() updated_df = df.copy() total_images = len(updated_df) # Collect all ground truth values for unique keyword extraction all_ground_truths = [str(row['Ground Truth']) for idx, row in updated_df.iterrows() if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()] # Get the full dataframe with image paths if self.full_df is None: return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "⚠️ No image data available" # Create a copy of the full dataframe to update updated_full_df = self.full_df.copy() # Ensure correct model is loaded self._ensure_correct_model_loaded(model_selection, quantization_type, progress) # Initialize progress progress(0, desc=f"🚀 Starting to reprocess {total_images} images with new prompt...") print(f"🚀 Starting to reprocess {total_images} images with new prompt...") for i in range(len(updated_df)): # Check stop flag before processing each image if self.check_stop_flag(): print(f"🛑 Reprocessing stopped by user at image {i + 1}/{total_images}") # Mark remaining images as not reprocessed in both dataframes for j in range(i, len(updated_df)): updated_df.iloc[j, updated_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user" updated_df.iloc[j, updated_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)" # Also update the full dataframe if j < len(updated_full_df): updated_full_df.iloc[j, updated_full_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user" updated_full_df.iloc[j, updated_full_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)" # Update the full_df reference self.full_df = updated_full_df final_message = f"🛑 Reprocessing stopped by user. Completed {i}/{total_images} images." print(final_message) return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message try: # Get image path from full_df image_path = self.full_df.iloc[i]['Image Path'] image_name = updated_df.iloc[i]['Image Name'] ground_truth = str(updated_df.iloc[i]['Ground Truth']) if pd.notna(updated_df.iloc[i]['Ground Truth']) else "" # Update progress with current image info current_progress = i / total_images progress_msg = f"🔄 Reprocessing image {i + 1}/{total_images}: {image_name[:30]}..." if len(image_name) > 30 else f"🔄 Reprocessing image {i + 1}/{total_images}: {image_name}" progress(current_progress, desc=progress_msg) print(progress_msg) # Use model inference with new prompt model_output = self.model_manager.inference(image_path, new_prompt) # Update both the display dataframe and the full dataframe updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = model_output updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = model_output # Extract binary output using ground truth and all ground truths for keyword extraction binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths) updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = binary_output updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = binary_output # Update progress after successful processing current_progress = (i + 1) / total_images progress_msg = f"✅ Completed {i + 1}/{total_images} images" progress(current_progress, desc=progress_msg) print(f"✅ Successfully reprocessed image {i + 1}/{total_images}") except Exception as e: print(f"❌ Error reprocessing image {i + 1}/{total_images}: {str(e)}") error_message = f"Error: {str(e)}" # Update both dataframes with error information updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = error_message updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = "Enter the output manually" updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = error_message updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = "Enter the output manually" # Update progress even for errors current_progress = (i + 1) / total_images progress_msg = f"⚠️ Processed {i + 1}/{total_images} images (with errors)" progress(current_progress, desc=progress_msg) # Update the full_df reference with the updated data self.full_df = updated_full_df # Check if reprocessing was completed or stopped if self.check_stop_flag(): final_message = f"🛑 Reprocessing stopped by user. Completed reprocessing for some images." else: final_message = f"🎉 Successfully completed reprocessing all {total_images} images with new prompt! Click 'Generate Metrics' to see accuracy data." self.save_results_to_s3(updated_full_df) print(final_message) # Return updated dataframe and clear accuracy data (hide section 3) return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message def save_results_to_s3(self, df): """Save results to S3 bucket.""" try: s3_bucket = os.getenv('AWS_BUCKET') prefix = os.getenv('AWS_PREFIX') s3_path = f"{prefix}/{datetime.now().date()}" date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") csv_file_name = f'{date_time}_model_output.csv' # create accuracy table metrics_df, _, cm_values = create_accuracy_table(df) # save metrics_df to text file text_file_name = f'{date_time}_evaluation_metrics.txt' # save metrics_df to text file with open(text_file_name, 'w') as f: f.write(metrics_df.to_string() + '\n\n') f.write(cm_values.to_string()) # save df to csv df.to_csv(csv_file_name, index=False) # upload files to s3 status = self.upload_file(text_file_name, s3_bucket, f"{s3_path}/{text_file_name}") print(f"Status of uploading {text_file_name} to {s3_bucket}/{s3_path}/{text_file_name}: {status}") status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}") print(f"Status of uploading {csv_file_name} to {s3_bucket}/{s3_path}/{csv_file_name}: {status}") # delete files from local os.remove(text_file_name) os.remove(csv_file_name) print(f"Deleted {text_file_name} and {csv_file_name}") except Exception as e: print(f"Error saving results to s3: {e}") if "No valid data" in str(e) or "Need at least 2 different" in str(e): df.to_csv(csv_file_name, index=False) status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}") print(f"Status of uploading only csv file to {s3_bucket}/{s3_path}/{csv_file_name}: {status}") os.remove(csv_file_name) print(f"Deleted {csv_file_name}") def upload_file(self,file_name, bucket, object_name=None): """Upload a file to an S3 bucket :param file_name: File to upload :param bucket: Bucket to upload to :param object_name: S3 object name. If not specified then file_name is used :return: True if file was uploaded, else False """ access_key = os.getenv('AWS_ACCESS_KEY_ID') secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') # If S3 object_name was not specified, use file_name if object_name is None: object_name = os.path.basename(file_name) # Upload the file s3_client = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key) try: response = s3_client.upload_file(file_name, bucket, object_name) except Exception as e: print(f"Error uploading {file_name} to s3: {e}") return False return True