File size: 33,949 Bytes
3aaced5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 |
from pathlib import Path
import sys
import tempfile
from unittest.mock import patch
from datasets import Dataset
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from syntetic_issue_report_data_generation.config import MODEL_CONFIGS, DATASET_CONFIGs
from syntetic_issue_report_data_generation.modeling.train import (
init_parser,
load_and_prepare_data,
train_model_setfit,
train_model_transformers,
)
@pytest.fixture
def temp_data_dir():
"""Create a temporary directory for test data files."""
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
@pytest.fixture
def sample_train_data():
"""Create sample training data with balanced classes."""
return pd.DataFrame(
{
"title": [
"Bug in login",
"Feature request",
"Performance issue",
"UI problem",
"Bug in logout",
"Add search",
"Memory leak",
"New API endpoint",
"Crash on startup",
"Enhancement needed",
],
"body": [
"Cannot login to system",
"Add dark mode feature",
"Slow loading times",
"Button misaligned",
"Cannot logout properly",
"Need search functionality",
"High memory usage",
"REST API needed",
"Application crashes",
"Improve user experience",
],
"label": [
"bug",
"enhancement",
"bug",
"bug",
"bug",
"enhancement",
"bug",
"enhancement",
"bug",
"enhancement",
],
}
)
@pytest.fixture
def sample_imbalanced_data():
"""Create sample data with imbalanced classes (for stratified sampling test)."""
return pd.DataFrame(
{
"title": [
"Bug 1",
"Bug 2",
"Bug 3",
"Bug 4",
"Bug 5",
"Bug 6",
"Bug 7",
"Bug 8",
"Enhancement 1",
"Enhancement 2",
],
"body": [
"Bug body 1",
"Bug body 2",
"Bug body 3",
"Bug body 4",
"Bug body 5",
"Bug body 6",
"Bug body 7",
"Bug body 8",
"Enhancement body 1",
"Enhancement body 2",
],
"label": [
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"bug",
"enhancement",
"enhancement",
],
}
)
@pytest.fixture
def train_config_with_title(temp_data_dir, sample_train_data):
"""Create train config with title and body columns."""
train_path = temp_data_dir / "train_with_title.csv"
sample_train_data.to_csv(train_path, index=False)
return {
"data_path": "train_with_title.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def imbalanced_train_config(temp_data_dir, sample_imbalanced_data):
"""Create train config with imbalanced data."""
train_path = temp_data_dir / "train_imbalanced.csv"
sample_imbalanced_data.to_csv(train_path, index=False)
return {
"data_path": "train_imbalanced.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def minimal_train_data():
"""Create minimal training data for quick training tests."""
return pd.DataFrame(
{
"title": [
"Bug 1",
"Bug 2",
"Enhancement 1",
"Enhancement 2",
"Bug 3",
"Enhancement 3",
],
"body": [
"Bug body 1",
"Bug body 2",
"Enh body 1",
"Enh body 2",
"Bug body 3",
"Enh body 3",
],
"label": ["bug", "bug", "enhancement", "enhancement", "bug", "enhancement"],
}
)
@pytest.fixture
def minimal_train_config(temp_data_dir, minimal_train_data):
"""Create train config with minimal data for fast training."""
train_path = temp_data_dir / "minimal_train.csv"
minimal_train_data.to_csv(train_path, index=False)
return {
"data_path": "minimal_train.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
@pytest.fixture
def minimal_model_config_setfit():
"""Create minimal SetFit model configuration for testing."""
return {
"model_checkpoint": "sentence-transformers/paraphrase-MiniLM-L3-v2", # Small, fast model
"params": {"num_epochs": 1, "batch_size": 4, "num_iterations": 5, "max_length": 64},
}
@pytest.fixture
def minimal_model_config_transformers():
"""Create minimal Transformers model configuration for testing."""
return {
"model_checkpoint": "prajjwal1/bert-tiny", # Very small BERT model
"params": {
"num_train_epochs": 1,
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
"learning_rate": 5e-5,
"warmup_steps": 0,
"weight_decay": 0.01,
"logging_steps": 1,
},
}
class TestDataLoadingAndPreparation:
"""Test class for data loading and preparation functionality."""
def test_load_data_with_valid_config(self, train_config_with_title, temp_data_dir):
"""Verify that data loads correctly with valid train dataset configuration."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that datasets were created
assert isinstance(train_ds, Dataset)
assert isinstance(test_ds, Dataset)
# Check that datasets have correct columns
assert set(train_ds.column_names) == {"text", "label"}
assert set(test_ds.column_names) == {"text", "label"}
# Check that datasets are not empty
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check total samples
assert len(train_ds) + len(test_ds) == 10
# Check label encoder is attached
assert hasattr(train_ds, "label_encoder")
assert hasattr(test_ds, "label_encoder")
# Check that labels are integers
assert all(isinstance(label, int) for label in train_ds["label"])
assert all(isinstance(label, int) for label in test_ds["label"])
def test_load_data_creates_holdout_split(self, train_config_with_title, temp_data_dir):
"""Verify holdout split is created when no test dataset is provided."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check eval strategy is holdout
assert eval_strategy == "holdout"
# Check that both datasets exist
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check that split is approximately correct (80/20 split of 10 samples)
total_samples = len(train_ds) + len(test_ds)
assert total_samples == 10
assert len(test_ds) == 2 # 20% of 10 = 2
assert len(train_ds) == 8 # 80% of 10 = 8
# Verify no data leakage (no overlap between train and test)
train_texts = set(train_ds["text"])
test_texts = set(test_ds["text"])
assert len(train_texts.intersection(test_texts)) == 0
def test_label_encoding_consistency(self, train_config_with_title, temp_data_dir):
"""Verify labels are encoded consistently across train/test sets."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that label encoders are the same object
assert train_ds.label_encoder is test_ds.label_encoder
# Check that labels are integers
assert all(isinstance(label, int) for label in train_ds["label"])
assert all(isinstance(label, int) for label in test_ds["label"])
# Check that label classes are consistent
train_label_classes = train_ds.label_encoder.classes_
test_label_classes = test_ds.label_encoder.classes_
assert list(train_label_classes) == list(test_label_classes)
# Check that we have the expected classes (bug and enhancement)
expected_classes = sorted(["bug", "enhancement"])
actual_classes = sorted(train_ds.label_encoder.classes_)
assert actual_classes == expected_classes
# Check that encoded labels are in valid range [0, num_classes)
num_classes = len(train_ds.label_encoder.classes_)
assert num_classes == 2
assert all(0 <= label < num_classes for label in train_ds["label"])
assert all(0 <= label < num_classes for label in test_ds["label"])
# Check that both classes appear in train set (due to stratification)
train_unique_labels = set(train_ds["label"])
assert len(train_unique_labels) == 2
def test_text_column_creation_with_title_and_body(
self, train_config_with_title, temp_data_dir
):
"""Verify text column combines title and body correctly."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
train_ds, test_ds, eval_strategy = load_and_prepare_data(
train_config_with_title, test_config=None, test_size=0.2
)
# Check that text column exists and is the only text column
assert "text" in train_ds.column_names
assert "text" in test_ds.column_names
assert "title" not in train_ds.column_names
assert "body" not in train_ds.column_names
# Check that all text entries are non-empty strings
assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
# Check that text contains content (not just whitespace)
assert all(len(text.strip()) > 0 for text in train_ds["text"])
assert all(len(text.strip()) > 0 for text in test_ds["text"])
# Check that text is longer than just title or body alone
# (indicating concatenation happened)
for text in train_ds["text"]:
# Text should have reasonable length (at least 10 chars)
assert len(text) >= 10
def test_max_train_samples_stratified_sampling(self, imbalanced_train_config, temp_data_dir):
"""Verify stratified sampling works correctly when max_train_samples is specified."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Original data has 10 samples: 8 bugs, 2 enhancements
# Request only 4 samples
train_ds, test_ds, eval_strategy = load_and_prepare_data(
imbalanced_train_config, test_config=None, test_size=0.2, max_train_samples=4
)
# Check that train dataset size is reduced to 4
assert len(train_ds) == 4
# Check that we still have both classes (stratified sampling)
unique_labels = set(train_ds["label"])
assert len(unique_labels) == 2, "Stratified sampling should preserve both classes"
# Check that class distribution is approximately maintained
# Original: 80% bug, 20% enhancement
# With 4 samples: should have ~3 bugs, ~1 enhancement
label_counts = {}
for label in train_ds["label"]:
label_name = train_ds.label_encoder.inverse_transform([label])[0]
label_counts[label_name] = label_counts.get(label_name, 0) + 1
# At least one sample from minority class
assert label_counts.get("enhancement", 0) >= 1
# Majority class should have more samples
assert label_counts.get("bug", 0) >= label_counts.get("enhancement", 0)
# Test dataset should remain at 20% of original (2 samples)
assert len(test_ds) == 2
# Total should be less than original
assert len(train_ds) + len(test_ds) < 10
class TestConfiguration:
"""Test class for configuration and argument parsing functionality."""
def test_parser_accepts_valid_arguments(self):
"""Verify parser accepts all valid combinations of arguments."""
parser = init_parser()
# Get first valid dataset and model from configs
valid_dataset = list(DATASET_CONFIGs.keys())[0]
valid_model = list(MODEL_CONFIGS.keys())[0]
# Test 1: Minimal required arguments
args = parser.parse_args(["--train-dataset", valid_dataset, "--model-name", valid_model])
assert args.train_dataset == valid_dataset
assert args.model_name == valid_model
assert args.test_dataset is None
assert args.test_size == 0.2 # default value
assert args.max_train_samples is None # default value
assert args.use_setfit is False # default value
assert args.run_name is None # default value
# Test 2: All arguments provided
if len(DATASET_CONFIGs.keys()) > 1:
valid_test_dataset = list(DATASET_CONFIGs.keys())[1]
else:
valid_test_dataset = valid_dataset
args = parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--test-dataset",
valid_test_dataset,
"--model-name",
valid_model,
"--use-setfit",
"--test-size",
"0.3",
"--max-train-samples",
"100",
"--run-name",
"test_run",
]
)
assert args.train_dataset == valid_dataset
assert args.test_dataset == valid_test_dataset
assert args.model_name == valid_model
assert args.use_setfit is True
assert args.test_size == 0.3
assert args.max_train_samples == 100
assert args.run_name == "test_run"
# Test 3: Only use-setfit flag
args = parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", valid_model, "--use-setfit"]
)
assert args.use_setfit is True
# Test 4: Custom test-size
args = parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", valid_model, "--test-size", "0.15"]
)
assert args.test_size == 0.15
# Test 5: Custom max-train-samples
args = parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--model-name",
valid_model,
"--max-train-samples",
"500",
]
)
assert args.max_train_samples == 500
def test_parser_rejects_invalid_dataset_names(self):
"""Verify parser rejects dataset names not in DATASET_CONFIGs."""
parser = init_parser()
# Get valid model
valid_model = list(MODEL_CONFIGS.keys())[0]
# Test 1: Invalid train dataset
with pytest.raises(SystemExit):
parser.parse_args(
["--train-dataset", "invalid_dataset_name", "--model-name", valid_model]
)
# Test 2: Invalid test dataset
valid_dataset = list(DATASET_CONFIGs.keys())[0]
with pytest.raises(SystemExit):
parser.parse_args(
[
"--train-dataset",
valid_dataset,
"--test-dataset",
"invalid_test_dataset",
"--model-name",
valid_model,
]
)
# Test 3: Invalid model name
with pytest.raises(SystemExit):
parser.parse_args(
["--train-dataset", valid_dataset, "--model-name", "invalid_model_name"]
)
# Test 4: Missing required argument (train-dataset)
with pytest.raises(SystemExit):
parser.parse_args(["--model-name", valid_model])
# Test 5: Missing required argument (model-name)
with pytest.raises(SystemExit):
parser.parse_args(["--train-dataset", valid_dataset])
class TestTrainingPipeline:
@pytest.mark.slow
def test_setfit_training_completes(
self, minimal_train_config, minimal_model_config_setfit, temp_data_dir
):
"""Verify SetFit training runs without errors (using minimal data)."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for test, 4 for train
)
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Train model
result = train_model_setfit(minimal_model_config_setfit, train_ds, test_ds)
# Check that result is returned
assert result is not None
# Check that result has expected structure
model, metrics, y_true, y_pred, model_type = result
# Check model type
assert model_type == "setfit"
# Check that model is returned
assert model is not None
# Check that metrics are computed
assert isinstance(metrics, dict)
assert "accuracy" in metrics
assert "f1_macro" in metrics
assert "f1_weighted" in metrics
# Check that metrics are in valid range [0, 1]
assert 0 <= metrics["accuracy"] <= 1
assert 0 <= metrics["f1_macro"] <= 1
assert 0 <= metrics["f1_weighted"] <= 1
# Check that predictions are returned
assert y_true is not None
assert y_pred is not None
# Check that predictions have correct length
assert len(y_true) == len(test_ds)
assert len(y_pred) == len(test_ds)
# Check that predictions are in valid label space
num_classes = len(train_ds.label_encoder.classes_)
assert all(0 <= pred < num_classes for pred in y_pred)
assert all(0 <= true < num_classes for true in y_true)
@pytest.mark.slow
def test_transformers_training_completes(
self, minimal_train_config, minimal_model_config_transformers, temp_data_dir
):
"""Verify Transformers training runs without errors (using minimal data)."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for test, 4 for train
)
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Train model
result = train_model_transformers(
minimal_model_config_transformers, train_ds, test_ds
)
# Check that result is returned
assert result is not None
# Check that result has expected structure
model_tuple, metrics, y_true, y_pred, model_type = result
# Check model type
assert model_type == "transformers"
# Check that model and tokenizer are returned
assert model_tuple is not None
assert isinstance(model_tuple, tuple)
assert len(model_tuple) == 2
model, tokenizer = model_tuple
assert model is not None
assert tokenizer is not None
# Check that metrics are computed
assert isinstance(metrics, dict)
# Transformers returns metrics with eval_ prefix
assert any("accuracy" in key for key in metrics.keys())
# Extract accuracy value (could be 'accuracy' or 'eval_accuracy')
accuracy_key = [k for k in metrics.keys() if "accuracy" in k][0]
accuracy = metrics[accuracy_key]
assert 0 <= accuracy <= 1
# Check that predictions are returned
assert y_true is not None
assert y_pred is not None
# Check that predictions have correct length
assert len(y_true) == len(test_ds)
assert len(y_pred) == len(test_ds)
# Check that predictions are in valid label space
num_classes = len(train_ds.label_encoder.classes_)
assert all(0 <= pred < num_classes for pred in y_pred)
assert all(0 <= true < num_classes for true in y_true)
# Check that predictions are numpy arrays or lists of integers
assert all(isinstance(pred, (int, np.integer)) for pred in y_pred)
assert all(isinstance(true, (int, np.integer)) for true in y_true)
class TestErrorHandling:
"""Test class for error handling functionality."""
def test_missing_train_file_raises_error(self, temp_data_dir):
"""Verify appropriate error when train file doesn't exist."""
# Create config pointing to non-existent file
missing_file_config = {
"data_path": "non_existent_file.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when file doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(missing_file_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
def test_invalid_label_column(self, temp_data_dir):
"""Verify error handling when specified label column doesn't exist."""
# Create data with specific columns
sample_data = pd.DataFrame(
{
"title": ["Bug 1", "Enhancement 1"],
"body": ["Bug body", "Enhancement body"],
"type": ["bug", "enhancement"], # Different column name
}
)
# Save to file
train_path = temp_data_dir / "invalid_label_col.csv"
sample_data.to_csv(train_path, index=False)
# Create config with wrong label column name
invalid_label_config = {
"data_path": "invalid_label_col.csv",
"label_col": "label", # This column doesn't exist
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when label column doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(invalid_label_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
def test_invalid_body_column(self, temp_data_dir):
"""Verify error handling when specified body column doesn't exist."""
# Create data with specific columns
sample_data = pd.DataFrame(
{
"title": ["Bug 1", "Enhancement 1"],
"description": ["Bug body", "Enhancement body"], # Different column name
"label": ["bug", "enhancement"],
}
)
# Save to file
train_path = temp_data_dir / "invalid_body_col.csv"
sample_data.to_csv(train_path, index=False)
# Create config with wrong body column name
invalid_body_config = {
"data_path": "invalid_body_col.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body", # This column doesn't exist
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Should call sys.exit(1) when body column doesn't exist
with pytest.raises(SystemExit) as excinfo:
load_and_prepare_data(invalid_body_config, test_config=None, test_size=0.2)
# Check that exit code is 1
assert excinfo.value.code == 1
class TestEdgeCases:
"""Test class for edge case scenarios."""
def test_very_small_dataset(self, temp_data_dir):
"""Verify training with very small datasets (< 10 samples)."""
# Create very small dataset (6 samples total, 3 per class)
very_small_data = pd.DataFrame(
{
"title": ["Bug 1", "Bug 2", "Bug 3", "Enh 1", "Enh 2", "Enh 3"],
"body": [
"Small bug 1",
"Small bug 2",
"Small bug 3",
"Small enh 1",
"Small enh 2",
"Small enh 3",
],
"label": ["bug", "bug", "bug", "enhancement", "enhancement", "enhancement"],
}
)
# Save to file
train_path = temp_data_dir / "very_small.csv"
very_small_data.to_csv(train_path, index=False)
small_config = {
"data_path": "very_small.csv",
"label_col": "label",
"title_col": "title",
"body_col": "body",
"sep": ",",
}
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data with small test split to ensure at least 1 sample per class in train
train_ds, test_ds, eval_strategy = load_and_prepare_data(
small_config,
test_config=None,
test_size=0.33, # 2 samples for test (1 per class), 4 for train
)
# Check that datasets were created despite small size
assert isinstance(train_ds, Dataset)
assert isinstance(test_ds, Dataset)
# Check that both datasets have samples
assert len(train_ds) > 0
assert len(test_ds) > 0
# Check total is preserved
assert len(train_ds) + len(test_ds) == 6
# Check that stratification preserved both classes in train set
train_unique_labels = set(train_ds["label"])
assert len(train_unique_labels) >= 1 # At least one class
# Check that we have valid label encoding
num_classes = len(train_ds.label_encoder.classes_)
assert num_classes == 2
assert all(0 <= label < num_classes for label in train_ds["label"])
assert all(0 <= label < num_classes for label in test_ds["label"])
# Check that text was properly created
assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
class TestOutputValidation:
"""Test class for output validation functionality."""
@pytest.mark.slow
def test_predictions_match_label_space(
self,
minimal_train_config,
minimal_model_config_setfit,
minimal_model_config_transformers,
temp_data_dir,
):
"""Verify predictions are within valid label space."""
with patch(
"syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
temp_data_dir,
):
# Load data
train_ds, test_ds, eval_strategy = load_and_prepare_data(
minimal_train_config,
test_config=None,
test_size=0.33, # 2 samples for tfest, 4 for train
)
# Get the valid label space
num_classes = len(train_ds.label_encoder.classes_)
valid_label_space = set(range(num_classes))
# Mock MLflow to avoid logging during tests
with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
# Test with SetFit
model, metrics, y_true, y_pred, model_type = train_model_setfit(
minimal_model_config_setfit, train_ds, test_ds
)
# Check that all predictions are in valid label space
assert all(
pred in valid_label_space for pred in y_pred
), f"SetFit predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred)}"
# Check that all true labels are in valid label space
assert all(
true in valid_label_space for true in y_true
), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true)}"
# Check that predictions are within [0, num_classes)
assert all(
0 <= pred < num_classes for pred in y_pred
), f"SetFit predictions out of range [0, {num_classes})"
# Check that y_true matches the original test labels
assert list(y_true) == list(
test_ds["label"]
), "True labels don't match original test dataset labels"
# Test with Transformers
(model_t, tokenizer), metrics_t, y_true_t, y_pred_t, model_type_t = (
train_model_transformers(minimal_model_config_transformers, train_ds, test_ds)
)
# Check that all predictions are in valid label space
assert all(
pred in valid_label_space for pred in y_pred_t
), f"Transformers predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred_t)}"
# Check that all true labels are in valid label space
assert all(
true in valid_label_space for true in y_true_t
), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true_t)}"
# Check that predictions are within [0, num_classes)
assert all(
0 <= pred < num_classes for pred in y_pred_t
), f"Transformers predictions out of range [0, {num_classes})"
# Check that y_true matches the original test labels
assert list(y_true_t) == list(
test_ds["label"]
), "True labels don't match original test dataset labels"
# Additional check: verify predictions are integers
assert all(
isinstance(pred, (int, np.integer)) for pred in y_pred
), "SetFit predictions must be integers"
assert all(
isinstance(pred, (int, np.integer)) for pred in y_pred_t
), "Transformers predictions must be integers"
# Check that at least some predictions were made (not all same)
# This is a sanity check - with random initialization, we should get some variation
# (Though with very small data, it's possible all predictions are the same)
unique_preds = len(set(y_pred))
unique_preds_t = len(set(y_pred_t))
assert unique_preds >= 1, "SetFit made no predictions"
assert unique_preds_t >= 1, "Transformers made no predictions"
|