| """Run the bio-experiment environment with Qwen3.5-0.8B as the planning agent.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
| from models import ( |
| ActionType, |
| ExperimentAction, |
| ExperimentObservation, |
| OutputType, |
| build_agent_observation_context, |
| build_agent_system_prompt, |
| ) |
| from server.hackathon_environment import BioExperimentEnvironment |
|
|
| DASHBOARD_STATE_PATH = Path(__file__).parent / "_dashboard_state.json" |
| DASHBOARD_CMD_PATH = Path(__file__).parent / "_dashboard_cmd.json" |
|
|
| USE_PIPELINE = os.getenv("RUN_AGENT_USE_PIPELINE", "0").strip().lower() not in {"0", "false", "off"} |
|
|
| def _parse_thinking_flag() -> bool: |
| import sys |
| if "--no-thinking" in sys.argv: |
| return False |
| if "--thinking" in sys.argv: |
| return True |
| return os.getenv("RUN_AGENT_ENABLE_THINKING", "1").strip().lower() not in {"0", "false", "off"} |
|
|
| ENABLE_THINKING = _parse_thinking_flag() |
|
|
| MODEL_ID = "Qwen/Qwen3.5-2B" |
| MAX_EPISODE_STEPS = int(os.getenv("RUN_AGENT_MAX_EPISODE_STEPS", "20")) |
| PIPELINE_TASK = "text-generation" |
|
|
| ACTION_TYPES = [a.value for a in ActionType] |
| ACTION_TYPE_ALIASES = { |
| "collect_samples": ActionType.COLLECT_SAMPLE.value, |
| "collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value, |
| "collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value, |
| "prepare_sc_library": ActionType.PREPARE_LIBRARY.value, |
| "sequence_single_cells": ActionType.SEQUENCE_CELLS.value, |
| "qc": ActionType.RUN_QC.value, |
| "run_quality_control": ActionType.RUN_QC.value, |
| "cluster": ActionType.CLUSTER_CELLS.value, |
| "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, |
| "differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, |
| "trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value, |
| "infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value, |
| "network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value, |
| "select_markers": ActionType.MARKER_SELECTION.value, |
| "final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value, |
| } |
|
|
| SYSTEM_PROMPT = build_agent_system_prompt() |
|
|
| STANDARD_PIPELINE_ORDER = [ |
| ActionType.COLLECT_SAMPLE, |
| ActionType.SELECT_COHORT, |
| ActionType.PREPARE_LIBRARY, |
| ActionType.SEQUENCE_CELLS, |
| ActionType.RUN_QC, |
| ActionType.FILTER_DATA, |
| ActionType.NORMALIZE_DATA, |
| ActionType.INTEGRATE_BATCHES, |
| ActionType.CLUSTER_CELLS, |
| ActionType.DIFFERENTIAL_EXPRESSION, |
| ActionType.PATHWAY_ENRICHMENT, |
| ActionType.MARKER_SELECTION, |
| ActionType.TRAJECTORY_ANALYSIS, |
| ActionType.REGULATORY_NETWORK_INFERENCE, |
| ActionType.SYNTHESIZE_CONCLUSION, |
| ] |
|
|
| MODEL_RESPONSE_PREVIEW_CHARS = int( |
| os.getenv("RUN_AGENT_MODEL_RESPONSE_PREVIEW_CHARS", "240") |
| ) |
|
|
|
|
| def compact_preview(value: Any, max_chars: int = 160) -> str: |
| try: |
| text = json.dumps(value, ensure_ascii=True, sort_keys=True) |
| except TypeError: |
| text = str(value) |
| text = re.sub(r"\s+", " ", text).strip() |
| if len(text) <= max_chars: |
| return text |
| return text[: max_chars - 3] + "..." |
|
|
|
|
| def format_observation(obs: ExperimentObservation) -> str: |
| parts = [ |
| f"TASK: {obs.task.problem_statement}", |
| f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}", |
| f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}", |
| f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d", |
| ] |
| context = build_agent_observation_context(obs, max_tools=5, max_assays=2) |
| if context: |
| parts.append(context) |
| if obs.pipeline_history: |
| last5 = obs.pipeline_history[-5:] |
| parts.append("Recent history:") |
| for h in last5: |
| tag = "OK" if h.success else "FAIL" |
| line = f" [{tag}] {h.action_type.value}" |
| if h.method: |
| line += f" ({h.method})" |
| line += f": {h.output_summary[:80]}" |
| parts.append(line) |
|
|
| completed = {h.action_type for h in obs.pipeline_history if h.success} |
| if completed: |
| parts.append(f"Completed steps (do NOT repeat): {', '.join(sorted(a.value for a in completed))}") |
| remaining = [a.value for a in STANDARD_PIPELINE_ORDER if a not in completed] |
| if remaining: |
| parts.append(f"Remaining steps (choose one): {', '.join(remaining)}") |
|
|
| if obs.latest_output and obs.latest_output.data: |
| parts.append( |
| f"Latest data: {compact_preview(obs.latest_output.data, 200)}" |
| ) |
| if obs.rule_violations: |
| parts.append(f"VIOLATIONS: {obs.rule_violations}") |
| if obs.discovered_markers: |
| parts.append(f"Markers found so far: {obs.discovered_markers[:5]}") |
|
|
| parts.append( |
| 'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n' |
| '{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}' |
| ) |
| return "\n".join(parts) |
|
|
|
|
| def _repair_truncated_json(text: str) -> Optional[str]: |
| """Try to repair JSON truncated mid-value (common with small LLMs).""" |
| s = text.strip() |
| if not s.startswith("{"): |
| return None |
|
|
| |
| s = re.sub(r',\s*"[^"\n]*$', '', s) |
| s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s) |
|
|
| in_string = False |
| escape = False |
| for ch in s: |
| if escape: |
| escape = False |
| continue |
| if ch == "\\": |
| escape = True |
| continue |
| if ch == '"': |
| in_string = not in_string |
|
|
| if in_string: |
| s += '"' |
|
|
| open_braces = s.count("{") - s.count("}") |
| open_brackets = s.count("[") - s.count("]") |
| s += "]" * max(0, open_brackets) |
| s += "}" * max(0, open_braces) |
|
|
| try: |
| obj = json.loads(s) |
| if isinstance(obj, dict): |
| return s |
| except json.JSONDecodeError: |
| pass |
|
|
| s = re.sub(r',\s*([}\]])', r'\1', s) |
| try: |
| obj = json.loads(s) |
| if isinstance(obj, dict): |
| return s |
| except json.JSONDecodeError: |
| pass |
| return None |
|
|
|
|
| def _normalize_jsonish_text(text: str) -> str: |
| """Normalize common near-JSON artifacts emitted by small local models.""" |
| text = _strip_js_comments(text) |
| text = re.sub(r'(?<=:\s)\bNone\b', 'null', text) |
| text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text) |
| text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text) |
| text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text) |
| return text |
|
|
|
|
| def _strip_js_comments(text: str) -> str: |
| """Remove // and /* */ comments that small LLMs inject into JSON.""" |
| text = re.sub(r'//[^\n]*', '', text) |
| text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) |
| return text |
|
|
|
|
| def extract_json_object(text: str) -> Optional[Dict[str, Any]]: |
| stripped = _normalize_jsonish_text(text).strip() |
| if stripped.startswith('"') and stripped.endswith('"'): |
| try: |
| unwrapped = json.loads(stripped) |
| except json.JSONDecodeError: |
| unwrapped = None |
| if isinstance(unwrapped, str): |
| stripped = _normalize_jsonish_text(unwrapped).strip() |
| fence_prefix = "```" |
| if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix): |
| lines = stripped.splitlines() |
| if len(lines) >= 3: |
| stripped = "\n".join(lines[1:-1]).strip() |
|
|
| candidates: List[str] = [stripped] |
| start = stripped.find("{") |
| while start != -1: |
| depth = 0 |
| for idx in range(start, len(stripped)): |
| char = stripped[idx] |
| if char == "{": |
| depth += 1 |
| elif char == "}": |
| depth -= 1 |
| if depth == 0: |
| candidates.append(stripped[start:idx + 1]) |
| break |
| start = stripped.find("{", start + 1) |
|
|
| repaired = None |
| first_brace = stripped.find("{") |
| if first_brace != -1: |
| repaired = _repair_truncated_json(stripped[first_brace:]) |
| if repaired is not None: |
| candidates.append(repaired) |
|
|
| candidates.sort(key=len, reverse=True) |
|
|
| for candidate in candidates: |
| try: |
| parsed = json.loads(candidate) |
| except json.JSONDecodeError: |
| continue |
| if isinstance(parsed, dict): |
| return parsed |
|
|
| return None |
|
|
|
|
| def _edit_distance(a: str, b: str) -> int: |
| if len(a) < len(b): |
| return _edit_distance(b, a) |
| if not b: |
| return len(a) |
| prev = list(range(len(b) + 1)) |
| for i, ca in enumerate(a): |
| curr = [i + 1] |
| for j, cb in enumerate(b): |
| curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb))) |
| prev = curr |
| return prev[-1] |
|
|
|
|
| def get_payload_value(payload: Dict[str, Any], *names: str) -> Any: |
| for name in names: |
| if name in payload: |
| return payload[name] |
|
|
| lowered = { |
| str(key).lower(): value |
| for key, value in payload.items() |
| } |
| for name in names: |
| if name.lower() in lowered: |
| return lowered[name.lower()] |
|
|
| for key, value in lowered.items(): |
| for name in names: |
| threshold = max(2, len(name) // 3) |
| if _edit_distance(key, name.lower()) <= threshold: |
| return value |
| return None |
|
|
|
|
| def normalize_optional_string(value: Any) -> Optional[str]: |
| if value is None or isinstance(value, bool): |
| return None |
| if isinstance(value, str): |
| value = value.strip() |
| return value or None |
| if isinstance(value, (int, float)): |
| return str(value) |
| return compact_preview(value, 80) |
|
|
|
|
| def normalize_action_type(raw_action_type: Any) -> Optional[str]: |
| if not isinstance(raw_action_type, str): |
| return None |
|
|
| candidate = raw_action_type.strip().lower() |
| if candidate in ACTION_TYPES: |
| return candidate |
| if candidate in ACTION_TYPE_ALIASES: |
| return ACTION_TYPE_ALIASES[candidate] |
|
|
| candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_") |
| if candidate in ACTION_TYPES: |
| return candidate |
| if candidate in ACTION_TYPE_ALIASES: |
| return ACTION_TYPE_ALIASES[candidate] |
|
|
| heuristics = [ |
| (("collect", "sample"), ActionType.COLLECT_SAMPLE.value), |
| (("library",), ActionType.PREPARE_LIBRARY.value), |
| (("sequence",), ActionType.SEQUENCE_CELLS.value), |
| (("qc",), ActionType.RUN_QC.value), |
| (("quality", "control"), ActionType.RUN_QC.value), |
| (("filter",), ActionType.FILTER_DATA.value), |
| (("normal",), ActionType.NORMALIZE_DATA.value), |
| (("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value), |
| (("cluster",), ActionType.CLUSTER_CELLS.value), |
| (("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value), |
| (("pathway",), ActionType.PATHWAY_ENRICHMENT.value), |
| (("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value), |
| (("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value), |
| (("marker",), ActionType.MARKER_SELECTION.value), |
| (("validat", "marker"), ActionType.VALIDATE_MARKER.value), |
| (("followup",), ActionType.DESIGN_FOLLOWUP.value), |
| (("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value), |
| (("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value), |
| ] |
| for fragments, normalized in heuristics: |
| if all(fragment in candidate for fragment in fragments): |
| return normalized |
| return None |
|
|
|
|
| def should_block_failed_reattempt( |
| history: List[Any], action_type: ActionType |
| ) -> bool: |
| last_failed_idx = None |
| last_success_idx = None |
|
|
| for idx, record in enumerate(history): |
| if record.action_type != action_type: |
| continue |
| if record.success: |
| last_success_idx = idx |
| else: |
| last_failed_idx = idx |
|
|
| if last_failed_idx is None: |
| return False |
|
|
| |
| |
| if last_success_idx is not None and last_success_idx > last_failed_idx: |
| return False |
| for record in history[last_failed_idx + 1:]: |
| if record.success and record.action_type != action_type: |
| return False |
| return True |
|
|
|
|
| def parse_action(text: str) -> Optional[ExperimentAction]: |
| d = extract_json_object(text) |
| if d is not None: |
| action_type = normalize_action_type(get_payload_value(d, "action_type")) |
| if action_type is None: |
| pass |
| else: |
| parameters = get_payload_value(d, "parameters", "params") or {} |
| if not isinstance(parameters, dict): |
| parameters = {} |
|
|
| confidence = get_payload_value(d, "confidence") |
| if confidence is None: |
| confidence = 0.5 |
| try: |
| confidence = float(confidence) |
| except (TypeError, ValueError): |
| confidence = 0.5 |
|
|
| justification = get_payload_value( |
| d, "justification", "justifyement", "reasoning", "rationale", "reason" |
| ) |
| if justification is not None and not isinstance(justification, str): |
| justification = compact_preview(justification, 200) |
| method = normalize_optional_string(get_payload_value(d, "method")) |
|
|
| return ExperimentAction( |
| action_type=ActionType(action_type), |
| method=method, |
| parameters=parameters, |
| justification=justification, |
| confidence=min(1.0, max(0.0, confidence)), |
| ) |
|
|
| action_match = re.search( |
| r'["\']action_type["\']\s*:\s*["\']([^"\']+)', |
| text, |
| re.IGNORECASE, |
| ) |
| if not action_match: |
| return None |
|
|
| action_type = normalize_action_type(action_match.group(1)) |
| if action_type is None: |
| return None |
|
|
| method_match = re.search( |
| r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)', |
| text, |
| re.IGNORECASE, |
| ) |
| confidence_match = re.search( |
| r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)', |
| text, |
| re.IGNORECASE, |
| ) |
| justification_match = re.search( |
| r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)', |
| text, |
| re.DOTALL | re.IGNORECASE, |
| ) |
|
|
| confidence = 0.5 |
| if confidence_match: |
| try: |
| confidence = float(confidence_match.group(1)) |
| except ValueError: |
| confidence = 0.5 |
|
|
| justification = None |
| if justification_match: |
| try: |
| justification = json.loads(f'"{justification_match.group(1)}"') |
| except json.JSONDecodeError: |
| justification = justification_match.group(1) |
|
|
| method = None |
| if method_match: |
| raw_method = method_match.group(1) |
| if raw_method.startswith('"') and raw_method.endswith('"'): |
| try: |
| method = json.loads(raw_method) |
| except json.JSONDecodeError: |
| method = raw_method.strip('"') |
| elif raw_method.lower() not in {"null", "none", "true", "false"}: |
| method = raw_method |
| method = normalize_optional_string(method) |
|
|
| return ExperimentAction( |
| action_type=ActionType(action_type), |
| method=method, |
| parameters={}, |
| justification=justification, |
| confidence=min(1.0, max(0.0, confidence)), |
| ) |
|
|
|
|
| def should_force_terminal_conclusion( |
| action: ExperimentAction, |
| completed_types: set[ActionType], |
| ) -> bool: |
| meta_repeatables = { |
| ActionType.DESIGN_FOLLOWUP, |
| ActionType.REQUEST_SUBAGENT_REVIEW, |
| } |
| return ( |
| action.action_type in meta_repeatables |
| and action.action_type in completed_types |
| and ActionType.SYNTHESIZE_CONCLUSION not in completed_types |
| ) |
|
|
|
|
| def _unique_nonempty(items: List[str], limit: int = 5) -> List[str]: |
| seen: set[str] = set() |
| result: List[str] = [] |
| for raw in items: |
| value = normalize_optional_string(raw) |
| if not value: |
| continue |
| key = value.upper() |
| if key in seen: |
| continue |
| seen.add(key) |
| result.append(value) |
| if len(result) >= limit: |
| break |
| return result |
|
|
|
|
| def _infer_conclusion_evidence( |
| obs: ExperimentObservation, |
| ) -> tuple[List[str], List[str], Dict[str, float]]: |
| top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5) |
| causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5) |
| predicted_pathways: Dict[str, float] = {} |
|
|
| for output in reversed(obs.all_outputs): |
| if not output.success: |
| continue |
|
|
| data = output.data or {} |
| if not top_markers: |
| if output.output_type == OutputType.MARKER_RESULT: |
| top_markers = _unique_nonempty(list(data.get("markers", [])), limit=5) |
| elif output.output_type == OutputType.DE_RESULT: |
| top_markers = _unique_nonempty( |
| [item.get("gene") for item in data.get("top_genes", []) if isinstance(item, dict)], |
| limit=5, |
| ) |
|
|
| if output.output_type == OutputType.PATHWAY_RESULT and not predicted_pathways: |
| for item in data.get("top_pathways", []): |
| if not isinstance(item, dict): |
| continue |
| pathway = normalize_optional_string(item.get("pathway")) |
| score = item.get("score") |
| if pathway and isinstance(score, (int, float)): |
| predicted_pathways[pathway] = float(score) |
| if len(predicted_pathways) >= 5: |
| break |
|
|
| if not causal_mechanisms: |
| if output.output_type == OutputType.PATHWAY_RESULT: |
| causal_mechanisms = _unique_nonempty( |
| [item.get("pathway") for item in data.get("top_pathways", []) if isinstance(item, dict)], |
| limit=5, |
| ) |
| elif output.output_type == OutputType.NETWORK_RESULT: |
| causal_mechanisms = _unique_nonempty( |
| list(data.get("top_regulators", [])), |
| limit=5, |
| ) |
|
|
| if top_markers and causal_mechanisms and predicted_pathways: |
| break |
|
|
| return top_markers, causal_mechanisms, predicted_pathways |
|
|
|
|
| def ensure_conclusion_claims( |
| obs: ExperimentObservation, |
| action: ExperimentAction, |
| ) -> ExperimentAction: |
| if action.action_type != ActionType.SYNTHESIZE_CONCLUSION: |
| return action |
|
|
| parameters = dict(action.parameters or {}) |
| raw_claims = parameters.get("claims") |
| if isinstance(raw_claims, list) and raw_claims: |
| normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)] |
| if normalized_claims: |
| parameters["claims"] = normalized_claims |
| if parameters != action.parameters: |
| return action.model_copy(update={"parameters": parameters}) |
| return action |
|
|
| top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs) |
| claim_type = "causal" if causal_mechanisms else "correlational" |
| conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions" |
| claim = action.justification or f"Final synthesis for {conditions}." |
|
|
| parameters["claims"] = [{ |
| "top_markers": top_markers, |
| "causal_mechanisms": causal_mechanisms, |
| "predicted_pathways": predicted_pathways, |
| "confidence": action.confidence, |
| "claim_type": claim_type, |
| "claim": claim, |
| }] |
| if not action.justification: |
| action = action.model_copy(update={"justification": claim}) |
| return action.model_copy(update={"parameters": parameters}) |
|
|
|
|
| def write_dashboard_state( |
| env: BioExperimentEnvironment, |
| obs: ExperimentObservation, |
| *, |
| step: int, |
| cumulative_reward: float, |
| model_response: str = "", |
| model_thinking: str = "", |
| action: Optional[ExperimentAction] = None, |
| gen_time: float = 0.0, |
| episode_done: bool = False, |
| ) -> None: |
| """Serialise the full world state (observable + latent) for the dashboard.""" |
| latent = env._latent |
| snapshot: Dict[str, Any] = { |
| "timestamp": time.time(), |
| "step": step, |
| "episode_done": episode_done, |
| "cumulative_reward": cumulative_reward, |
| "gen_time_s": round(gen_time, 2), |
| "model_response_raw": model_response[:600], |
| "model_thinking": model_thinking[:800], |
| "thinking_enabled": ENABLE_THINKING, |
| } |
|
|
| snapshot["task"] = { |
| "problem_statement": obs.task.problem_statement, |
| "organism": obs.task.organism, |
| "tissue": obs.task.tissue, |
| "modality": obs.task.modality, |
| "conditions": obs.task.conditions, |
| "budget_limit": obs.task.budget_limit, |
| "time_limit_days": obs.task.time_limit_days, |
| } |
|
|
| snapshot["resources"] = { |
| "budget_used": round(obs.resource_usage.budget_used, 2), |
| "budget_remaining": round(obs.resource_usage.budget_remaining, 2), |
| "time_used_days": round(obs.resource_usage.time_used_days, 1), |
| "time_remaining_days": round(obs.resource_usage.time_remaining_days, 1), |
| "samples_consumed": obs.resource_usage.samples_consumed, |
| "compute_hours_used": round(obs.resource_usage.compute_hours_used, 2), |
| } |
|
|
| snapshot["pipeline_history"] = [ |
| { |
| "step_index": h.step_index, |
| "action_type": h.action_type.value, |
| "method": h.method, |
| "output_summary": h.output_summary[:120], |
| "success": h.success, |
| "quality_score": round(h.quality_score, 3), |
| "resource_cost": round(h.resource_cost, 2), |
| "time_cost_days": round(h.time_cost_days, 1), |
| } |
| for h in obs.pipeline_history |
| ] |
|
|
| if action: |
| snapshot["current_action"] = { |
| "action_type": action.action_type.value, |
| "method": action.method, |
| "parameters": action.parameters, |
| "justification": action.justification, |
| "confidence": action.confidence, |
| } |
|
|
| if obs.latest_output: |
| lo = obs.latest_output |
| snapshot["latest_output"] = { |
| "summary": lo.summary, |
| "success": lo.success, |
| "quality_score": round(lo.quality_score, 3), |
| "uncertainty": round(lo.uncertainty, 3), |
| "warnings": lo.warnings, |
| "data_preview": compact_preview(lo.data, 300) if lo.data else None, |
| } |
|
|
| snapshot["discovered_markers"] = obs.discovered_markers[:20] |
| snapshot["candidate_mechanisms"] = obs.candidate_mechanisms[:20] |
| snapshot["rule_violations"] = obs.rule_violations |
| snapshot["uncertainty_summary"] = { |
| k: round(v, 3) for k, v in obs.uncertainty_summary.items() |
| } |
| snapshot["reward_breakdown"] = { |
| k: round(v, 4) for k, v in obs.step_reward_breakdown.items() |
| } |
|
|
| if obs.conclusions: |
| snapshot["conclusions"] = [ |
| { |
| "claim": c.claim, |
| "claim_type": c.claim_type, |
| "confidence": c.confidence, |
| "top_markers": c.top_markers, |
| "causal_mechanisms": c.causal_mechanisms, |
| "predicted_pathways": c.predicted_pathways, |
| } |
| for c in obs.conclusions |
| ] |
|
|
| if latent: |
| bio = latent.biology |
| snapshot["latent"] = { |
| "cell_populations": [ |
| { |
| "name": cp.name, |
| "proportion": round(cp.proportion, 3), |
| "marker_genes": cp.marker_genes[:8], |
| "state": cp.state, |
| } |
| for cp in bio.cell_populations |
| ], |
| "true_markers": bio.true_markers, |
| "causal_mechanisms": bio.causal_mechanisms, |
| "true_pathways": { |
| k: round(v, 3) for k, v in list(bio.true_pathways.items())[:15] |
| }, |
| "true_de_genes_count": sum( |
| len(genes) for genes in bio.true_de_genes.values() |
| ), |
| "true_regulatory_network_size": sum( |
| len(targets) for targets in bio.true_regulatory_network.values() |
| ), |
| "confounders": bio.confounders, |
| "n_true_cells": bio.n_true_cells, |
| "technical": { |
| "ambient_rna_fraction": latent.technical.ambient_rna_fraction, |
| "doublet_rate": latent.technical.doublet_rate, |
| "dropout_rate": latent.technical.dropout_rate, |
| "sample_quality": latent.technical.sample_quality, |
| "library_complexity": latent.technical.library_complexity, |
| "capture_efficiency": latent.technical.capture_efficiency, |
| }, |
| "progress": latent.progress.model_dump(), |
| "hidden_failure_conditions": latent.hidden_failure_conditions, |
| } |
|
|
| try: |
| DASHBOARD_STATE_PATH.write_text( |
| json.dumps(snapshot, indent=2, default=str), encoding="utf-8" |
| ) |
| except Exception: |
| pass |
|
|
|
|
| def log(msg: str) -> None: |
| print(msg, flush=True) |
|
|
|
|
| def build_observation_prompt(obs: ExperimentObservation) -> str: |
| return format_observation(obs) |
|
|
|
|
| def run_with_pipeline(pipe, prompt: str) -> str: |
| try: |
| _pipe_max = 2048 if ENABLE_THINKING else 300 |
| result = pipe(prompt, max_new_tokens=_pipe_max, return_full_text=False) |
| except Exception: |
| return "" |
|
|
| if isinstance(result, list) and result: |
| result = result[0] |
| if isinstance(result, dict): |
| text = result.get("generated_text") or result.get("text") or result.get("answer") |
| elif isinstance(result, str): |
| text = result |
| else: |
| text = "" |
| return text.strip() if isinstance(text, str) else "" |
|
|
|
|
| def resolve_torch_runtime() -> Dict[str, Any]: |
| use_cuda = torch.cuda.is_available() |
| bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False |
| dtype = torch.bfloat16 if bf16 else ( |
| torch.float16 if use_cuda else torch.float32 |
| ) |
| return { |
| "use_cuda": use_cuda, |
| "device": "cuda:0" if use_cuda else "cpu", |
| "dtype": dtype, |
| "device_map": "auto" if use_cuda else None, |
| "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu", |
| } |
|
|
|
|
| def main(): |
| tokenizer = None |
| model = None |
| eos_ids: List[int] = [] |
| active_pipeline = None |
|
|
| runtime = resolve_torch_runtime() |
| log( |
| f"Using local model runtime: device={runtime['device']} " |
| f"name={runtime['device_name']} dtype={runtime['dtype']}" |
| ) |
|
|
| if USE_PIPELINE: |
| log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...") |
| try: |
| active_pipeline = pipeline( |
| PIPELINE_TASK, |
| model=MODEL_ID, |
| trust_remote_code=True, |
| dtype=runtime["dtype"], |
| device=0 if runtime["use_cuda"] else -1, |
| ) |
| log("Pipeline loaded.") |
| except Exception as exc: |
| log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.") |
|
|
| if active_pipeline is None: |
| log(f"Loading tokenizer for {MODEL_ID} ...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, trust_remote_code=True, |
| ) |
| log("Tokenizer loaded. Loading model (this may download files on first run) ...") |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| dtype=runtime["dtype"], |
| device_map=runtime["device_map"], |
| trust_remote_code=True, |
| ) |
| log(f"Model loaded. Device: {model.device}") |
|
|
| if tokenizer.eos_token_id is not None: |
| eos_ids.append(tokenizer.eos_token_id) |
| extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) |
| for tid in extra: |
| if isinstance(tid, int) and tid not in eos_ids: |
| eos_ids.append(tid) |
| log(f"EOS token ids: {eos_ids}") |
|
|
| def check_dashboard_command() -> Optional[Dict[str, Any]]: |
| """Read and consume a command file written by the dashboard.""" |
| try: |
| raw = DASHBOARD_CMD_PATH.read_text(encoding="utf-8") |
| try: |
| DASHBOARD_CMD_PATH.unlink(missing_ok=True) |
| except OSError: |
| |
| pass |
| return json.loads(raw) |
| except (FileNotFoundError, json.JSONDecodeError): |
| return None |
|
|
| def run_episode( |
| scenario_name: Optional[str] = None, |
| custom_ground_truth: Optional[Dict[str, Any]] = None, |
| ): |
| env = BioExperimentEnvironment(scenario_name=scenario_name) |
| obs = env.reset() |
|
|
| if custom_ground_truth and env._latent: |
| gt = custom_ground_truth |
| bio = env._latent.biology |
| if gt.get("true_markers"): |
| bio.true_markers = gt["true_markers"] |
| if gt.get("causal_mechanisms"): |
| bio.causal_mechanisms = gt["causal_mechanisms"] |
| if gt.get("true_pathways"): |
| bio.true_pathways = { |
| k: float(v) for k, v in gt["true_pathways"].items() |
| } |
|
|
| log("\n" + "=" * 70) |
| log(f"TASK: {obs.task.problem_statement}") |
| log(f"Conditions: {obs.task.conditions}") |
| log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days") |
| if ENABLE_THINKING: |
| log("Reasoning mode: ENABLED") |
| log("=" * 70) |
|
|
| cumulative_reward = 0.0 |
| write_dashboard_state(env, obs, step=0, cumulative_reward=0.0) |
|
|
| for step in range(MAX_EPISODE_STEPS): |
| cmd = check_dashboard_command() |
| if cmd and cmd.get("action") == "restart": |
| log("\n[DASHBOARD] Restart requested — ending episode early.") |
| break |
|
|
| user_msg = build_observation_prompt(obs) |
|
|
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_msg}, |
| ] |
|
|
| if active_pipeline is not None: |
| prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}" |
| else: |
| try: |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=ENABLE_THINKING, |
| ) |
| except TypeError: |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| t0 = time.time() |
| if active_pipeline is not None: |
| response = run_with_pipeline(active_pipeline, prompt) |
| if not response: |
| response = format_observation(obs) |
| else: |
| assert tokenizer is not None and model is not None |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| n_input = inputs["input_ids"].shape[1] |
| max_new = 2048 if ENABLE_THINKING else 300 |
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=max_new, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.8, |
| top_k=20, |
| repetition_penalty=1.3, |
| eos_token_id=eos_ids if eos_ids else None, |
| ) |
| new_tokens = output_ids[0][n_input:] |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
| gen_time = time.time() - t0 |
|
|
| thinking = "" |
| if ENABLE_THINKING: |
| think_match = re.search( |
| r"<think>(.*?)</think>", response, re.DOTALL |
| ) |
| if think_match: |
| thinking = think_match.group(1).strip() |
| response = response[think_match.end():].strip() |
| elif response.startswith("<think>"): |
| parts = response.split("</think>", 1) |
| if len(parts) == 2: |
| thinking = parts[0].replace("<think>", "").strip() |
| response = parts[1].strip() |
|
|
| is_last_step = (step == MAX_EPISODE_STEPS - 1) |
|
|
| action = parse_action(response) |
| if action is None: |
| if is_last_step: |
| log(f"\n [!] Parse failed on final step — forcing synthesize_conclusion.") |
| action = ExperimentAction( |
| action_type=ActionType.SYNTHESIZE_CONCLUSION, |
| justification="forced terminal conclusion", |
| confidence=0.5, |
| ) |
| else: |
| log(f"\n [!] Parse failed, skipping step. Raw: {response[:150]}") |
| continue |
|
|
| completed_types = { |
| r.action_type for r in obs.pipeline_history if r.success |
| } |
| failed_types = { |
| r.action_type |
| for r in obs.pipeline_history |
| if not r.success |
| } |
|
|
| if should_force_terminal_conclusion(action, completed_types): |
| log( |
| f"\n [!] repeated completed meta step {action.action_type.value} " |
| f"— forcing synthesize_conclusion." |
| ) |
| action = ExperimentAction( |
| action_type=ActionType.SYNTHESIZE_CONCLUSION, |
| justification="repeated completed meta step forced terminal conclusion", |
| confidence=action.confidence, |
| ) |
| completed_types = { |
| r.action_type for r in obs.pipeline_history if r.success |
| } |
|
|
| skip_reason = None |
| if action.action_type in completed_types: |
| skip_reason = ( |
| f"blocked repeat of completed step {action.action_type.value}" |
| ) |
| elif action.action_type in failed_types: |
| if should_block_failed_reattempt( |
| obs.pipeline_history, action.action_type |
| ): |
| skip_reason = ( |
| f"blocked re-attempt of failed step {action.action_type.value}" |
| ) |
|
|
| if skip_reason: |
| if is_last_step: |
| log(f"\n [!] {skip_reason} on final step — forcing synthesize_conclusion.") |
| action = ExperimentAction( |
| action_type=ActionType.SYNTHESIZE_CONCLUSION, |
| justification="forced terminal conclusion", |
| confidence=0.5, |
| ) |
| else: |
| log(f"\n [!] {skip_reason}, skipping step.") |
| continue |
|
|
| if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION: |
| log(f"\n [!] Final step — overriding {action.action_type.value} with synthesize_conclusion.") |
| action = ExperimentAction( |
| action_type=ActionType.SYNTHESIZE_CONCLUSION, |
| justification="forced terminal conclusion", |
| confidence=action.confidence, |
| ) |
|
|
| action = ensure_conclusion_claims(obs, action) |
|
|
| log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)") |
| if thinking: |
| log(f" Thinking: {thinking[:200]}") |
| if action.justification: |
| log(f" Rationale: {action.justification}") |
| else: |
| log(" Rationale: [model did not provide one]") |
| if action.parameters: |
| log(f" Parameters: {compact_preview(action.parameters, 200)}") |
| elif not action.justification and response: |
| log( |
| f" Model response: " |
| f"{compact_preview(response, MODEL_RESPONSE_PREVIEW_CHARS)}" |
| ) |
|
|
| obs = env.step(action) |
|
|
| if obs.latest_output: |
| lo = obs.latest_output |
| status = "OK" if lo.success else "FAIL" |
| log(f" [{status}] {lo.summary}") |
| if lo.warnings: |
| log(f" Warnings: {lo.warnings}") |
|
|
| step_reward = obs.reward |
| cumulative_reward += step_reward |
| log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})") |
| log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d") |
|
|
| write_dashboard_state( |
| env, obs, |
| step=step + 1, |
| cumulative_reward=cumulative_reward, |
| model_response=response, |
| model_thinking=thinking, |
| action=action, |
| gen_time=gen_time, |
| episode_done=obs.done, |
| ) |
|
|
| if obs.rule_violations: |
| log(f" Violations: {obs.rule_violations}") |
|
|
| if obs.done: |
| break |
|
|
| log(f"\n{'=' * 70}") |
| log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})") |
| log(f" Steps: {obs.step_index}") |
| log(f" Total reward: {cumulative_reward:+.3f}") |
| log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}") |
| log(f" Time used: {obs.resource_usage.time_used_days:.0f} days") |
| if obs.conclusions: |
| log(" Conclusions:") |
| for c in obs.conclusions: |
| log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}") |
| if c.top_markers: |
| log(f" Markers: {c.top_markers}") |
| if c.causal_mechanisms: |
| log(f" Mechanisms: {c.causal_mechanisms}") |
| if c.predicted_pathways: |
| log(f" Pathways: {c.predicted_pathways}") |
| log("=" * 70) |
|
|
| try: |
| DASHBOARD_CMD_PATH.unlink(missing_ok=True) |
| except OSError: |
| pass |
| run_episode() |
|
|
| while True: |
| log("\nWaiting for dashboard command (restart / new task) ...") |
| while True: |
| cmd = check_dashboard_command() |
| if cmd: |
| break |
| time.sleep(1.0) |
|
|
| action_type = cmd.get("action", "restart") |
| if action_type == "quit": |
| log("Quit requested.") |
| break |
|
|
| scenario = cmd.get("scenario_name") |
| ground_truth = cmd.get("ground_truth") |
| log(f"\n[DASHBOARD] {action_type} — scenario={scenario}") |
| run_episode(scenario_name=scenario, custom_ground_truth=ground_truth) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|