| import torch | |
| from transformers import Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| def train_model(): | |
| training_args = TrainingArguments( | |
| output_dir="./checkpoints", | |
| num_train_epochs=100, | |
| per_device_train_batch_size=4, | |
| gradient_accumulation_steps=4, | |
| learning_rate=1e-4, | |
| fp16=True, | |
| save_steps=500, | |
| ) | |
| dataset = load_dataset("dance_videos_dataset") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |