import json from datasets import load_dataset import verifiers as vf def load_environment( num_train_examples=7000, num_eval_examples=1000, **kwargs ): """ Environment for verifying complex JSON output from models. The task requires models to: 1. Parse multi-question prompts 2. Generate valid JSON responses 3. Match the expected structure with correct keys and values Rewards (no penalties, only positive rewards): - Formatting (valid JSON dict): 0.33 if pass, 0 if fail - All keys match: 0.33 if pass, 0 if fail - Answer values match: 0.33 if pass, 0 if fail Total max reward: ~1.0 """ # Load dataset from HuggingFace dataset = load_dataset("Delta-Vector/Tauri-Complex-JSON-Formatting", split="train") # Map to expected format - keep verification_info as string to avoid schema issues def format_example(example): return { "question": example["prompt"], "info": {"verification_info": example["verification_info"]}, # Keep as dict with string } dataset = dataset.map(format_example, remove_columns=dataset.column_names) # Split into train and eval train_dataset = dataset.select(range(num_train_examples)) eval_dataset = dataset.select(range(num_train_examples, num_train_examples + num_eval_examples)) # Custom extract function to parse JSON from code blocks or raw text def extract_json_from_completion(completion): """Extract JSON from completion, handling code blocks.""" if not completion: return "" # Get the last message content if isinstance(completion, list) and len(completion) > 0: content = completion[-1].get("content", "") else: content = str(completion) # Try to extract from code blocks first (```json ... ``` or ``` ... ```) import re code_block_pattern = r"```(?:json)?\s*\n(.*?)\n```" matches = re.findall(code_block_pattern, content, re.DOTALL) if matches: return matches[-1].strip() # Return last code block # Otherwise return the content as-is return content.strip() # Use simple Parser with custom extract function parser = vf.Parser(extract_fn=extract_json_from_completion) def format_reward(completion, **kwargs) -> float: """ Reward for valid JSON formatting. Returns 0.33 for valid JSON dict, 0 for invalid. """ try: response = parser.parse_answer(completion) or "" response = response.strip() # Check if response is not empty if not response: return 0.0 # Try to parse as JSON parsed = json.loads(response) # Must be a dict (since ground truth is always a dict) if not isinstance(parsed, dict): return 0.0 return 0.33 except (json.JSONDecodeError, ValueError, TypeError): return 0.0 def keys_match_reward(completion, info, **kwargs) -> float: """ Reward for matching keys in the JSON structure. Returns 0.33 if all keys match, 0 otherwise. """ try: response = parser.parse_answer(completion) or "" response = response.strip() parsed_response = json.loads(response) # Parse ground truth from info verification_info = json.loads(info["verification_info"]) ground_truth = verification_info["ground_truth"] # Check if it's a dict if not isinstance(parsed_response, dict): return 0.0 # Get all keys from ground truth (recursively) def get_all_keys(d, prefix=""): keys = set() if isinstance(d, dict): for k, v in d.items(): full_key = f"{prefix}.{k}" if prefix else k keys.add(full_key) keys.update(get_all_keys(v, full_key)) return keys expected_keys = get_all_keys(ground_truth) actual_keys = get_all_keys(parsed_response) # Check if keys match exactly if expected_keys == actual_keys: return 0.33 else: return 0.0 except (json.JSONDecodeError, ValueError, AttributeError, TypeError): return 0.0 def values_match_reward(completion, info, **kwargs) -> float: """ Reward for matching values in the JSON structure. Returns 0.33 if all values match, 0 otherwise. """ try: response = parser.parse_answer(completion) or "" response = response.strip() parsed_response = json.loads(response) # Parse ground truth from info verification_info = json.loads(info["verification_info"]) ground_truth = verification_info["ground_truth"] # Deep comparison of values def deep_compare(a, b): if type(a) != type(b): return False if isinstance(a, dict): if set(a.keys()) != set(b.keys()): return False return all(deep_compare(a[k], b[k]) for k in a.keys()) elif isinstance(a, list): if len(a) != len(b): return False return all(deep_compare(a[i], b[i]) for i in range(len(a))) else: return a == b if deep_compare(parsed_response, ground_truth): return 0.33 else: return 0.0 except (json.JSONDecodeError, ValueError, AttributeError, TypeError): return 0.0 # Create rubric with all reward functions rubric = vf.Rubric( parser=parser, funcs=[ format_reward, keys_match_reward, values_match_reward, ], weights=[1.0, 1.0, 1.0] # Equal weights for all three criteria ) # Return SingleTurnEnv since this is a one-shot task # No system prompt - let the dataset prompt speak for itself vf_env = vf.SingleTurnEnv( dataset=train_dataset, eval_dataset=eval_dataset, parser=parser, rubric=rubric, ) return vf_env