simran40 commited on
Commit
173cb04
Β·
verified Β·
1 Parent(s): 45001af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -110
app.py CHANGED
@@ -3,208 +3,322 @@ import fitz # PyMuPDF
3
  import re
4
  import faiss
5
  import numpy as np
6
-
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import pipeline
9
 
 
 
 
 
 
 
10
 
11
  # =================================================
12
  # MODEL LOADING (ONCE)
 
13
  # =================================================
14
 
15
- # Embedding model for semantic retrieval
16
- embedding_model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
17
-
18
- # Extractive QA model (accurate answers)
19
- qa_pipeline = pipeline(
20
- "question-answering",
21
- model="deepset/roberta-base-squad2",
22
- tokenizer="deepset/roberta-base-squad2"
23
- )
24
-
25
- # Summarization model (clean summary)
26
- summarizer = pipeline(
27
- "summarization",
28
- model="facebook/bart-large-cnn",
29
- tokenizer="facebook/bart-large-cnn"
30
- )
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  # =================================================
34
- # PDF PROCESSING
35
  # =================================================
36
 
37
  def extract_text_from_pdf(pdf_path):
 
38
  doc = fitz.open(pdf_path)
39
  text = ""
40
  for page in doc:
41
- text += page.get_text()
42
  return text
43
 
44
 
45
  def clean_text(text):
 
 
46
  text = re.sub(r"\s+", " ", text)
47
- text = re.sub(r"Table of contents.*?Introduction", "", text, flags=re.I)
48
- text = re.sub(r"\bPage \d+\b", "", text)
 
49
  return text.strip()
50
 
51
 
52
- def chunk_text(text, chunk_size=350, overlap=80):
 
53
  chunks = []
54
  start = 0
55
  while start < len(text):
56
  end = start + chunk_size
57
  chunks.append(text[start:end])
58
- start = end - overlap
59
  return chunks
60
 
61
 
62
- def chunk_text_for_summary(text, chunk_size=900, overlap=100):
 
63
  chunks = []
64
  start = 0
65
  while start < len(text):
66
  end = start + chunk_size
67
  chunks.append(text[start:end])
68
- start = end - overlap
69
  return chunks
70
 
71
 
72
  # =================================================
73
- # VECTOR DATABASE (FAISS)
74
  # =================================================
75
 
76
  def build_faiss_index(chunks):
77
- embeddings = embedding_model.encode(chunks)
 
 
78
  embeddings = np.array(embeddings).astype("float32")
 
 
79
  index = faiss.IndexFlatL2(embeddings.shape[1])
80
  index.add(embeddings)
 
81
  return index, chunks
82
 
83
 
84
  def retrieve_relevant_chunks(question, index, chunks, top_k=5):
 
 
 
 
 
 
85
  query_embedding = embedding_model.encode([question]).astype("float32")
 
 
86
  distances, indices = index.search(query_embedding, top_k)
87
 
88
  results = []
89
  for i, idx in enumerate(indices[0]):
 
90
  results.append((chunks[idx], distances[0][i]))
91
 
 
92
  results.sort(key=lambda x: x[1])
93
  return [r[0] for r in results]
94
 
95
 
96
  # =================================================
97
- # QUESTION ANSWERING (ACCURATE)
98
- # =================================================
99
-
100
- def generate_answer(question, context_chunks):
101
- best_answer = ""
102
- best_score = 0.0
103
-
104
- for chunk in context_chunks:
105
- result = qa_pipeline(
106
- question=question,
107
- context=chunk
108
- )
109
-
110
- if result["score"] > best_score:
111
- best_score = result["score"]
112
- best_answer = result["answer"]
113
-
114
- if best_score < 0.3 or best_answer.strip() == "":
115
- return "Information not found in the document."
116
-
117
- return best_answer
118
-
119
-
120
- # =================================================
121
- # SUMMARIZATION
122
- # =================================================
123
-
124
- def generate_summary(chunks):
125
- summaries = []
126
-
127
- for chunk in chunks:
128
- summary = summarizer(
129
- chunk,
130
- max_length=150,
131
- min_length=60,
132
- do_sample=False
133
- )[0]["summary_text"]
134
-
135
- summaries.append(summary)
136
-
137
- return " ".join(summaries)
138
-
139
-
140
- # =================================================
141
- # MAIN FUNCTIONS
142
  # =================================================
143
 
144
- def pdf_qa(pdf_file, question):
145
- if pdf_file is None or question.strip() == "":
146
- return "Please upload a PDF and ask a question."
147
-
148
- text = extract_text_from_pdf(pdf_file.name)
149
- text = clean_text(text)
150
-
151
- chunks = chunk_text(text)
152
- index, chunks = build_faiss_index(chunks)
153
-
154
- relevant_chunks = retrieve_relevant_chunks(question, index, chunks)
155
- return generate_answer(question, relevant_chunks)
156
 
157
-
158
- def pdf_summary(pdf_file):
 
159
  if pdf_file is None:
 
 
 
 
160
  return "Please upload a PDF document."
161
 
162
- text = extract_text_from_pdf(pdf_file.name)
163
- text = clean_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- chunks = chunk_text_for_summary(text)
166
- return generate_summary(chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  # =================================================
170
- # GRADIO UI (QA + SUMMARY)
171
  # =================================================
172
 
173
  with gr.Blocks() as demo:
174
 
175
  gr.Markdown("""
176
- # πŸ“„ PDF Question Answering & Summarization System
 
 
 
 
 
 
 
 
177
 
178
- This system supports **two functionalities**:
179
- - πŸ” **Ask Questions** (Accurate answers from PDF)
180
- - πŸ“ **Generate Summary** (Concise document summary)
181
 
182
- Built using **RAG architecture with open-source AI models**.
183
- """)
184
 
 
 
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
- pdf_input = gr.File(label="πŸ“€ Upload PDF", file_types=[".pdf"])
188
-
189
  question_input = gr.Textbox(
190
- label="❓ Ask a question (for Q&A)",
191
- placeholder="e.g. Whose report is this?",
192
  lines=2
193
  )
 
194
 
195
- qa_btn = gr.Button("πŸ” Get Answer")
196
- summary_btn = gr.Button("πŸ“ Generate Summary")
197
 
198
- with gr.Column(scale=2):
199
- output_box = gr.Textbox(label="πŸ“Œ Output", lines=12)
200
 
201
- qa_btn.click(pdf_qa, [pdf_input, question_input], output_box)
202
- summary_btn.click(pdf_summary, [pdf_input], output_box)
 
203
 
204
  gr.Markdown("""
205
  ---
206
- **Β© Simranpreet Kaur**
207
- **NIELIT Ropar | AIML Six Months Training | 2026**
208
  """)
209
 
210
- demo.launch()
 
 
3
  import re
4
  import faiss
5
  import numpy as np
6
+ import time
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import pipeline
9
 
10
+ # --- Global State and Initialization ---
11
+ # These variables will hold the processed document data
12
+ qa_index = None
13
+ qa_chunks = []
14
+ summarizer_chunks = []
15
+ is_initialized = False
16
 
17
  # =================================================
18
  # MODEL LOADING (ONCE)
19
+ # WARNING: This step is the primary cause of slow startup.
20
  # =================================================
21
 
22
+ try:
23
+ # Embedding model for semantic retrieval
24
+ print("Loading Sentence Transformer model...")
25
+ embedding_model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
26
+
27
+ # Extractive QA model (accurate answers)
28
+ print("Loading Extractive QA model...")
29
+ qa_pipeline = pipeline(
30
+ "question-answering",
31
+ model="deepset/roberta-base-squad2",
32
+ tokenizer="deepset/roberta-base-squad2"
33
+ )
34
+
35
+ # Summarization model (clean summary)
36
+ print("Loading Summarization model...")
37
+ summarizer = pipeline(
38
+ "summarization",
39
+ model="facebook/bart-large-cnn",
40
+ tokenizer="facebook/bart-large-cnn"
41
+ )
42
+ is_initialized = True
43
+ print("All models loaded successfully.")
44
+
45
+ except Exception as e:
46
+ print(f"ERROR: Failed to load required models. Please check dependencies (requirements.txt). Error: {e}")
47
+ # Set initialized to False so functions return an error message
48
+ is_initialized = False
49
 
50
 
51
  # =================================================
52
+ # PDF PROCESSING UTILITIES
53
  # =================================================
54
 
55
  def extract_text_from_pdf(pdf_path):
56
+ """Extracts raw text content from a PDF file using PyMuPDF."""
57
  doc = fitz.open(pdf_path)
58
  text = ""
59
  for page in doc:
60
+ text += page.get_text() + "\n\n"
61
  return text
62
 
63
 
64
  def clean_text(text):
65
+ """Performs common cleanup on raw PDF text."""
66
+ # Remove excessive whitespace
67
  text = re.sub(r"\s+", " ", text)
68
+ # Attempt to remove table of contents, headers, footers (often document-specific)
69
+ text = re.sub(r"Table of Contents.*?Introduction", "", text, flags=re.I | re.DOTALL)
70
+ text = re.sub(r"\bPage \d+ of \d+\b|\bPage \d+\b", "", text)
71
  return text.strip()
72
 
73
 
74
+ def chunk_text(text, chunk_size=400, overlap=100):
75
+ """Chunks text for QA retrieval (smaller chunks for better context focus)."""
76
  chunks = []
77
  start = 0
78
  while start < len(text):
79
  end = start + chunk_size
80
  chunks.append(text[start:end])
81
+ start = end - overlap if end < len(text) else len(text)
82
  return chunks
83
 
84
 
85
+ def chunk_text_for_summary(text, chunk_size=1024, overlap=150):
86
+ """Chunks text for summarization (larger chunks to maintain context flow)."""
87
  chunks = []
88
  start = 0
89
  while start < len(text):
90
  end = start + chunk_size
91
  chunks.append(text[start:end])
92
+ start = end - overlap if end < len(text) else len(text)
93
  return chunks
94
 
95
 
96
  # =================================================
97
+ # FAISS AND CONTEXT RETRIEVAL
98
  # =================================================
99
 
100
  def build_faiss_index(chunks):
101
+ """Builds a FAISS Index from text chunks."""
102
+ print(f"Encoding {len(chunks)} chunks...")
103
+ embeddings = embedding_model.encode(chunks, show_progress_bar=False)
104
  embeddings = np.array(embeddings).astype("float32")
105
+
106
+ # Initialize FAISS Index (L2 distance for 'multi-qa-MiniLM-L6-cos-v1')
107
  index = faiss.IndexFlatL2(embeddings.shape[1])
108
  index.add(embeddings)
109
+ print("FAISS Index built.")
110
  return index, chunks
111
 
112
 
113
  def retrieve_relevant_chunks(question, index, chunks, top_k=5):
114
+ """Retrieves the most relevant chunks for a given question."""
115
+ # Ensure FAISS index is ready
116
+ if index is None:
117
+ return []
118
+
119
+ # Encode the query
120
  query_embedding = embedding_model.encode([question]).astype("float32")
121
+
122
+ # Search the index
123
  distances, indices = index.search(query_embedding, top_k)
124
 
125
  results = []
126
  for i, idx in enumerate(indices[0]):
127
+ # Higher score (smaller distance) is better in L2
128
  results.append((chunks[idx], distances[0][i]))
129
 
130
+ # Sort by distance (smallest distance first)
131
  results.sort(key=lambda x: x[1])
132
  return [r[0] for r in results]
133
 
134
 
135
  # =================================================
136
+ # HANDLERS FOR GRADIO INPUT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # =================================================
138
 
139
+ def process_pdf(pdf_file):
140
+ """
141
+ Initial PDF processing step: extracts text, cleans it, chunks it,
142
+ and builds the FAISS index for retrieval. Updates global state.
143
+ """
144
+ global qa_index, qa_chunks, summarizer_chunks
 
 
 
 
 
 
145
 
146
+ if not is_initialized:
147
+ return "ERROR: AI models failed to load. Please check console for details."
148
+
149
  if pdf_file is None:
150
+ # Clear state if no file is provided
151
+ qa_index = None
152
+ qa_chunks = []
153
+ summarizer_chunks = []
154
  return "Please upload a PDF document."
155
 
156
+ try:
157
+ start_time = time.time()
158
+ print("Starting PDF processing...")
159
+
160
+ # 1. Extraction and Cleaning
161
+ raw_text = extract_text_from_pdf(pdf_file.name)
162
+ cleaned_text = clean_text(raw_text)
163
+
164
+ # 2. Chunking for QA and Summary
165
+ qa_chunks = chunk_text(cleaned_text)
166
+ # Summarizer chunks might be larger to keep sequential context
167
+ summarizer_chunks = chunk_text_for_summary(cleaned_text)
168
+
169
+ # 3. Building FAISS Index for QA
170
+ qa_index, qa_chunks = build_faiss_index(qa_chunks)
171
+
172
+ end_time = time.time()
173
+
174
+ return (f"Document successfully processed and indexed! "
175
+ f"Total chunks: {len(qa_chunks)}. "
176
+ f"Ready for Q&A and Summary. (Processing time: {end_time - start_time:.2f} seconds)")
177
+
178
+ except Exception as e:
179
+ return f"An error occurred during PDF processing: {e}"
180
+
181
+
182
+ def get_answer(question):
183
+ """Handles the Question Answering functionality."""
184
+ if not is_initialized:
185
+ return "ERROR: AI models failed to load. Cannot answer questions."
186
+
187
+ if qa_index is None:
188
+ return "Please upload and process a document first."
189
+
190
+ if not question or question.strip() == "":
191
+ return "Please enter a question to get an answer."
192
+
193
+ try:
194
+ start_time = time.time()
195
+ # 1. Retrieval (RAG component)
196
+ relevant_chunks = retrieve_relevant_chunks(question, qa_index, qa_chunks)
197
+
198
+ # Combine the retrieved chunks into a single context
199
+ context = " ".join(relevant_chunks)
200
+
201
+ # 2. Generation (Extractive QA component)
202
+ # Pass the question and the combined, relevant context to the QA model
203
+ result = qa_pipeline(
204
+ question=question,
205
+ context=context,
206
+ # Set minimum answer length to avoid single-word outputs
207
+ max_answer_len=256,
208
+ )
209
 
210
+ answer = result["answer"]
211
+ score = result["score"]
212
+
213
+ # Set a confidence threshold for a valid answer
214
+ if score < 0.4 or answer.strip() == "":
215
+ return "Information not found in the most relevant sections of the document (confidence too low)."
216
+
217
+ end_time = time.time()
218
+ return (f"Answer: {answer}\n\n"
219
+ f"Confidence Score: {score:.2f}\n"
220
+ f"Time taken: {end_time - start_time:.2f} seconds")
221
+
222
+ except Exception as e:
223
+ return f"An error occurred during Q&A generation: {e}"
224
+
225
+
226
+ def get_summary():
227
+ """Handles the Summarization functionality."""
228
+ if not is_initialized:
229
+ return "ERROR: AI models failed to load. Cannot generate summary."
230
+
231
+ if not summarizer_chunks:
232
+ return "Please upload and process a document first."
233
+
234
+ try:
235
+ start_time = time.time()
236
+ summaries = []
237
+
238
+ # Summarize each chunk sequentially
239
+ for i, chunk in enumerate(summarizer_chunks):
240
+ print(f"Summarizing chunk {i+1}/{len(summarizer_chunks)}")
241
+ summary_output = summarizer(
242
+ chunk,
243
+ max_length=150,
244
+ min_length=50,
245
+ do_sample=False,
246
+ truncation=True # Crucial to handle inputs slightly over the model's max length
247
+ )[0]["summary_text"]
248
+ summaries.append(summary_output)
249
+
250
+ # Join the sequential summaries and run a final merge summary
251
+ merged_summary_text = " ".join(summaries)
252
+
253
+ # If the merged summary is still too long, run a final summary pass
254
+ if len(merged_summary_text) > 1024:
255
+ print("Running final merge summary...")
256
+ final_summary_output = summarizer(
257
+ merged_summary_text,
258
+ max_length=400,
259
+ min_length=150,
260
+ do_sample=False,
261
+ truncation=True
262
+ )[0]["summary_text"]
263
+ else:
264
+ final_summary_output = merged_summary_text
265
+
266
+ end_time = time.time()
267
+ return (f"--- Document Summary ---\n\n{final_summary_output}\n\n"
268
+ f"Time taken: {end_time - start_time:.2f} seconds")
269
+
270
+ except Exception as e:
271
+ return f"An error occurred during summarization: {e}"
272
 
273
 
274
  # =================================================
275
+ # GRADIO UI
276
  # =================================================
277
 
278
  with gr.Blocks() as demo:
279
 
280
  gr.Markdown("""
281
+ # πŸ“„ Open-Source RAG Document Analysis System (Python/Gradio)
282
+
283
+ This system uses three best-in-class open-source models for **Retrieval-Augmented Generation (RAG)**:
284
+ 1. **`multi-qa-MiniLM-L6-cos-v1`**: for fast, accurate context retrieval.
285
+ 2. **`deepset/roberta-base-squad2`**: for highly accurate, extractive Question Answering.
286
+ 3. **`facebook/bart-large-cnn`**: for multi-step, high-quality Summarization.
287
+
288
+ ⚠️ **Warning**: Initial model loading is very slow. Please be patient after the app starts.
289
+ """)
290
 
291
+ with gr.Row():
292
+ pdf_input = gr.File(label="πŸ“€ Upload PDF Document", file_types=[".pdf"])
293
+ process_status = gr.Textbox(label="Processing Status", interactive=False, value="Upload a PDF to begin.")
294
 
295
+ process_btn = gr.Button("1. Process & Index Document", variant="primary")
296
+ process_btn.click(process_pdf, [pdf_input], process_status)
297
 
298
+ gr.Markdown("---")
299
+
300
  with gr.Row():
301
  with gr.Column(scale=1):
 
 
302
  question_input = gr.Textbox(
303
+ label="❓ Step 2: Ask a Question",
304
+ placeholder="e.g. What were the Q4 revenue figures?",
305
  lines=2
306
  )
307
+ qa_btn = gr.Button("πŸ” Get Accurate Answer", variant="secondary")
308
 
309
+ with gr.Column(scale=1):
310
+ summary_btn = gr.Button("πŸ“ Step 2: Generate Full Summary", variant="secondary")
311
 
312
+ output_box = gr.Textbox(label="πŸ“Œ Output / Result", lines=10, interactive=False)
 
313
 
314
+ # Bind events
315
+ qa_btn.click(get_answer, [question_input], output_box)
316
+ summary_btn.click(get_summary, [], output_box)
317
 
318
  gr.Markdown("""
319
  ---
320
+ *Disclaimer: Due to the size of the models, expect longer processing times for Q&A and Summarization than API-based solutions.*
 
321
  """)
322
 
323
+ # To run the Gradio application, you would typically call:
324
+ # demo.launch()