File size: 33,949 Bytes
3aaced5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
from pathlib import Path
import sys
import tempfile
from unittest.mock import patch

from datasets import Dataset
import numpy as np
import pandas as pd
import pytest

sys.path.insert(0, str(Path(__file__).parent.parent))

from syntetic_issue_report_data_generation.config import MODEL_CONFIGS, DATASET_CONFIGs
from syntetic_issue_report_data_generation.modeling.train import (
    init_parser,
    load_and_prepare_data,
    train_model_setfit,
    train_model_transformers,
)


@pytest.fixture
def temp_data_dir():
    """Create a temporary directory for test data files."""
    with tempfile.TemporaryDirectory() as tmpdirname:
        yield Path(tmpdirname)


@pytest.fixture
def sample_train_data():
    """Create sample training data with balanced classes."""
    return pd.DataFrame(
        {
            "title": [
                "Bug in login",
                "Feature request",
                "Performance issue",
                "UI problem",
                "Bug in logout",
                "Add search",
                "Memory leak",
                "New API endpoint",
                "Crash on startup",
                "Enhancement needed",
            ],
            "body": [
                "Cannot login to system",
                "Add dark mode feature",
                "Slow loading times",
                "Button misaligned",
                "Cannot logout properly",
                "Need search functionality",
                "High memory usage",
                "REST API needed",
                "Application crashes",
                "Improve user experience",
            ],
            "label": [
                "bug",
                "enhancement",
                "bug",
                "bug",
                "bug",
                "enhancement",
                "bug",
                "enhancement",
                "bug",
                "enhancement",
            ],
        }
    )


@pytest.fixture
def sample_imbalanced_data():
    """Create sample data with imbalanced classes (for stratified sampling test)."""
    return pd.DataFrame(
        {
            "title": [
                "Bug 1",
                "Bug 2",
                "Bug 3",
                "Bug 4",
                "Bug 5",
                "Bug 6",
                "Bug 7",
                "Bug 8",
                "Enhancement 1",
                "Enhancement 2",
            ],
            "body": [
                "Bug body 1",
                "Bug body 2",
                "Bug body 3",
                "Bug body 4",
                "Bug body 5",
                "Bug body 6",
                "Bug body 7",
                "Bug body 8",
                "Enhancement body 1",
                "Enhancement body 2",
            ],
            "label": [
                "bug",
                "bug",
                "bug",
                "bug",
                "bug",
                "bug",
                "bug",
                "bug",
                "enhancement",
                "enhancement",
            ],
        }
    )


@pytest.fixture
def train_config_with_title(temp_data_dir, sample_train_data):
    """Create train config with title and body columns."""
    train_path = temp_data_dir / "train_with_title.csv"
    sample_train_data.to_csv(train_path, index=False)

    return {
        "data_path": "train_with_title.csv",
        "label_col": "label",
        "title_col": "title",
        "body_col": "body",
        "sep": ",",
    }


@pytest.fixture
def imbalanced_train_config(temp_data_dir, sample_imbalanced_data):
    """Create train config with imbalanced data."""
    train_path = temp_data_dir / "train_imbalanced.csv"
    sample_imbalanced_data.to_csv(train_path, index=False)

    return {
        "data_path": "train_imbalanced.csv",
        "label_col": "label",
        "title_col": "title",
        "body_col": "body",
        "sep": ",",
    }


@pytest.fixture
def minimal_train_data():
    """Create minimal training data for quick training tests."""
    return pd.DataFrame(
        {
            "title": [
                "Bug 1",
                "Bug 2",
                "Enhancement 1",
                "Enhancement 2",
                "Bug 3",
                "Enhancement 3",
            ],
            "body": [
                "Bug body 1",
                "Bug body 2",
                "Enh body 1",
                "Enh body 2",
                "Bug body 3",
                "Enh body 3",
            ],
            "label": ["bug", "bug", "enhancement", "enhancement", "bug", "enhancement"],
        }
    )


@pytest.fixture
def minimal_train_config(temp_data_dir, minimal_train_data):
    """Create train config with minimal data for fast training."""
    train_path = temp_data_dir / "minimal_train.csv"
    minimal_train_data.to_csv(train_path, index=False)

    return {
        "data_path": "minimal_train.csv",
        "label_col": "label",
        "title_col": "title",
        "body_col": "body",
        "sep": ",",
    }


@pytest.fixture
def minimal_model_config_setfit():
    """Create minimal SetFit model configuration for testing."""
    return {
        "model_checkpoint": "sentence-transformers/paraphrase-MiniLM-L3-v2",  # Small, fast model
        "params": {"num_epochs": 1, "batch_size": 4, "num_iterations": 5, "max_length": 64},
    }


@pytest.fixture
def minimal_model_config_transformers():
    """Create minimal Transformers model configuration for testing."""
    return {
        "model_checkpoint": "prajjwal1/bert-tiny",  # Very small BERT model
        "params": {
            "num_train_epochs": 1,
            "per_device_train_batch_size": 2,
            "per_device_eval_batch_size": 2,
            "learning_rate": 5e-5,
            "warmup_steps": 0,
            "weight_decay": 0.01,
            "logging_steps": 1,
        },
    }


class TestDataLoadingAndPreparation:
    """Test class for data loading and preparation functionality."""

    def test_load_data_with_valid_config(self, train_config_with_title, temp_data_dir):
        """Verify that data loads correctly with valid train dataset configuration."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                train_config_with_title, test_config=None, test_size=0.2
            )

            # Check that datasets were created
            assert isinstance(train_ds, Dataset)
            assert isinstance(test_ds, Dataset)

            # Check that datasets have correct columns
            assert set(train_ds.column_names) == {"text", "label"}
            assert set(test_ds.column_names) == {"text", "label"}

            # Check that datasets are not empty
            assert len(train_ds) > 0
            assert len(test_ds) > 0

            # Check total samples
            assert len(train_ds) + len(test_ds) == 10

            # Check label encoder is attached
            assert hasattr(train_ds, "label_encoder")
            assert hasattr(test_ds, "label_encoder")

            # Check that labels are integers
            assert all(isinstance(label, int) for label in train_ds["label"])
            assert all(isinstance(label, int) for label in test_ds["label"])

    def test_load_data_creates_holdout_split(self, train_config_with_title, temp_data_dir):
        """Verify holdout split is created when no test dataset is provided."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                train_config_with_title, test_config=None, test_size=0.2
            )

            # Check eval strategy is holdout
            assert eval_strategy == "holdout"

            # Check that both datasets exist
            assert len(train_ds) > 0
            assert len(test_ds) > 0

            # Check that split is approximately correct (80/20 split of 10 samples)
            total_samples = len(train_ds) + len(test_ds)
            assert total_samples == 10
            assert len(test_ds) == 2  # 20% of 10 = 2
            assert len(train_ds) == 8  # 80% of 10 = 8

            # Verify no data leakage (no overlap between train and test)
            train_texts = set(train_ds["text"])
            test_texts = set(test_ds["text"])
            assert len(train_texts.intersection(test_texts)) == 0

    def test_label_encoding_consistency(self, train_config_with_title, temp_data_dir):
        """Verify labels are encoded consistently across train/test sets."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                train_config_with_title, test_config=None, test_size=0.2
            )

            # Check that label encoders are the same object
            assert train_ds.label_encoder is test_ds.label_encoder

            # Check that labels are integers
            assert all(isinstance(label, int) for label in train_ds["label"])
            assert all(isinstance(label, int) for label in test_ds["label"])

            # Check that label classes are consistent
            train_label_classes = train_ds.label_encoder.classes_
            test_label_classes = test_ds.label_encoder.classes_
            assert list(train_label_classes) == list(test_label_classes)

            # Check that we have the expected classes (bug and enhancement)
            expected_classes = sorted(["bug", "enhancement"])
            actual_classes = sorted(train_ds.label_encoder.classes_)
            assert actual_classes == expected_classes

            # Check that encoded labels are in valid range [0, num_classes)
            num_classes = len(train_ds.label_encoder.classes_)
            assert num_classes == 2
            assert all(0 <= label < num_classes for label in train_ds["label"])
            assert all(0 <= label < num_classes for label in test_ds["label"])

            # Check that both classes appear in train set (due to stratification)
            train_unique_labels = set(train_ds["label"])
            assert len(train_unique_labels) == 2

    def test_text_column_creation_with_title_and_body(
        self, train_config_with_title, temp_data_dir
    ):
        """Verify text column combines title and body correctly."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                train_config_with_title, test_config=None, test_size=0.2
            )

            # Check that text column exists and is the only text column
            assert "text" in train_ds.column_names
            assert "text" in test_ds.column_names
            assert "title" not in train_ds.column_names
            assert "body" not in train_ds.column_names

            # Check that all text entries are non-empty strings
            assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
            assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])

            # Check that text contains content (not just whitespace)
            assert all(len(text.strip()) > 0 for text in train_ds["text"])
            assert all(len(text.strip()) > 0 for text in test_ds["text"])

            # Check that text is longer than just title or body alone
            # (indicating concatenation happened)
            for text in train_ds["text"]:
                # Text should have reasonable length (at least 10 chars)
                assert len(text) >= 10

    def test_max_train_samples_stratified_sampling(self, imbalanced_train_config, temp_data_dir):
        """Verify stratified sampling works correctly when max_train_samples is specified."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Original data has 10 samples: 8 bugs, 2 enhancements
            # Request only 4 samples
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                imbalanced_train_config, test_config=None, test_size=0.2, max_train_samples=4
            )

            # Check that train dataset size is reduced to 4
            assert len(train_ds) == 4

            # Check that we still have both classes (stratified sampling)
            unique_labels = set(train_ds["label"])
            assert len(unique_labels) == 2, "Stratified sampling should preserve both classes"

            # Check that class distribution is approximately maintained
            # Original: 80% bug, 20% enhancement
            # With 4 samples: should have ~3 bugs, ~1 enhancement
            label_counts = {}
            for label in train_ds["label"]:
                label_name = train_ds.label_encoder.inverse_transform([label])[0]
                label_counts[label_name] = label_counts.get(label_name, 0) + 1

            # At least one sample from minority class
            assert label_counts.get("enhancement", 0) >= 1
            # Majority class should have more samples
            assert label_counts.get("bug", 0) >= label_counts.get("enhancement", 0)

            # Test dataset should remain at 20% of original (2 samples)
            assert len(test_ds) == 2

            # Total should be less than original
            assert len(train_ds) + len(test_ds) < 10


class TestConfiguration:
    """Test class for configuration and argument parsing functionality."""

    def test_parser_accepts_valid_arguments(self):
        """Verify parser accepts all valid combinations of arguments."""
        parser = init_parser()

        # Get first valid dataset and model from configs
        valid_dataset = list(DATASET_CONFIGs.keys())[0]
        valid_model = list(MODEL_CONFIGS.keys())[0]

        # Test 1: Minimal required arguments
        args = parser.parse_args(["--train-dataset", valid_dataset, "--model-name", valid_model])
        assert args.train_dataset == valid_dataset
        assert args.model_name == valid_model
        assert args.test_dataset is None
        assert args.test_size == 0.2  # default value
        assert args.max_train_samples is None  # default value
        assert args.use_setfit is False  # default value
        assert args.run_name is None  # default value

        # Test 2: All arguments provided
        if len(DATASET_CONFIGs.keys()) > 1:
            valid_test_dataset = list(DATASET_CONFIGs.keys())[1]
        else:
            valid_test_dataset = valid_dataset

        args = parser.parse_args(
            [
                "--train-dataset",
                valid_dataset,
                "--test-dataset",
                valid_test_dataset,
                "--model-name",
                valid_model,
                "--use-setfit",
                "--test-size",
                "0.3",
                "--max-train-samples",
                "100",
                "--run-name",
                "test_run",
            ]
        )
        assert args.train_dataset == valid_dataset
        assert args.test_dataset == valid_test_dataset
        assert args.model_name == valid_model
        assert args.use_setfit is True
        assert args.test_size == 0.3
        assert args.max_train_samples == 100
        assert args.run_name == "test_run"

        # Test 3: Only use-setfit flag
        args = parser.parse_args(
            ["--train-dataset", valid_dataset, "--model-name", valid_model, "--use-setfit"]
        )
        assert args.use_setfit is True

        # Test 4: Custom test-size
        args = parser.parse_args(
            ["--train-dataset", valid_dataset, "--model-name", valid_model, "--test-size", "0.15"]
        )
        assert args.test_size == 0.15

        # Test 5: Custom max-train-samples
        args = parser.parse_args(
            [
                "--train-dataset",
                valid_dataset,
                "--model-name",
                valid_model,
                "--max-train-samples",
                "500",
            ]
        )
        assert args.max_train_samples == 500

    def test_parser_rejects_invalid_dataset_names(self):
        """Verify parser rejects dataset names not in DATASET_CONFIGs."""
        parser = init_parser()

        # Get valid model
        valid_model = list(MODEL_CONFIGS.keys())[0]

        # Test 1: Invalid train dataset
        with pytest.raises(SystemExit):
            parser.parse_args(
                ["--train-dataset", "invalid_dataset_name", "--model-name", valid_model]
            )

        # Test 2: Invalid test dataset
        valid_dataset = list(DATASET_CONFIGs.keys())[0]
        with pytest.raises(SystemExit):
            parser.parse_args(
                [
                    "--train-dataset",
                    valid_dataset,
                    "--test-dataset",
                    "invalid_test_dataset",
                    "--model-name",
                    valid_model,
                ]
            )

        # Test 3: Invalid model name
        with pytest.raises(SystemExit):
            parser.parse_args(
                ["--train-dataset", valid_dataset, "--model-name", "invalid_model_name"]
            )

        # Test 4: Missing required argument (train-dataset)
        with pytest.raises(SystemExit):
            parser.parse_args(["--model-name", valid_model])

        # Test 5: Missing required argument (model-name)
        with pytest.raises(SystemExit):
            parser.parse_args(["--train-dataset", valid_dataset])


class TestTrainingPipeline:
    @pytest.mark.slow
    def test_setfit_training_completes(
        self, minimal_train_config, minimal_model_config_setfit, temp_data_dir
    ):
        """Verify SetFit training runs without errors (using minimal data)."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Load data
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                minimal_train_config,
                test_config=None,
                test_size=0.33,  # 2 samples for test, 4 for train
            )

            # Mock MLflow to avoid logging during tests
            with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
                # Train model
                result = train_model_setfit(minimal_model_config_setfit, train_ds, test_ds)

                # Check that result is returned
                assert result is not None

                # Check that result has expected structure
                model, metrics, y_true, y_pred, model_type = result

                # Check model type
                assert model_type == "setfit"

                # Check that model is returned
                assert model is not None

                # Check that metrics are computed
                assert isinstance(metrics, dict)
                assert "accuracy" in metrics
                assert "f1_macro" in metrics
                assert "f1_weighted" in metrics

                # Check that metrics are in valid range [0, 1]
                assert 0 <= metrics["accuracy"] <= 1
                assert 0 <= metrics["f1_macro"] <= 1
                assert 0 <= metrics["f1_weighted"] <= 1

                # Check that predictions are returned
                assert y_true is not None
                assert y_pred is not None

                # Check that predictions have correct length
                assert len(y_true) == len(test_ds)
                assert len(y_pred) == len(test_ds)

                # Check that predictions are in valid label space
                num_classes = len(train_ds.label_encoder.classes_)
                assert all(0 <= pred < num_classes for pred in y_pred)
                assert all(0 <= true < num_classes for true in y_true)

    @pytest.mark.slow
    def test_transformers_training_completes(
        self, minimal_train_config, minimal_model_config_transformers, temp_data_dir
    ):
        """Verify Transformers training runs without errors (using minimal data)."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Load data
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                minimal_train_config,
                test_config=None,
                test_size=0.33,  # 2 samples for test, 4 for train
            )

            # Mock MLflow to avoid logging during tests
            with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
                # Train model
                result = train_model_transformers(
                    minimal_model_config_transformers, train_ds, test_ds
                )

                # Check that result is returned
                assert result is not None

                # Check that result has expected structure
                model_tuple, metrics, y_true, y_pred, model_type = result

                # Check model type
                assert model_type == "transformers"

                # Check that model and tokenizer are returned
                assert model_tuple is not None
                assert isinstance(model_tuple, tuple)
                assert len(model_tuple) == 2
                model, tokenizer = model_tuple
                assert model is not None
                assert tokenizer is not None

                # Check that metrics are computed
                assert isinstance(metrics, dict)
                # Transformers returns metrics with eval_ prefix
                assert any("accuracy" in key for key in metrics.keys())

                # Extract accuracy value (could be 'accuracy' or 'eval_accuracy')
                accuracy_key = [k for k in metrics.keys() if "accuracy" in k][0]
                accuracy = metrics[accuracy_key]
                assert 0 <= accuracy <= 1

                # Check that predictions are returned
                assert y_true is not None
                assert y_pred is not None

                # Check that predictions have correct length
                assert len(y_true) == len(test_ds)
                assert len(y_pred) == len(test_ds)

                # Check that predictions are in valid label space
                num_classes = len(train_ds.label_encoder.classes_)
                assert all(0 <= pred < num_classes for pred in y_pred)
                assert all(0 <= true < num_classes for true in y_true)

                # Check that predictions are numpy arrays or lists of integers
                assert all(isinstance(pred, (int, np.integer)) for pred in y_pred)
                assert all(isinstance(true, (int, np.integer)) for true in y_true)


class TestErrorHandling:
    """Test class for error handling functionality."""

    def test_missing_train_file_raises_error(self, temp_data_dir):
        """Verify appropriate error when train file doesn't exist."""
        # Create config pointing to non-existent file
        missing_file_config = {
            "data_path": "non_existent_file.csv",
            "label_col": "label",
            "title_col": "title",
            "body_col": "body",
            "sep": ",",
        }

        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Should call sys.exit(1) when file doesn't exist
            with pytest.raises(SystemExit) as excinfo:
                load_and_prepare_data(missing_file_config, test_config=None, test_size=0.2)

            # Check that exit code is 1
            assert excinfo.value.code == 1

    def test_invalid_label_column(self, temp_data_dir):
        """Verify error handling when specified label column doesn't exist."""
        # Create data with specific columns
        sample_data = pd.DataFrame(
            {
                "title": ["Bug 1", "Enhancement 1"],
                "body": ["Bug body", "Enhancement body"],
                "type": ["bug", "enhancement"],  # Different column name
            }
        )

        # Save to file
        train_path = temp_data_dir / "invalid_label_col.csv"
        sample_data.to_csv(train_path, index=False)

        # Create config with wrong label column name
        invalid_label_config = {
            "data_path": "invalid_label_col.csv",
            "label_col": "label",  # This column doesn't exist
            "title_col": "title",
            "body_col": "body",
            "sep": ",",
        }

        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Should call sys.exit(1) when label column doesn't exist
            with pytest.raises(SystemExit) as excinfo:
                load_and_prepare_data(invalid_label_config, test_config=None, test_size=0.2)

            # Check that exit code is 1
            assert excinfo.value.code == 1

    def test_invalid_body_column(self, temp_data_dir):
        """Verify error handling when specified body column doesn't exist."""
        # Create data with specific columns
        sample_data = pd.DataFrame(
            {
                "title": ["Bug 1", "Enhancement 1"],
                "description": ["Bug body", "Enhancement body"],  # Different column name
                "label": ["bug", "enhancement"],
            }
        )

        # Save to file
        train_path = temp_data_dir / "invalid_body_col.csv"
        sample_data.to_csv(train_path, index=False)

        # Create config with wrong body column name
        invalid_body_config = {
            "data_path": "invalid_body_col.csv",
            "label_col": "label",
            "title_col": "title",
            "body_col": "body",  # This column doesn't exist
            "sep": ",",
        }

        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Should call sys.exit(1) when body column doesn't exist
            with pytest.raises(SystemExit) as excinfo:
                load_and_prepare_data(invalid_body_config, test_config=None, test_size=0.2)

            # Check that exit code is 1
            assert excinfo.value.code == 1


class TestEdgeCases:
    """Test class for edge case scenarios."""

    def test_very_small_dataset(self, temp_data_dir):
        """Verify training with very small datasets (< 10 samples)."""
        # Create very small dataset (6 samples total, 3 per class)
        very_small_data = pd.DataFrame(
            {
                "title": ["Bug 1", "Bug 2", "Bug 3", "Enh 1", "Enh 2", "Enh 3"],
                "body": [
                    "Small bug 1",
                    "Small bug 2",
                    "Small bug 3",
                    "Small enh 1",
                    "Small enh 2",
                    "Small enh 3",
                ],
                "label": ["bug", "bug", "bug", "enhancement", "enhancement", "enhancement"],
            }
        )

        # Save to file
        train_path = temp_data_dir / "very_small.csv"
        very_small_data.to_csv(train_path, index=False)

        small_config = {
            "data_path": "very_small.csv",
            "label_col": "label",
            "title_col": "title",
            "body_col": "body",
            "sep": ",",
        }

        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Load data with small test split to ensure at least 1 sample per class in train
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                small_config,
                test_config=None,
                test_size=0.33,  # 2 samples for test (1 per class), 4 for train
            )

            # Check that datasets were created despite small size
            assert isinstance(train_ds, Dataset)
            assert isinstance(test_ds, Dataset)

            # Check that both datasets have samples
            assert len(train_ds) > 0
            assert len(test_ds) > 0

            # Check total is preserved
            assert len(train_ds) + len(test_ds) == 6

            # Check that stratification preserved both classes in train set
            train_unique_labels = set(train_ds["label"])
            assert len(train_unique_labels) >= 1  # At least one class

            # Check that we have valid label encoding
            num_classes = len(train_ds.label_encoder.classes_)
            assert num_classes == 2
            assert all(0 <= label < num_classes for label in train_ds["label"])
            assert all(0 <= label < num_classes for label in test_ds["label"])

            # Check that text was properly created
            assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
            assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])


class TestOutputValidation:
    """Test class for output validation functionality."""

    @pytest.mark.slow
    def test_predictions_match_label_space(
        self,
        minimal_train_config,
        minimal_model_config_setfit,
        minimal_model_config_transformers,
        temp_data_dir,
    ):
        """Verify predictions are within valid label space."""
        with patch(
            "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
            temp_data_dir,
        ):
            # Load data
            train_ds, test_ds, eval_strategy = load_and_prepare_data(
                minimal_train_config,
                test_config=None,
                test_size=0.33,  # 2 samples for tfest, 4 for train
            )

            # Get the valid label space
            num_classes = len(train_ds.label_encoder.classes_)
            valid_label_space = set(range(num_classes))

            # Mock MLflow to avoid logging during tests
            with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
                # Test with SetFit
                model, metrics, y_true, y_pred, model_type = train_model_setfit(
                    minimal_model_config_setfit, train_ds, test_ds
                )

                # Check that all predictions are in valid label space
                assert all(
                    pred in valid_label_space for pred in y_pred
                ), f"SetFit predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred)}"

                # Check that all true labels are in valid label space
                assert all(
                    true in valid_label_space for true in y_true
                ), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true)}"

                # Check that predictions are within [0, num_classes)
                assert all(
                    0 <= pred < num_classes for pred in y_pred
                ), f"SetFit predictions out of range [0, {num_classes})"

                # Check that y_true matches the original test labels
                assert list(y_true) == list(
                    test_ds["label"]
                ), "True labels don't match original test dataset labels"

                # Test with Transformers
                (model_t, tokenizer), metrics_t, y_true_t, y_pred_t, model_type_t = (
                    train_model_transformers(minimal_model_config_transformers, train_ds, test_ds)
                )

                # Check that all predictions are in valid label space
                assert all(
                    pred in valid_label_space for pred in y_pred_t
                ), f"Transformers predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred_t)}"

                # Check that all true labels are in valid label space
                assert all(
                    true in valid_label_space for true in y_true_t
                ), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true_t)}"

                # Check that predictions are within [0, num_classes)
                assert all(
                    0 <= pred < num_classes for pred in y_pred_t
                ), f"Transformers predictions out of range [0, {num_classes})"

                # Check that y_true matches the original test labels
                assert list(y_true_t) == list(
                    test_ds["label"]
                ), "True labels don't match original test dataset labels"

                # Additional check: verify predictions are integers
                assert all(
                    isinstance(pred, (int, np.integer)) for pred in y_pred
                ), "SetFit predictions must be integers"
                assert all(
                    isinstance(pred, (int, np.integer)) for pred in y_pred_t
                ), "Transformers predictions must be integers"

                # Check that at least some predictions were made (not all same)
                # This is a sanity check - with random initialization, we should get some variation
                # (Though with very small data, it's possible all predictions are the same)
                unique_preds = len(set(y_pred))
                unique_preds_t = len(set(y_pred_t))
                assert unique_preds >= 1, "SetFit made no predictions"
                assert unique_preds_t >= 1, "Transformers made no predictions"