JSON-env / train_complex_json_output.py
Delta-Vector's picture
Upload folder using huggingface_hub
6b0866b verified
import verifiers as vf
"""
# install
vf-install complex-json-output (-p /path/to/environments)
# quick eval
vf-eval complex-json-output (-m model_name in endpoints.py)
inference:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 vf-vllm --model Qwen/Qwen2.5-1.5B-Instruct \
--data-parallel-size 6 --enforce-eager --disable-log-requests
training:
CUDA_VISIBLE_DEVICES=6,7 accelerate launch --num-processes 2 \
--config-file configs/zero3.yaml examples/grpo/train_complex_json_output.py
"""
# Hyperparameters
HPARAMS = [
"per_device_train_batch_size",
"num_generations",
"gradient_accumulation_steps",
"max_tokens",
"max_seq_len",
"max_prompt_length",
"max_completion_length",
"temperature",
"learning_rate",
"max_steps",
"warmup_steps",
"eval_steps",
"save_steps",
"beta",
"loss_type",
]
# Load environment
vf_env = vf.load_environment(
env_id="complex-json-output",
num_train_examples=8000, # Use subset for faster training
num_eval_examples=50
)
# Model configuration
model_name = "/raid/workspace/Mango/verifiers/MS3.2-0.35-Beta"
run_name = "complex-json-grpo_" + model_name.split("/")[-1].lower()
# Load model and tokenizer
model, tokenizer = vf.get_model_and_tokenizer(model_name)
# Training arguments
training_args = vf.grpo_defaults(run_name=run_name)
# Batch configuration
training_args.per_device_train_batch_size = 2
training_args.num_generations = 16
training_args.gradient_accumulation_steps = 2
# Generation parameters
training_args.max_tokens = 2048 # JSON can be long
training_args.max_seq_len = 16000
training_args.max_prompt_length = 8192 # Allow long prompts (questions can be lengthy)
training_args.max_completion_length = 4096 # Allow long completions
training_args.temperature = 0.1 # Some diversity but not too much
# Training schedule
training_args.learning_rate = 5e-6
training_args.max_steps = 1000
training_args.warmup_steps = 15
# Evaluation
training_args.eval_strategy = "none"
training_args.eval_steps = 50
training_args.per_device_eval_batch_size = 8
# Checkpointing
training_args.save_strategy = "steps"
training_args.save_steps = 100
# GRPO parameters
training_args.beta = 0.001 # Conservative KL penalty
training_args.loss_type = "dr_grpo" # Recommended: no length bias
# Logging
training_args.logging_steps = 1
training_args.log_completions = True
training_args.num_completions_to_print = 3
training_args.report_to = "wandb" # Disable wandb
# Create trainer
trainer = vf.GRPOTrainer(
model=model,
processing_class=tokenizer,
env=vf_env,
args=training_args,
peft_config=vf.lora_defaults(r=8, alpha=16), # Use LoRA for efficiency
)
# Train
trainer.train()