Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| DEFAULT_MODEL = "google/gemma-3-4b-it" | |
| ADAPTER_PATH = "./gemma-lecture-adapter" | |
| HUB_ADAPTER_ID = "noufwithy/gemma-lecture-adapter" | |
| SUMMARIZE_SYSTEM_PROMPT = """You are a lecture summarization assistant. | |
| Summarize the following lecture transcription into a comprehensive, structured summary with these sections: | |
| - **Summary**: A concise overview of what the lecture covered | |
| - **Key Points**: The main concepts, definitions, and important details covered in the lecture (use bullet points) | |
| - **Action Points**: Any tasks, assignments, or follow-up actions mentioned by the lecturer | |
| Cover ALL topics discussed. Do not omit any major points. | |
| Output ONLY the summary. No explanations or extra commentary.""" | |
| # Quiz prompts match the training data format exactly (one question per call) | |
| MCQ_SYSTEM_PROMPT = """You are an educational quiz generator. | |
| Based on the following lecture transcription, generate a multiple choice question | |
| with 4 options labeled A-D and indicate the correct answer. | |
| Format: | |
| Q1. [Question] | |
| A) [Option] | |
| B) [Option] | |
| C) [Option] | |
| D) [Option] | |
| Correct Answer: [Letter] | |
| Output ONLY the question. No explanations or extra commentary.""" | |
| SHORT_ANSWER_SYSTEM_PROMPT = """You are an educational quiz generator. | |
| Based on the following lecture transcription, generate a short answer question | |
| with the expected answer. | |
| Format: | |
| Q1. [Question] | |
| Expected Answer: [Brief answer] | |
| Output ONLY the question. No explanations or extra commentary.""" | |
| NUM_MCQ = 5 | |
| NUM_SHORT_ANSWER = 3 | |
| _model = None | |
| _tokenizer = None | |
| def _load_model(model_id: str = DEFAULT_MODEL, adapter_path: str = ADAPTER_PATH): | |
| global _model, _tokenizer | |
| if _model is not None: | |
| return _model, _tokenizer | |
| _tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Try local adapter first, then HuggingFace Hub, then base model | |
| adapter_source = adapter_path if os.path.isdir(adapter_path) else HUB_ADAPTER_ID | |
| # Load in bfloat16 (bitsandbytes 4-bit/8-bit quantization broken with Gemma 3) | |
| try: | |
| print(f"Loading model with LoRA adapter from {adapter_source}...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| attn_implementation="eager", | |
| ) | |
| _model = PeftModel.from_pretrained(base_model, adapter_source) | |
| _model.eval() | |
| print("LoRA adapter loaded successfully on bfloat16 base model.") | |
| except Exception as e: | |
| print(f"LoRA adapter failed ({e}), falling back to base model...") | |
| traceback.print_exc() | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| model_id, device_map="auto", dtype=torch.bfloat16, | |
| ) | |
| return _model, _tokenizer | |
| def _generate(messages, max_new_tokens=2048, do_sample=False, temperature=0.7): | |
| """Generate text using model.generate() directly.""" | |
| model, tokenizer = _load_model() | |
| # Format chat messages into a string, then tokenize | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) | |
| input_ids = inputs["input_ids"].to(model.device) | |
| attention_mask = inputs["attention_mask"].to(model.device) | |
| print(f"[DEBUG] input length: {input_ids.shape[-1]} tokens") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else None, | |
| top_p=0.9 if do_sample else None, | |
| repetition_penalty=1.3, | |
| ) | |
| # Decode only the new tokens (skip the input) | |
| new_tokens = outputs[0][input_ids.shape[-1]:] | |
| print(f"[DEBUG] generated {len(new_tokens)} new tokens") | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| return response.strip() | |
| def _is_good_summary(text: str, transcript: str = "") -> bool: | |
| """Check if a summary meets minimum quality: long enough, not repetitive, not parroting.""" | |
| if len(text) < 100: | |
| return False | |
| # Check for excessive repetition (same line or sentence repeated 2+ times) | |
| from collections import Counter | |
| for chunks in [ | |
| [s.strip() for s in text.split("\n") if s.strip()], | |
| [s.strip() for s in text.split(".") if s.strip()], | |
| ]: | |
| if chunks: | |
| counts = Counter(chunks) | |
| most_common_count = counts.most_common(1)[0][1] | |
| if most_common_count >= 2: | |
| print(f"[QUALITY] Repetitive output detected ({most_common_count} repeats)") | |
| return False | |
| # Check if summary is just parroting the transcript (high word overlap) | |
| if transcript: | |
| summary_words = set(text.lower().split()) | |
| transcript_words = set(transcript.lower().split()) | |
| if summary_words and transcript_words: | |
| overlap = len(summary_words & transcript_words) / len(summary_words) | |
| if overlap > 0.85: | |
| print(f"[QUALITY] Summary too similar to transcript ({overlap:.0%} word overlap)") | |
| return False | |
| # Check if summary has enough key points (at least 3 bullet points) | |
| bullet_count = text.count("- ") | |
| has_key_points = "key points" in text.lower() | |
| if has_key_points and bullet_count < 3: | |
| print(f"[QUALITY] Summary has too few key points ({bullet_count})") | |
| return False | |
| # Check minimum unique content (summary should have substance) | |
| unique_lines = set(s.strip() for s in text.split("\n") if s.strip() and len(s.strip()) > 10) | |
| if len(unique_lines) < 5: | |
| print(f"[QUALITY] Summary too shallow ({len(unique_lines)} unique lines)") | |
| return False | |
| return True | |
| def _generate_with_base_fallback(messages, transcript="", **kwargs): | |
| """Generate with adapter first. If output is bad, retry with base model.""" | |
| result = _generate(messages, **kwargs) | |
| if _is_good_summary(result, transcript=transcript): | |
| return result | |
| # Adapter output is bad, try base model | |
| model, _ = _load_model() | |
| if isinstance(model, PeftModel): | |
| print("[FALLBACK] Adapter output too short or repetitive, retrying with base model...") | |
| model.disable_adapter_layers() | |
| try: | |
| result = _generate(messages, **kwargs) | |
| finally: | |
| model.enable_adapter_layers() | |
| print(f"[FALLBACK] base model response length: {len(result)}") | |
| return result | |
| def _truncate_transcript(transcript: str, max_words: int = 4000) -> str: | |
| """Truncate transcript to fit model's effective context (trained on 3072 tokens).""" | |
| words = transcript.split() | |
| if len(words) <= max_words: | |
| return transcript | |
| print(f"[TRUNCATE] Transcript has {len(words)} words, truncating to {max_words}") | |
| return " ".join(words[:max_words]) | |
| def summarize_lecture(transcript: str, model: str = DEFAULT_MODEL) -> str: | |
| """Summarize a lecture transcript using Gemma.""" | |
| if not transcript or not transcript.strip(): | |
| return "" | |
| truncated = _truncate_transcript(transcript) | |
| messages = [ | |
| {"role": "system", "content": SUMMARIZE_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Lecture transcription:\n\n{truncated}"}, | |
| ] | |
| # Try adapter first, fall back to base model if quality is bad | |
| result = _generate_with_base_fallback(messages, transcript=transcript, do_sample=True, temperature=0.3) | |
| print(f"[DEBUG summarize] response length: {len(result)}") | |
| return result | |
| def _extract_question_text(result: str) -> str: | |
| """Extract just the question text (first line after Q number) for dedup comparison.""" | |
| import re | |
| match = re.search(r'Q\d+\.\s*(.+)', result) | |
| return match.group(1).strip().lower() if match else result.strip().lower() | |
| def _is_good_quiz_answer(result: str, transcript: str = "") -> bool: | |
| """Check if a generated quiz question is reasonable quality.""" | |
| # Reject if response doesn't match any expected format (no question generated) | |
| if "Correct Answer:" not in result and "Expected Answer:" not in result: | |
| print(f"[QUALITY] Response has no valid question format (missing Correct/Expected Answer)") | |
| return False | |
| # Reject if there's no actual question (Q1. pattern) | |
| if "Q1." not in result: | |
| print(f"[QUALITY] Response missing Q1. question marker") | |
| return False | |
| # Short answer: reject if expected answer is just a transcript fragment with no real content | |
| if "Expected Answer:" in result: | |
| answer = result.split("Expected Answer:")[-1].strip() | |
| # Reject vague/pointer answers like "right here", "this arrow", "at this point" | |
| vague_phrases = ["right here", "this arrow", "at this point", "this one", "over here", "right there"] | |
| if any(phrase in answer.lower() for phrase in vague_phrases): | |
| print(f"[QUALITY] Short answer too vague: {answer}") | |
| return False | |
| if len(answer.split()) < 2: | |
| print(f"[QUALITY] Short answer too short: {answer}") | |
| return False | |
| # MCQ: reject if it doesn't have 4 options or has duplicate options | |
| if "Correct Answer:" in result and "Expected Answer:" not in result: | |
| import re | |
| for label in ["A)", "B)", "C)", "D)"]: | |
| if label not in result: | |
| print(f"[QUALITY] MCQ missing option {label}") | |
| return False | |
| # Reject if options are mostly duplicated | |
| options = re.findall(r'[A-D]\)\s*(.+)', result) | |
| unique_options = set(opt.strip().lower() for opt in options) | |
| if len(unique_options) < 3: | |
| print(f"[QUALITY] MCQ has duplicate options ({len(unique_options)} unique out of {len(options)})") | |
| return False | |
| return True | |
| def _dedup_mcq_options(result: str) -> str: | |
| """Remove duplicate MCQ options, keeping unique ones only.""" | |
| import re | |
| options = re.findall(r'([A-D])\)\s*(.+)', result) | |
| if len(options) != 4: | |
| return result | |
| seen = {} | |
| unique = [] | |
| for label, text in options: | |
| key = text.strip().lower() | |
| if key not in seen: | |
| seen[key] = True | |
| unique.append((label, text.strip())) | |
| if len(unique) == len(options): | |
| return result # no duplicates | |
| print(f"[QUALITY] Removed {len(options) - len(unique)} duplicate MCQ option(s)") | |
| # Rebuild with correct labels | |
| lines = result.split("\n") | |
| new_lines = [] | |
| option_idx = 0 | |
| labels = ["A", "B", "C", "D"] | |
| for line in lines: | |
| if re.match(r'^[A-D]\)', line): | |
| if option_idx < len(unique): | |
| new_lines.append(f"{labels[option_idx]}) {unique[option_idx][1]}") | |
| option_idx += 1 | |
| else: | |
| new_lines.append(line) | |
| return "\n".join(new_lines) | |
| def _generate_quiz_with_fallback(messages, transcript="", **kwargs): | |
| """Generate a quiz question with adapter, fall back to base model if bad.""" | |
| result = _generate(messages, **kwargs) | |
| if _is_good_quiz_answer(result, transcript): | |
| return result | |
| model, _ = _load_model() | |
| if isinstance(model, PeftModel): | |
| print("[FALLBACK] Quiz answer bad, retrying with base model...") | |
| model.disable_adapter_layers() | |
| try: | |
| result = _generate(messages, **kwargs) | |
| finally: | |
| model.enable_adapter_layers() | |
| return result | |
| def _normalize_words(text: str) -> set[str]: | |
| """Strip punctuation from words for cleaner comparison.""" | |
| import re | |
| return set(re.sub(r'[^\w\s]', '', word) for word in text.split() if word.strip()) | |
| def _is_duplicate(result: str, existing_parts: list[str]) -> bool: | |
| """Check if a generated question is too similar to any already generated.""" | |
| new_q = _extract_question_text(result) | |
| for part in existing_parts: | |
| old_q = _extract_question_text(part) | |
| # Check if questions share most of their words (punctuation-stripped) | |
| new_words = _normalize_words(new_q) | |
| old_words = _normalize_words(old_q) | |
| if not new_words or not old_words: | |
| continue | |
| overlap = len(new_words & old_words) / min(len(new_words), len(old_words)) | |
| if overlap > 0.7: | |
| print(f"[QUALITY] Duplicate question detected ({overlap:.0%} word overlap)") | |
| return True | |
| return False | |
| def generate_quiz(transcript: str, model: str = DEFAULT_MODEL) -> str: | |
| """Generate quiz questions from a lecture transcript using Gemma. | |
| Generates questions one at a time to match training format, then combines them. | |
| Skips duplicate questions automatically. | |
| """ | |
| if not transcript or not transcript.strip(): | |
| return "" | |
| transcript = _truncate_transcript(transcript) | |
| parts = [] | |
| max_retries = 2 # extra attempts per question if duplicate | |
| # Generate MCQs one at a time (matches training: one MCQ per example) | |
| for i in range(NUM_MCQ): | |
| print(f"[DEBUG quiz] generating MCQ {i + 1}/{NUM_MCQ}...") | |
| messages = [ | |
| {"role": "system", "content": MCQ_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, | |
| ] | |
| good = False | |
| for attempt in range(1 + max_retries): | |
| result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) | |
| if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): | |
| good = True | |
| break | |
| print(f"[DEBUG quiz] MCQ {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") | |
| if good: | |
| result = _dedup_mcq_options(result) | |
| result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) | |
| parts.append(result) | |
| else: | |
| print(f"[DEBUG quiz] MCQ {i + 1} dropped (unreliable after {1 + max_retries} attempts)") | |
| # Generate short answer questions one at a time | |
| for i in range(NUM_SHORT_ANSWER): | |
| q_num = NUM_MCQ + i + 1 | |
| print(f"[DEBUG quiz] generating short answer {i + 1}/{NUM_SHORT_ANSWER}...") | |
| messages = [ | |
| {"role": "system", "content": SHORT_ANSWER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, | |
| ] | |
| good = False | |
| for attempt in range(1 + max_retries): | |
| result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) | |
| if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): | |
| good = True | |
| break | |
| print(f"[DEBUG quiz] short answer {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") | |
| if good: | |
| result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) | |
| parts.append(result) | |
| else: | |
| print(f"[DEBUG quiz] short answer {i + 1} dropped (unreliable after {1 + max_retries} attempts)") | |
| combined = "\n\n".join(parts) | |
| print(f"[DEBUG quiz] total response length: {len(combined)}") | |
| return combined | |