TurtleDetector / training /segmentation_train.py
WildlifeDatasets's picture
Added training scripts
bb14d6a unverified
raw
history blame contribute delete
669 Bytes
import os
from ultralytics import YOLO
project = f"{os.getcwd()}/runs"
device = "cuda:2"
imgsz = 640
epochs = 20
# Stage 1: Pretrain on SeaTurtleID2022 (large dataset)
model = YOLO("yolo11s-seg.pt")
model.train(
data="segmentation_stage1.yaml",
project=project,
name="stage1",
epochs=epochs,
imgsz=imgsz,
device=device,
fliplr=0,
flipud=0,
)
# Stage 2: Fine-tune on combined dataset (balanced)
model = YOLO(f"{project}/stage1/weights/last.pt")
model.train(
data="segmentation_stage2.yaml",
project=project,
name="stage2",
epochs=epochs,
imgsz=imgsz,
device=device,
fliplr=0,
flipud=0,
freeze=5,
)