pawakumar's picture
Update app.py
59cda2f verified
import os
import re
import ast
import tempfile
import textwrap
import subprocess
import gc
import traceback
from io import StringIO
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
from pylint.lint import Run
from pylint.reporters.text import TextReporter
# ==========================================
# CONFIGURATION & STATE
# ==========================================
MODEL_CONFIGS = {
"Qwen 2.5 Coder 1.5B": {
"base": "Qwen/Qwen2.5-Coder-1.5B",
"adapter": "weights/docstring_codegen_qwen15b_lora" # Checked against your screenshot
},
"StarCoder2 3B": {
"base": "bigcode/starcoder2-3b",
"adapter": "weights/starcoder2_3b_qlora_docstring"
},
"CodeLlama 7B": {
"base": "codellama/CodeLlama-7b-Python-hf",
"adapter": "weights/codellama7b_python_qlora_docstring"
}
}
# Determine the best device available
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
DTYPE = torch.float16 if DEVICE in ["cuda", "mps"] else torch.float32
current_model_name = None
global_tokenizer = None
global_model = None
# ==========================================
# MODEL MANAGEMENT
# ==========================================
def load_model_if_needed(model_choice):
global current_model_name, global_tokenizer, global_model
if current_model_name == model_choice and global_model is not None:
return global_tokenizer, global_model
print(f"\n[INFO] Switching to {model_choice}... Freeing memory.")
# 1. Clear old models from memory to prevent crashes
if global_model is not None:
del global_model
if global_tokenizer is not None:
del global_tokenizer
global_model = None
global_tokenizer = None
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
elif DEVICE == "mps":
torch.mps.empty_cache()
config = MODEL_CONFIGS[model_choice]
base_id = config["base"]
adapter_dir = config["adapter"]
try:
print(f"[INFO] Loading Tokenizer: {base_id}")
tokenizer = AutoTokenizer.from_pretrained(base_id, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"[INFO] Loading Base Model: {base_id} onto {DEVICE}")
# Using .to(DEVICE) instead of device_map="auto" for better stability across all OS
base_model = AutoModelForCausalLM.from_pretrained(
base_id,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
).to(DEVICE)
if os.path.exists(adapter_dir):
print(f"[INFO] Attaching LoRA Adapter from: {adapter_dir}")
model = PeftModel.from_pretrained(base_model, adapter_dir)
else:
print(f"[WARNING] Adapter directory '{adapter_dir}' not found! Running base only.")
model = base_model
model.eval()
current_model_name = model_choice
global_tokenizer = tokenizer
global_model = model
print("[INFO] Model loaded successfully.\n")
return tokenizer, model
except Exception as e:
print(f"[ERROR] Model load failed:\n")
traceback.print_exc()
raise RuntimeError(f"Failed to load {base_id}. Ensure you have HuggingFace access and correct paths. Error: {str(e)}")
@torch.no_grad()
def generate_raw(model, tokenizer, prompt: str, max_new_tokens: int = 160):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
def clean_generated_body(text: str) -> str:
fallback_pass = " pass\n"
if not text or not text.strip():
return fallback_pass
text = text.replace("\r\n", "\n").replace("\r", "\n")
text = re.sub(r"```[a-zA-Z]*", "", text).replace("```", "")
lines = text.splitlines()
cleaned = []
for line in lines:
if not line.strip():
cleaned.append("")
continue
if line.startswith("def ") or line.startswith("class ") or line.startswith("import "):
break
leading_spaces = len(line) - len(line.lstrip())
if leading_spaces == 0:
cleaned.append(" " + line)
elif leading_spaces < 4:
cleaned.append(" " + line.lstrip())
else:
cleaned.append(line)
body = "\n".join(cleaned).rstrip()
if not body:
return fallback_pass
return body + "\n"
def generate_body(model, tokenizer, prompt: str, max_new_tokens: int = 160, is_base: bool = False):
if is_base and hasattr(model, "disable_adapter"):
with model.disable_adapter():
raw = generate_raw(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
else:
raw = generate_raw(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
return clean_generated_body(raw)
# ==========================================
# PRE-PROCESSING & VALIDATION
# ==========================================
def auto_fix_prompt_indentation(prompt: str) -> str:
lines = prompt.splitlines()
fixed = []
inside_func = False
for line in lines:
if line.startswith("def ") and line.strip().endswith(":"):
inside_func = True
fixed.append(line)
continue
if inside_func and line.strip() and not line.startswith((" ", "\t")):
fixed.append(" " + line)
else:
fixed.append(line)
return "\n".join(fixed) + "\n"
BLOCKED_CALLS = {"eval", "exec", "compile", "__import__", "open", "input", "breakpoint"}
BLOCKED_NAMES = {"os", "sys", "subprocess", "socket", "shutil", "pathlib", "resource", "signal", "multiprocessing", "threading", "asyncio"}
BLOCKED_ATTR_PREFIX = {"os.", "sys.", "subprocess.", "socket.", "shutil.", "pathlib."}
SANDBOX_TIMEOUT_SEC = 5
def ast_syntax_check(code: str):
try:
ast.parse(code)
return True, "Syntax OK"
except SyntaxError as e:
return False, f"SyntaxError: {e}"
def static_safety_scan(code: str):
try:
tree = ast.parse(code)
except Exception as e:
return False, [f"AST parse failed: {e}"]
issues = []
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
issues.append("Import statements are not allowed in sandbox mode.")
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id in BLOCKED_CALLS:
issues.append(f"Blocked function call: {node.func.id}")
if isinstance(node.func, ast.Attribute):
try:
full = ast.unparse(node.func)
except Exception:
full = None
if full:
for prefix in BLOCKED_ATTR_PREFIX:
if full.startswith(prefix):
issues.append(f"Blocked attribute call: {full}")
if isinstance(node, ast.Name) and node.id in BLOCKED_NAMES:
issues.append(f"Blocked name reference: {node.id}")
if isinstance(node, ast.Attribute) and node.attr.startswith("__"):
issues.append(f"Blocked dunder attribute access: {node.attr}")
return len(issues) == 0, sorted(set(issues))
def run_in_sandbox(full_code: str, test_code: str = ""):
syntax_ok, syntax_msg = ast_syntax_check(full_code)
if not syntax_ok:
return {"ok": False, "stage": "syntax", "stdout": "", "stderr": syntax_msg, "returncode": None}
safe_ok, issues = static_safety_scan(full_code + "\n" + (test_code or ""))
if not safe_ok:
return {"ok": False, "stage": "static_safety", "stdout": "", "stderr": "\n".join(issues), "returncode": None}
runner = f"""
{full_code}
def __run_user_tests__():
{textwrap.indent(test_code if test_code.strip() else " pass", " ")}
if __name__ == "__main__":
try:
__run_user_tests__()
print("SANDBOX_OK")
except AssertionError as e:
print("ASSERTION FAILED")
raise e
"""
with tempfile.TemporaryDirectory() as td:
script_path = os.path.join(td, "runner.py")
with open(script_path, "w", encoding="utf-8") as f:
f.write(runner)
try:
proc = subprocess.run(
["python", "-I", script_path],
capture_output=True, text=True, timeout=SANDBOX_TIMEOUT_SEC, cwd=td,
)
ok = (proc.returncode == 0) and ("SANDBOX_OK" in proc.stdout)
return {"ok": ok, "stage": "runtime", "stdout": proc.stdout, "stderr": proc.stderr, "returncode": proc.returncode}
except subprocess.TimeoutExpired as e:
return {"ok": False, "stage": "timeout", "stdout": e.stdout or "", "stderr": "Execution timed out.", "returncode": None}
def run_pylint_on_code(full_code: str):
syntax_ok, syntax_msg = ast_syntax_check(full_code)
if not syntax_ok:
return {"score": None, "report": syntax_msg}
with tempfile.TemporaryDirectory() as td:
file_path = os.path.join(td, "candidate.py")
with open(file_path, "w", encoding="utf-8") as f:
f.write(full_code)
output = StringIO()
reporter = TextReporter(output)
try:
Run([file_path, "--score=y", "--reports=n"], reporter=reporter, exit=False)
report_text = output.getvalue()
score_match = re.search(r"rated at ([\-0-9\.]+)/10", report_text)
score = float(score_match.group(1)) if score_match else None
return {"score": score, "report": report_text.strip() if report_text.strip() else "No pylint messages."}
except Exception as e:
return {"score": None, "report": f"Pylint failed: {repr(e)}"}
# ==========================================
# GRADIO UI & PIPELINE
# ==========================================
def analyze_prompt(
model_choice: str,
prompt: str,
unit_tests: str,
max_new_tokens: int,
use_base_model: bool,
use_finetuned_model: bool,
):
if not prompt or not prompt.strip():
return ("Please enter a prompt.", "", "", "", "", "", "", "")
prompt = auto_fix_prompt_indentation(prompt.rstrip())
func_name_match = re.search(r"def\s+([a-zA-Z0-9_]+)\s*\(", prompt)
if func_name_match and unit_tests:
actual_func_name = func_name_match.group(1)
unit_tests = unit_tests.replace("candidate(", f"{actual_func_name}(")
try:
tokenizer, model = load_model_if_needed(model_choice)
except Exception as e:
# Will output the exact loading error into the Run Summary box
err_msg = f"❌ MODEL LOAD ERROR ❌\n\n{str(e)}"
return (err_msg, "", "", err_msg, "", "", err_msg, "")
def process_one(label, is_base_run):
body = generate_body(model, tokenizer, prompt, max_new_tokens=max_new_tokens, is_base=is_base_run)
full_code = prompt + body
syntax_ok, syntax_msg = ast_syntax_check(full_code)
safe_ok, safe_issues = static_safety_scan(full_code)
sandbox_result = run_in_sandbox(full_code, unit_tests or "")
pylint_result = run_pylint_on_code(full_code)
safety_text = "Safe check: PASS" if safe_ok else "Safe check: FAIL\n" + "\n".join(safe_issues)
sandbox_text = (
f"Sandbox: {'PASS' if sandbox_result['ok'] else 'FAIL'}\n"
f"Stage: {sandbox_result['stage']}\n"
f"Return code: {sandbox_result['returncode']}\n"
f"STDOUT:\n{sandbox_result['stdout']}\n"
f"STDERR:\n{sandbox_result['stderr']}"
)
pylint_text = f"Pylint score: {pylint_result['score']}\n\n{pylint_result['report']}"
return {
"label": label, "body": body, "full_code": full_code,
"syntax": syntax_msg, "safety": safety_text, "sandbox": sandbox_text, "pylint": pylint_text,
}
base_result, ft_result = None, None
if use_base_model:
base_result = process_one("Base Model", is_base_run=True)
if use_finetuned_model:
ft_result = process_one("Fine-Tuned Model", is_base_run=False)
base_body = base_result["body"] if base_result else ""
base_full = base_result["full_code"] if base_result else ""
base_diag = (f"{base_result['syntax']}\n\n{base_result['safety']}\n\n{base_result['sandbox']}\n\n{base_result['pylint']}" if base_result else "")
ft_body = ft_result["body"] if ft_result else ""
ft_full = ft_result["full_code"] if ft_result else ""
ft_diag = (f"{ft_result['syntax']}\n\n{ft_result['safety']}\n\n{ft_result['sandbox']}\n\n{ft_result['pylint']}" if ft_result else "")
verdict = []
if base_result: verdict.append("✅ Base model ran.")
if ft_result: verdict.append("✅ Fine-tuned model ran.")
if ft_result and base_result: verdict.append("Use the side-by-side outputs below to compare results.")
return ("\n".join(verdict), base_body, base_full, base_diag, ft_body, ft_full, ft_diag, f"Prompt length: {len(prompt)} chars")
EXAMPLE_PROMPTS = [
[
"Qwen 2.5 Coder 1.5B",
'''def add_numbers(a, b):\n """Return the sum of two integers."""\n''',
'''assert candidate(2, 3) == 5\nassert candidate(-1, 1) == 0''',
64, True, True,
],
[
"StarCoder2 3B",
'''def is_even(n):\n """Return True if n is even, otherwise False."""\n''',
'''assert candidate(4) is True\nassert candidate(5) is False''',
64, True, True,
],
]
with gr.Blocks(title="Multi-Model CodeGen Demo") as demo:
# ---------------------------------------------------------
# 1. HEADER & BOLD DISCLAIMER
# ---------------------------------------------------------
gr.Markdown("# Multi-Model CodeGen Demo (Base vs LoRA)")
gr.HTML("""
<div style="background-color: #ffe6e6; border-left: 6px solid #ff4d4d; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
<h3 style="color: #cc0000; margin-top: 0;">⚠️ IMPORTANT DISCLAIMER</h3>
<p style="color: #333; font-weight: bold; font-size: 14px;">
1. When using this app, please be patient! You must wait for the selected model to fully load into memory.Model loading status can be seen at terminal where you have run app.py<br>
2. Make sure to fill in the boxes exactly as per the instructions.<br>
3. The maximum accuracy of these models is ~55%, so the generated code may occasionally fail tests.
</p>
</div>
""")
# ---------------------------------------------------------
# 2. MAIN TWO-COLUMN LAYOUT
# ---------------------------------------------------------
with gr.Row():
# LEFT COLUMN (Inputs)
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=list(MODEL_CONFIGS.keys()),
value="Qwen 2.5 Coder 1.5B",
label="Select Base Architecture"
)
prompt_in = gr.Textbox(
label="Prompt (function signature + docstring)", lines=10,
placeholder='def add_numbers(a, b):\n """Return the sum of two integers."""\n',
)
tests_in = gr.Textbox(
label="Optional Unit Tests (Use 'candidate' or the actual function name)", lines=8,
placeholder='assert candidate(2, 3) == 5\nassert candidate(-1, 1) == 0',
)
with gr.Row():
max_tokens_in = gr.Slider(minimum=16, maximum=512, value=96, step=16, label="Max New Tokens")
with gr.Row():
use_base_in = gr.Checkbox(value=True, label="Run Base Model")
use_ft_in = gr.Checkbox(value=True, label="Run Fine-Tuned Model")
run_btn = gr.Button("Generate + Analyze", variant="primary")
clear_btn = gr.Button("Clear")
# RIGHT COLUMN (Outputs & Examples)
with gr.Column(scale=1):
verdict_out = gr.Textbox(label="Run Summary", lines=5)
meta_out = gr.Textbox(label="Meta", lines=2)
# Moved Examples to exactly below the Meta box inside the right column!
gr.Markdown("### ⬇️ Try an Example Below:")
gr.Examples(
examples=EXAMPLE_PROMPTS,
inputs=[model_dropdown, prompt_in, tests_in, max_tokens_in, use_base_in, use_ft_in],
)
# ---------------------------------------------------------
# 3. OUTPUT TABS
# ---------------------------------------------------------
with gr.Tab("Base Model"):
base_body_out = gr.Code(label="Base Model: Generated Body", language="python")
base_full_out = gr.Code(label="Base Model: Full Reconstructed Code", language="python")
base_diag_out = gr.Textbox(label="Base Model Diagnostics", lines=22)
with gr.Tab("Fine-Tuned Model"):
ft_body_out = gr.Code(label="Fine-Tuned: Generated Body", language="python")
ft_full_out = gr.Code(label="Fine-Tuned: Full Reconstructed Code", language="python")
ft_diag_out = gr.Textbox(label="Fine-Tuned Diagnostics", lines=22)
# ---------------------------------------------------------
# 4. EVENT LISTENERS
# ---------------------------------------------------------
run_btn.click(
fn=analyze_prompt,
inputs=[model_dropdown, prompt_in, tests_in, max_tokens_in, use_base_in, use_ft_in],
outputs=[verdict_out, base_body_out, base_full_out, base_diag_out, ft_body_out, ft_full_out, ft_diag_out, meta_out],
)
clear_btn.click(
fn=lambda: ("", "", "", "", "", "", "", ""), inputs=[],
outputs=[verdict_out, base_body_out, base_full_out, base_diag_out, ft_body_out, ft_full_out, ft_diag_out, meta_out],
)
# if __name__ == "__main__":
# demo.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)