Upload 4 files
Browse files
app.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import ast
|
| 6 |
+
import tempfile
|
| 7 |
+
import textwrap
|
| 8 |
+
import subprocess
|
| 9 |
+
import gc
|
| 10 |
+
import traceback
|
| 11 |
+
from io import StringIO
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from peft import PeftModel
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from pylint.lint import Run
|
| 18 |
+
from pylint.reporters.text import TextReporter
|
| 19 |
+
|
| 20 |
+
# ==========================================
|
| 21 |
+
# CONFIGURATION & STATE
|
| 22 |
+
# ==========================================
|
| 23 |
+
|
| 24 |
+
MODEL_CONFIGS = {
|
| 25 |
+
"Qwen 2.5 Coder 1.5B": {
|
| 26 |
+
"base": "Qwen/Qwen2.5-Coder-1.5B",
|
| 27 |
+
"adapter": "weights/docstring_codegen_qwen15b_lora" # Checked against your screenshot
|
| 28 |
+
},
|
| 29 |
+
"StarCoder2 3B": {
|
| 30 |
+
"base": "bigcode/starcoder2-3b",
|
| 31 |
+
"adapter": "weights/starcoder2_3b_qlora_docstring"
|
| 32 |
+
},
|
| 33 |
+
"CodeLlama 7B": {
|
| 34 |
+
"base": "codellama/CodeLlama-7b-Python-hf",
|
| 35 |
+
"adapter": "weights/codellama7b_python_qlora_docstring"
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Determine the best device available
|
| 40 |
+
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 41 |
+
DTYPE = torch.float16 if DEVICE in ["cuda", "mps"] else torch.float32
|
| 42 |
+
|
| 43 |
+
current_model_name = None
|
| 44 |
+
global_tokenizer = None
|
| 45 |
+
global_model = None
|
| 46 |
+
|
| 47 |
+
# ==========================================
|
| 48 |
+
# MODEL MANAGEMENT
|
| 49 |
+
# ==========================================
|
| 50 |
+
|
| 51 |
+
def load_model_if_needed(model_choice):
|
| 52 |
+
global current_model_name, global_tokenizer, global_model
|
| 53 |
+
|
| 54 |
+
if current_model_name == model_choice and global_model is not None:
|
| 55 |
+
return global_tokenizer, global_model
|
| 56 |
+
|
| 57 |
+
print(f"\n[INFO] Switching to {model_choice}... Freeing memory.")
|
| 58 |
+
|
| 59 |
+
# 1. Clear old models from memory to prevent crashes
|
| 60 |
+
if global_model is not None:
|
| 61 |
+
del global_model
|
| 62 |
+
if global_tokenizer is not None:
|
| 63 |
+
del global_tokenizer
|
| 64 |
+
|
| 65 |
+
global_model = None
|
| 66 |
+
global_tokenizer = None
|
| 67 |
+
gc.collect()
|
| 68 |
+
|
| 69 |
+
if DEVICE == "cuda":
|
| 70 |
+
torch.cuda.empty_cache()
|
| 71 |
+
elif DEVICE == "mps":
|
| 72 |
+
torch.mps.empty_cache()
|
| 73 |
+
|
| 74 |
+
config = MODEL_CONFIGS[model_choice]
|
| 75 |
+
base_id = config["base"]
|
| 76 |
+
adapter_dir = config["adapter"]
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
print(f"[INFO] Loading Tokenizer: {base_id}")
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained(base_id, use_fast=True)
|
| 81 |
+
if tokenizer.pad_token is None:
|
| 82 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 83 |
+
|
| 84 |
+
print(f"[INFO] Loading Base Model: {base_id} onto {DEVICE}")
|
| 85 |
+
# Using .to(DEVICE) instead of device_map="auto" for better stability across all OS
|
| 86 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
+
base_id,
|
| 88 |
+
torch_dtype=DTYPE,
|
| 89 |
+
low_cpu_mem_usage=True,
|
| 90 |
+
).to(DEVICE)
|
| 91 |
+
|
| 92 |
+
if os.path.exists(adapter_dir):
|
| 93 |
+
print(f"[INFO] Attaching LoRA Adapter from: {adapter_dir}")
|
| 94 |
+
model = PeftModel.from_pretrained(base_model, adapter_dir)
|
| 95 |
+
else:
|
| 96 |
+
print(f"[WARNING] Adapter directory '{adapter_dir}' not found! Running base only.")
|
| 97 |
+
model = base_model
|
| 98 |
+
|
| 99 |
+
model.eval()
|
| 100 |
+
|
| 101 |
+
current_model_name = model_choice
|
| 102 |
+
global_tokenizer = tokenizer
|
| 103 |
+
global_model = model
|
| 104 |
+
|
| 105 |
+
print("[INFO] Model loaded successfully.\n")
|
| 106 |
+
return tokenizer, model
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"[ERROR] Model load failed:\n")
|
| 110 |
+
traceback.print_exc()
|
| 111 |
+
raise RuntimeError(f"Failed to load {base_id}. Ensure you have HuggingFace access and correct paths. Error: {str(e)}")
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def generate_raw(model, tokenizer, prompt: str, max_new_tokens: int = 160):
|
| 115 |
+
inputs = tokenizer(
|
| 116 |
+
prompt,
|
| 117 |
+
return_tensors="pt",
|
| 118 |
+
truncation=True,
|
| 119 |
+
max_length=1024,
|
| 120 |
+
).to(model.device)
|
| 121 |
+
|
| 122 |
+
outputs = model.generate(
|
| 123 |
+
**inputs,
|
| 124 |
+
max_new_tokens=max_new_tokens,
|
| 125 |
+
do_sample=False,
|
| 126 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 127 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 131 |
+
return tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 132 |
+
|
| 133 |
+
def clean_generated_body(text: str) -> str:
|
| 134 |
+
fallback_pass = " pass\n"
|
| 135 |
+
if not text or not text.strip():
|
| 136 |
+
return fallback_pass
|
| 137 |
+
|
| 138 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
| 139 |
+
text = re.sub(r"```[a-zA-Z]*", "", text).replace("```", "")
|
| 140 |
+
|
| 141 |
+
lines = text.splitlines()
|
| 142 |
+
cleaned = []
|
| 143 |
+
|
| 144 |
+
for line in lines:
|
| 145 |
+
if not line.strip():
|
| 146 |
+
cleaned.append("")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
if line.startswith("def ") or line.startswith("class ") or line.startswith("import "):
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
leading_spaces = len(line) - len(line.lstrip())
|
| 153 |
+
|
| 154 |
+
if leading_spaces == 0:
|
| 155 |
+
cleaned.append(" " + line)
|
| 156 |
+
elif leading_spaces < 4:
|
| 157 |
+
cleaned.append(" " + line.lstrip())
|
| 158 |
+
else:
|
| 159 |
+
cleaned.append(line)
|
| 160 |
+
|
| 161 |
+
body = "\n".join(cleaned).rstrip()
|
| 162 |
+
if not body:
|
| 163 |
+
return fallback_pass
|
| 164 |
+
|
| 165 |
+
return body + "\n"
|
| 166 |
+
|
| 167 |
+
def generate_body(model, tokenizer, prompt: str, max_new_tokens: int = 160, is_base: bool = False):
|
| 168 |
+
if is_base and hasattr(model, "disable_adapter"):
|
| 169 |
+
with model.disable_adapter():
|
| 170 |
+
raw = generate_raw(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
|
| 171 |
+
else:
|
| 172 |
+
raw = generate_raw(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
|
| 173 |
+
|
| 174 |
+
return clean_generated_body(raw)
|
| 175 |
+
|
| 176 |
+
# ==========================================
|
| 177 |
+
# PRE-PROCESSING & VALIDATION
|
| 178 |
+
# ==========================================
|
| 179 |
+
|
| 180 |
+
def auto_fix_prompt_indentation(prompt: str) -> str:
|
| 181 |
+
lines = prompt.splitlines()
|
| 182 |
+
fixed = []
|
| 183 |
+
inside_func = False
|
| 184 |
+
|
| 185 |
+
for line in lines:
|
| 186 |
+
if line.startswith("def ") and line.strip().endswith(":"):
|
| 187 |
+
inside_func = True
|
| 188 |
+
fixed.append(line)
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
if inside_func and line.strip() and not line.startswith((" ", "\t")):
|
| 192 |
+
fixed.append(" " + line)
|
| 193 |
+
else:
|
| 194 |
+
fixed.append(line)
|
| 195 |
+
|
| 196 |
+
return "\n".join(fixed) + "\n"
|
| 197 |
+
|
| 198 |
+
BLOCKED_CALLS = {"eval", "exec", "compile", "__import__", "open", "input", "breakpoint"}
|
| 199 |
+
BLOCKED_NAMES = {"os", "sys", "subprocess", "socket", "shutil", "pathlib", "resource", "signal", "multiprocessing", "threading", "asyncio"}
|
| 200 |
+
BLOCKED_ATTR_PREFIX = {"os.", "sys.", "subprocess.", "socket.", "shutil.", "pathlib."}
|
| 201 |
+
SANDBOX_TIMEOUT_SEC = 5
|
| 202 |
+
|
| 203 |
+
def ast_syntax_check(code: str):
|
| 204 |
+
try:
|
| 205 |
+
ast.parse(code)
|
| 206 |
+
return True, "Syntax OK"
|
| 207 |
+
except SyntaxError as e:
|
| 208 |
+
return False, f"SyntaxError: {e}"
|
| 209 |
+
|
| 210 |
+
def static_safety_scan(code: str):
|
| 211 |
+
try:
|
| 212 |
+
tree = ast.parse(code)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
return False, [f"AST parse failed: {e}"]
|
| 215 |
+
|
| 216 |
+
issues = []
|
| 217 |
+
for node in ast.walk(tree):
|
| 218 |
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
| 219 |
+
issues.append("Import statements are not allowed in sandbox mode.")
|
| 220 |
+
if isinstance(node, ast.Call):
|
| 221 |
+
if isinstance(node.func, ast.Name) and node.func.id in BLOCKED_CALLS:
|
| 222 |
+
issues.append(f"Blocked function call: {node.func.id}")
|
| 223 |
+
if isinstance(node.func, ast.Attribute):
|
| 224 |
+
try:
|
| 225 |
+
full = ast.unparse(node.func)
|
| 226 |
+
except Exception:
|
| 227 |
+
full = None
|
| 228 |
+
if full:
|
| 229 |
+
for prefix in BLOCKED_ATTR_PREFIX:
|
| 230 |
+
if full.startswith(prefix):
|
| 231 |
+
issues.append(f"Blocked attribute call: {full}")
|
| 232 |
+
if isinstance(node, ast.Name) and node.id in BLOCKED_NAMES:
|
| 233 |
+
issues.append(f"Blocked name reference: {node.id}")
|
| 234 |
+
if isinstance(node, ast.Attribute) and node.attr.startswith("__"):
|
| 235 |
+
issues.append(f"Blocked dunder attribute access: {node.attr}")
|
| 236 |
+
|
| 237 |
+
return len(issues) == 0, sorted(set(issues))
|
| 238 |
+
|
| 239 |
+
def run_in_sandbox(full_code: str, test_code: str = ""):
|
| 240 |
+
syntax_ok, syntax_msg = ast_syntax_check(full_code)
|
| 241 |
+
if not syntax_ok:
|
| 242 |
+
return {"ok": False, "stage": "syntax", "stdout": "", "stderr": syntax_msg, "returncode": None}
|
| 243 |
+
|
| 244 |
+
safe_ok, issues = static_safety_scan(full_code + "\n" + (test_code or ""))
|
| 245 |
+
if not safe_ok:
|
| 246 |
+
return {"ok": False, "stage": "static_safety", "stdout": "", "stderr": "\n".join(issues), "returncode": None}
|
| 247 |
+
|
| 248 |
+
runner = f"""
|
| 249 |
+
{full_code}
|
| 250 |
+
|
| 251 |
+
def __run_user_tests__():
|
| 252 |
+
{textwrap.indent(test_code if test_code.strip() else " pass", " ")}
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
try:
|
| 256 |
+
__run_user_tests__()
|
| 257 |
+
print("SANDBOX_OK")
|
| 258 |
+
except AssertionError as e:
|
| 259 |
+
print("ASSERTION FAILED")
|
| 260 |
+
raise e
|
| 261 |
+
"""
|
| 262 |
+
with tempfile.TemporaryDirectory() as td:
|
| 263 |
+
script_path = os.path.join(td, "runner.py")
|
| 264 |
+
with open(script_path, "w", encoding="utf-8") as f:
|
| 265 |
+
f.write(runner)
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
proc = subprocess.run(
|
| 269 |
+
["python", "-I", script_path],
|
| 270 |
+
capture_output=True, text=True, timeout=SANDBOX_TIMEOUT_SEC, cwd=td,
|
| 271 |
+
)
|
| 272 |
+
ok = (proc.returncode == 0) and ("SANDBOX_OK" in proc.stdout)
|
| 273 |
+
return {"ok": ok, "stage": "runtime", "stdout": proc.stdout, "stderr": proc.stderr, "returncode": proc.returncode}
|
| 274 |
+
except subprocess.TimeoutExpired as e:
|
| 275 |
+
return {"ok": False, "stage": "timeout", "stdout": e.stdout or "", "stderr": "Execution timed out.", "returncode": None}
|
| 276 |
+
|
| 277 |
+
def run_pylint_on_code(full_code: str):
|
| 278 |
+
syntax_ok, syntax_msg = ast_syntax_check(full_code)
|
| 279 |
+
if not syntax_ok:
|
| 280 |
+
return {"score": None, "report": syntax_msg}
|
| 281 |
+
|
| 282 |
+
with tempfile.TemporaryDirectory() as td:
|
| 283 |
+
file_path = os.path.join(td, "candidate.py")
|
| 284 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 285 |
+
f.write(full_code)
|
| 286 |
+
|
| 287 |
+
output = StringIO()
|
| 288 |
+
reporter = TextReporter(output)
|
| 289 |
+
try:
|
| 290 |
+
Run([file_path, "--score=y", "--reports=n"], reporter=reporter, exit=False)
|
| 291 |
+
report_text = output.getvalue()
|
| 292 |
+
score_match = re.search(r"rated at ([\-0-9\.]+)/10", report_text)
|
| 293 |
+
score = float(score_match.group(1)) if score_match else None
|
| 294 |
+
return {"score": score, "report": report_text.strip() if report_text.strip() else "No pylint messages."}
|
| 295 |
+
except Exception as e:
|
| 296 |
+
return {"score": None, "report": f"Pylint failed: {repr(e)}"}
|
| 297 |
+
|
| 298 |
+
# ==========================================
|
| 299 |
+
# GRADIO UI & PIPELINE
|
| 300 |
+
# ==========================================
|
| 301 |
+
|
| 302 |
+
def analyze_prompt(
|
| 303 |
+
model_choice: str,
|
| 304 |
+
prompt: str,
|
| 305 |
+
unit_tests: str,
|
| 306 |
+
max_new_tokens: int,
|
| 307 |
+
use_base_model: bool,
|
| 308 |
+
use_finetuned_model: bool,
|
| 309 |
+
):
|
| 310 |
+
if not prompt or not prompt.strip():
|
| 311 |
+
return ("Please enter a prompt.", "", "", "", "", "", "", "")
|
| 312 |
+
|
| 313 |
+
prompt = auto_fix_prompt_indentation(prompt.rstrip())
|
| 314 |
+
|
| 315 |
+
func_name_match = re.search(r"def\s+([a-zA-Z0-9_]+)\s*\(", prompt)
|
| 316 |
+
if func_name_match and unit_tests:
|
| 317 |
+
actual_func_name = func_name_match.group(1)
|
| 318 |
+
unit_tests = unit_tests.replace("candidate(", f"{actual_func_name}(")
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
tokenizer, model = load_model_if_needed(model_choice)
|
| 322 |
+
except Exception as e:
|
| 323 |
+
# Will output the exact loading error into the Run Summary box
|
| 324 |
+
err_msg = f"❌ MODEL LOAD ERROR ❌\n\n{str(e)}"
|
| 325 |
+
return (err_msg, "", "", err_msg, "", "", err_msg, "")
|
| 326 |
+
|
| 327 |
+
def process_one(label, is_base_run):
|
| 328 |
+
body = generate_body(model, tokenizer, prompt, max_new_tokens=max_new_tokens, is_base=is_base_run)
|
| 329 |
+
full_code = prompt + body
|
| 330 |
+
|
| 331 |
+
syntax_ok, syntax_msg = ast_syntax_check(full_code)
|
| 332 |
+
safe_ok, safe_issues = static_safety_scan(full_code)
|
| 333 |
+
|
| 334 |
+
sandbox_result = run_in_sandbox(full_code, unit_tests or "")
|
| 335 |
+
pylint_result = run_pylint_on_code(full_code)
|
| 336 |
+
|
| 337 |
+
safety_text = "Safe check: PASS" if safe_ok else "Safe check: FAIL\n" + "\n".join(safe_issues)
|
| 338 |
+
sandbox_text = (
|
| 339 |
+
f"Sandbox: {'PASS' if sandbox_result['ok'] else 'FAIL'}\n"
|
| 340 |
+
f"Stage: {sandbox_result['stage']}\n"
|
| 341 |
+
f"Return code: {sandbox_result['returncode']}\n"
|
| 342 |
+
f"STDOUT:\n{sandbox_result['stdout']}\n"
|
| 343 |
+
f"STDERR:\n{sandbox_result['stderr']}"
|
| 344 |
+
)
|
| 345 |
+
pylint_text = f"Pylint score: {pylint_result['score']}\n\n{pylint_result['report']}"
|
| 346 |
+
|
| 347 |
+
return {
|
| 348 |
+
"label": label, "body": body, "full_code": full_code,
|
| 349 |
+
"syntax": syntax_msg, "safety": safety_text, "sandbox": sandbox_text, "pylint": pylint_text,
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
base_result, ft_result = None, None
|
| 353 |
+
|
| 354 |
+
if use_base_model:
|
| 355 |
+
base_result = process_one("Base Model", is_base_run=True)
|
| 356 |
+
if use_finetuned_model:
|
| 357 |
+
ft_result = process_one("Fine-Tuned Model", is_base_run=False)
|
| 358 |
+
|
| 359 |
+
base_body = base_result["body"] if base_result else ""
|
| 360 |
+
base_full = base_result["full_code"] if base_result else ""
|
| 361 |
+
base_diag = (f"{base_result['syntax']}\n\n{base_result['safety']}\n\n{base_result['sandbox']}\n\n{base_result['pylint']}" if base_result else "")
|
| 362 |
+
|
| 363 |
+
ft_body = ft_result["body"] if ft_result else ""
|
| 364 |
+
ft_full = ft_result["full_code"] if ft_result else ""
|
| 365 |
+
ft_diag = (f"{ft_result['syntax']}\n\n{ft_result['safety']}\n\n{ft_result['sandbox']}\n\n{ft_result['pylint']}" if ft_result else "")
|
| 366 |
+
|
| 367 |
+
verdict = []
|
| 368 |
+
if base_result: verdict.append("✅ Base model ran.")
|
| 369 |
+
if ft_result: verdict.append("✅ Fine-tuned model ran.")
|
| 370 |
+
if ft_result and base_result: verdict.append("Use the side-by-side outputs below to compare results.")
|
| 371 |
+
|
| 372 |
+
return ("\n".join(verdict), base_body, base_full, base_diag, ft_body, ft_full, ft_diag, f"Prompt length: {len(prompt)} chars")
|
| 373 |
+
|
| 374 |
+
EXAMPLE_PROMPTS = [
|
| 375 |
+
[
|
| 376 |
+
"Qwen 2.5 Coder 1.5B",
|
| 377 |
+
'''def add_numbers(a, b):\n """Return the sum of two integers."""\n''',
|
| 378 |
+
'''assert candidate(2, 3) == 5\nassert candidate(-1, 1) == 0''',
|
| 379 |
+
64, True, True,
|
| 380 |
+
],
|
| 381 |
+
[
|
| 382 |
+
"StarCoder2 3B",
|
| 383 |
+
'''def is_even(n):\n """Return True if n is even, otherwise False."""\n''',
|
| 384 |
+
'''assert candidate(4) is True\nassert candidate(5) is False''',
|
| 385 |
+
64, True, True,
|
| 386 |
+
],
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
with gr.Blocks(title="Multi-Model CodeGen Demo") as demo:
|
| 390 |
+
# ---------------------------------------------------------
|
| 391 |
+
# 1. HEADER & BOLD DISCLAIMER
|
| 392 |
+
# ---------------------------------------------------------
|
| 393 |
+
gr.Markdown("# Multi-Model CodeGen Demo (Base vs LoRA)")
|
| 394 |
+
gr.HTML("""
|
| 395 |
+
<div style="background-color: #ffe6e6; border-left: 6px solid #ff4d4d; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
|
| 396 |
+
<h3 style="color: #cc0000; margin-top: 0;">⚠️ IMPORTANT DISCLAIMER</h3>
|
| 397 |
+
<p style="color: #333; font-weight: bold; font-size: 14px;">
|
| 398 |
+
1. When using this app, please be patient! You must wait for the selected model to fully load into memory.Model loading status can be seen at terminal where you have run app.py<br>
|
| 399 |
+
2. Make sure to fill in the boxes exactly as per the instructions.<br>
|
| 400 |
+
3. The maximum accuracy of these models is ~55%, so the generated code may occasionally fail tests.
|
| 401 |
+
</p>
|
| 402 |
+
</div>
|
| 403 |
+
""")
|
| 404 |
+
|
| 405 |
+
# ---------------------------------------------------------
|
| 406 |
+
# 2. MAIN TWO-COLUMN LAYOUT
|
| 407 |
+
# ---------------------------------------------------------
|
| 408 |
+
with gr.Row():
|
| 409 |
+
|
| 410 |
+
# LEFT COLUMN (Inputs)
|
| 411 |
+
with gr.Column(scale=1):
|
| 412 |
+
model_dropdown = gr.Dropdown(
|
| 413 |
+
choices=list(MODEL_CONFIGS.keys()),
|
| 414 |
+
value="Qwen 2.5 Coder 1.5B",
|
| 415 |
+
label="Select Base Architecture"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
prompt_in = gr.Textbox(
|
| 419 |
+
label="Prompt (function signature + docstring)", lines=10,
|
| 420 |
+
placeholder='def add_numbers(a, b):\n """Return the sum of two integers."""\n',
|
| 421 |
+
)
|
| 422 |
+
tests_in = gr.Textbox(
|
| 423 |
+
label="Optional Unit Tests (Use 'candidate' or the actual function name)", lines=8,
|
| 424 |
+
placeholder='assert candidate(2, 3) == 5\nassert candidate(-1, 1) == 0',
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
with gr.Row():
|
| 428 |
+
max_tokens_in = gr.Slider(minimum=16, maximum=512, value=96, step=16, label="Max New Tokens")
|
| 429 |
+
|
| 430 |
+
with gr.Row():
|
| 431 |
+
use_base_in = gr.Checkbox(value=True, label="Run Base Model")
|
| 432 |
+
use_ft_in = gr.Checkbox(value=True, label="Run Fine-Tuned Model")
|
| 433 |
+
|
| 434 |
+
run_btn = gr.Button("Generate + Analyze", variant="primary")
|
| 435 |
+
clear_btn = gr.Button("Clear")
|
| 436 |
+
|
| 437 |
+
# RIGHT COLUMN (Outputs & Examples)
|
| 438 |
+
with gr.Column(scale=1):
|
| 439 |
+
verdict_out = gr.Textbox(label="Run Summary", lines=5)
|
| 440 |
+
meta_out = gr.Textbox(label="Meta", lines=2)
|
| 441 |
+
|
| 442 |
+
# Moved Examples to exactly below the Meta box inside the right column!
|
| 443 |
+
gr.Markdown("### ⬇️ Try an Example Below:")
|
| 444 |
+
gr.Examples(
|
| 445 |
+
examples=EXAMPLE_PROMPTS,
|
| 446 |
+
inputs=[model_dropdown, prompt_in, tests_in, max_tokens_in, use_base_in, use_ft_in],
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# ---------------------------------------------------------
|
| 450 |
+
# 3. OUTPUT TABS
|
| 451 |
+
# ---------------------------------------------------------
|
| 452 |
+
with gr.Tab("Base Model"):
|
| 453 |
+
base_body_out = gr.Code(label="Base Model: Generated Body", language="python")
|
| 454 |
+
base_full_out = gr.Code(label="Base Model: Full Reconstructed Code", language="python")
|
| 455 |
+
base_diag_out = gr.Textbox(label="Base Model Diagnostics", lines=22)
|
| 456 |
+
|
| 457 |
+
with gr.Tab("Fine-Tuned Model"):
|
| 458 |
+
ft_body_out = gr.Code(label="Fine-Tuned: Generated Body", language="python")
|
| 459 |
+
ft_full_out = gr.Code(label="Fine-Tuned: Full Reconstructed Code", language="python")
|
| 460 |
+
ft_diag_out = gr.Textbox(label="Fine-Tuned Diagnostics", lines=22)
|
| 461 |
+
|
| 462 |
+
# ---------------------------------------------------------
|
| 463 |
+
# 4. EVENT LISTENERS
|
| 464 |
+
# ---------------------------------------------------------
|
| 465 |
+
run_btn.click(
|
| 466 |
+
fn=analyze_prompt,
|
| 467 |
+
inputs=[model_dropdown, prompt_in, tests_in, max_tokens_in, use_base_in, use_ft_in],
|
| 468 |
+
outputs=[verdict_out, base_body_out, base_full_out, base_diag_out, ft_body_out, ft_full_out, ft_diag_out, meta_out],
|
| 469 |
+
)
|
| 470 |
+
clear_btn.click(
|
| 471 |
+
fn=lambda: ("", "", "", "", "", "", "", ""), inputs=[],
|
| 472 |
+
outputs=[verdict_out, base_body_out, base_full_out, base_diag_out, ft_body_out, ft_full_out, ft_diag_out, meta_out],
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# if __name__ == "__main__":
|
| 476 |
+
# demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
| 477 |
+
if __name__ == "__main__":
|
| 478 |
+
demo.queue().launch(share=True)
|
req.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.45,<5
|
| 2 |
+
peft>=0.12,<1
|
| 3 |
+
accelerate>=0.34,<2
|
| 4 |
+
gradio>=4.44,<6
|
| 5 |
+
pylint>=3.2,<4
|
| 6 |
+
bitsandbytes>=0.49.0
|
| 7 |
+
safetensors
|
| 8 |
+
tree-sitter
|
| 9 |
+
tree-sitter-python
|
| 10 |
+
torch
|
| 11 |
+
datasets
|
weights/docstring_codegen_qwen15b_lora/adapter_config.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alora_invocation_tokens": null,
|
| 3 |
+
"alpha_pattern": {},
|
| 4 |
+
"arrow_config": null,
|
| 5 |
+
"auto_mapping": null,
|
| 6 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-Coder-1.5B",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"corda_config": null,
|
| 9 |
+
"ensure_weight_tying": false,
|
| 10 |
+
"eva_config": null,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_lora_weights": true,
|
| 15 |
+
"layer_replication": null,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"loftq_config": {},
|
| 19 |
+
"lora_alpha": 32,
|
| 20 |
+
"lora_bias": false,
|
| 21 |
+
"lora_dropout": 0.05,
|
| 22 |
+
"megatron_config": null,
|
| 23 |
+
"megatron_core": "megatron.core",
|
| 24 |
+
"modules_to_save": null,
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"peft_version": "0.18.1",
|
| 27 |
+
"qalora_group_size": 16,
|
| 28 |
+
"r": 16,
|
| 29 |
+
"rank_pattern": {},
|
| 30 |
+
"revision": null,
|
| 31 |
+
"target_modules": [
|
| 32 |
+
"k_proj",
|
| 33 |
+
"v_proj",
|
| 34 |
+
"up_proj",
|
| 35 |
+
"q_proj",
|
| 36 |
+
"down_proj",
|
| 37 |
+
"o_proj",
|
| 38 |
+
"gate_proj"
|
| 39 |
+
],
|
| 40 |
+
"target_parameters": null,
|
| 41 |
+
"task_type": "CAUSAL_LM",
|
| 42 |
+
"trainable_token_indices": null,
|
| 43 |
+
"use_dora": false,
|
| 44 |
+
"use_qalora": false,
|
| 45 |
+
"use_rslora": false
|
| 46 |
+
}
|
weights/docstring_codegen_qwen15b_lora/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:962b7838c14339a4476f554e956080c44096a23a63a564c69e5e9817f3f6506b
|
| 3 |
+
size 73911112
|