micheleminervini0604 commited on
Commit
3aaced5
·
1 Parent(s): 0425031

Added tests for the train.py module, fixed an error when loading the dataset

Browse files
pyproject.toml CHANGED
@@ -61,3 +61,7 @@ force-sort-within-sections = true
61
  quote-style = "double"
62
  indent-style = "space"
63
 
 
 
 
 
 
61
  quote-style = "double"
62
  indent-style = "space"
63
 
64
+ [tool.pytest.ini_options]
65
+ markers = [
66
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
67
+ ]
syntetic_issue_report_data_generation/modeling/train.py CHANGED
@@ -2,196 +2,215 @@
2
  import argparse
3
  import os
4
  import sys
5
- import warnings
6
- import pandas as pd
7
  import dagshub
8
- import mlflow
9
- from pathlib import Path
10
  from datasets import Dataset
 
 
 
11
  from sklearn.model_selection import train_test_split
12
- from sklearn.metrics import classification_report, accuracy_score, f1_score
13
-
14
  import torch
15
- print(f"CUDA available: {torch.cuda.is_available()}")
16
- print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
17
 
18
  from syntetic_issue_report_data_generation.config import (
19
- DATASET_CONFIGs,
20
- MODEL_CONFIGS,
21
- MLFLOW_TRACKING_URI,
22
  MLFLOW_EXPERIMENT_NAME,
23
- INTERIM_DATA_DIR,
24
- SOFT_CLEANED_DATA_DIR
 
 
25
  )
26
 
27
-
 
28
 
29
  # Global settings
30
  GLOBAL_SEED = 42
31
- os.environ['MLFLOW_TRACKING_URI'] = MLFLOW_TRACKING_URI
 
32
 
33
  def init_parser():
34
  """Initialize the argument parser."""
35
  parser = argparse.ArgumentParser(description="Train a model for issue classification.")
36
  parser.add_argument(
37
- "--train-dataset",
38
- type=str,
39
- required=True,
40
  choices=DATASET_CONFIGs.keys(),
41
- help="Name of the train dataset configuration to use (from config.py)"
42
  )
43
  parser.add_argument(
44
- "--test-dataset",
45
- type=str,
46
  required=False,
47
  choices=DATASET_CONFIGs.keys(),
48
- help="Name of the test dataset configuration to use (from config.py). If not provided, will create a holdout split from train data."
49
  )
50
  parser.add_argument(
51
- "--model-name",
52
- type=str,
53
- required=True,
54
  choices=MODEL_CONFIGS.keys(),
55
- help="Name of the model configuration to use (from config.py)"
56
  )
57
  parser.add_argument(
58
  "--use-setfit",
59
  action="store_true",
60
- help="Use SetFit for training instead of standard transformers"
61
  )
62
  parser.add_argument(
63
  "--test-size",
64
  type=float,
65
  default=0.2,
66
- help="Test size for holdout split if test-dataset not provided (default: 0.2)"
67
  )
68
  parser.add_argument(
69
  "--max-train-samples",
70
  type=int,
71
- default=None,
72
- help="Maximum number of train samples to use for training. Uses stratified sampling if provided."
73
  )
74
  parser.add_argument(
75
- "--run-name",
76
- type=str,
77
- default=None,
78
- help="Custom name for the MLflow run"
79
  )
80
  return parser
81
 
 
82
  def load_and_prepare_data(train_config, test_config=None, test_size=0.2, max_train_samples=None):
83
  """
84
  Load and prepare data from config entries.
85
-
86
  Args:
87
  train_config: Train dataset configuration dictionary
88
  test_config: Optional test dataset configuration dictionary
89
  test_size: Size of holdout split if test_config not provided
90
- max_train_samples: Maximum number of train samples to use
91
  """
92
  from sklearn.preprocessing import LabelEncoder
93
-
94
  print(f"Loading train data from: {train_config['data_path']}")
95
-
96
  # Get train configuration
97
- train_path = SOFT_CLEANED_DATA_DIR / train_config['data_path']
98
- train_label_col = train_config['label_col']
99
- train_title_col = train_config.get('title_col')
100
- train_body_col = train_config['body_col']
101
- train_sep = train_config.get('sep', ',')
102
 
103
  # Load train data
104
  if not train_path.exists():
105
  print(f"Error: Train file not found at {train_path}")
106
  sys.exit(1)
107
-
108
  train_df = pd.read_csv(train_path, sep=train_sep)
109
-
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Handle test data
111
  if test_config:
112
  print(f"Loading test data from: {test_config['data_path']}")
113
- test_path = SOFT_CLEANED_DATA_DIR / test_config['data_path']
114
- test_label_col = test_config['label_col']
115
- test_title_col = test_config.get('title_col')
116
- test_body_col = test_config['body_col']
117
- test_sep = test_config.get('sep', ',')
118
-
119
  if not test_path.exists():
120
  print(f"Error: Test file not found at {test_path}")
121
  sys.exit(1)
122
-
123
  test_df = pd.read_csv(test_path, sep=test_sep)
 
124
  evaluation_strategy = "pre-split"
125
-
126
  # Create text columns with respective configurations
127
  if train_title_col and train_body_col:
128
- train_df['text'] = train_df[train_title_col].fillna('') + " " + train_df[train_body_col].fillna('')
 
 
129
  else:
130
- train_df['text'] = train_df[train_body_col].fillna('')
131
-
132
  if test_title_col and test_body_col:
133
- test_df['text'] = test_df[test_title_col].fillna('') + " " + test_df[test_body_col].fillna('')
 
 
134
  else:
135
- test_df['text'] = test_df[test_body_col].fillna('')
136
-
137
  # Rename label columns to 'label'
138
  train_df = train_df[["text", train_label_col]].rename(columns={train_label_col: "label"})
139
  test_df = test_df[["text", test_label_col]].rename(columns={test_label_col: "label"})
140
  else:
141
  print(f"No test dataset provided. Creating holdout split with test_size={test_size}")
142
-
143
  # Create text column
144
  if train_title_col and train_body_col:
145
- train_df['text'] = train_df[train_title_col].fillna('') + " " + train_df[train_body_col].fillna('')
 
 
146
  else:
147
- train_df['text'] = train_df[train_body_col].fillna('')
148
-
149
  # Select and rename columns
150
  train_df = train_df[["text", train_label_col]].rename(columns={train_label_col: "label"})
151
-
152
  # Create holdout split
153
  train_df, test_df = train_test_split(
154
- train_df,
155
- test_size=test_size,
156
- random_state=GLOBAL_SEED,
157
- stratify=train_df["label"]
158
  )
159
  evaluation_strategy = "holdout"
160
 
161
  # Applica il campionamento se max_train_samples è specificato e il dataset è più grande
162
  if max_train_samples is not None and len(train_df) > max_train_samples:
163
- print(f"Sampling {max_train_samples} samples from the training set (original size: {len(train_df)}).")
164
-
 
 
165
  # Per garantire il campionamento stratificato, calcoliamo quanti campioni prendere per classe
166
- num_classes = train_df["label"].nunique() # Numero di classi univoche
167
  samples_per_class = max_train_samples // num_classes
168
-
169
  # Campiona stratificato
170
  sampled_train_df_list = []
171
  for label_val in train_df["label"].unique():
172
  class_subset = train_df[train_df["label"] == label_val]
173
- sampled_train_df_list.append(class_subset.sample(n=min(len(class_subset), samples_per_class), random_state=GLOBAL_SEED))
174
-
175
- train_df = pd.concat(sampled_train_df_list).sample(frac=1, random_state=GLOBAL_SEED).reset_index(drop=True) # Ricombina e mescola
176
- print(f"New train samples after stratified sampling: {len(train_df)}")
177
-
178
 
 
 
 
 
 
 
179
 
180
  # Encode labels to integers
181
  label_encoder = LabelEncoder()
182
-
183
  # Fit on combined labels to ensure consistency
184
  all_labels = pd.concat([train_df["label"], test_df["label"]])
185
  label_encoder.fit(all_labels)
186
-
187
  # Transform labels
188
  train_df["label"] = label_encoder.transform(train_df["label"])
189
  test_df["label"] = label_encoder.transform(test_df["label"])
190
-
191
  # Log label mapping
192
  label_mapping = {str(label): int(idx) for idx, label in enumerate(label_encoder.classes_)}
193
  print(f"Label mapping: {label_mapping}")
194
-
195
  # Reset index to avoid issues
196
  train_df = train_df.reset_index(drop=True)
197
  test_df = test_df.reset_index(drop=True)
@@ -200,66 +219,67 @@ def load_and_prepare_data(train_config, test_config=None, test_size=0.2, max_tra
200
  train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
201
  test_dataset = Dataset.from_pandas(test_df, preserve_index=False)
202
 
203
- print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
204
- print(f"Train columns: {train_dataset.column_names}")
205
- print(f"Test columns: {test_dataset.column_names}")
206
  print(f"Number of unique labels: {len(label_encoder.classes_)}")
207
-
208
  # Store label encoder in the dataset for later use
209
  train_dataset.label_encoder = label_encoder
210
  test_dataset.label_encoder = label_encoder
211
-
212
  return train_dataset, test_dataset, evaluation_strategy
213
 
 
214
  def train_model_setfit(model_config, train_ds, test_ds):
215
  """Train the model using SetFit."""
216
  from setfit import SetFitModel, SetFitTrainer
217
-
218
  # 1. Load the pretrained model from Hugging Face
219
  print(f"Loading SetFit model: {model_config['model_checkpoint']}")
220
  model = SetFitModel.from_pretrained(
221
- model_config['model_checkpoint'],
222
  )
223
 
224
  # 2. Define training arguments
225
- setfit_params = model_config['params']
226
 
227
  # 3. Monkey patch per disabilitare i callback problematici
228
  import transformers.integrations.integration_utils as integration_utils
229
-
230
  # Salva le classi originali
231
  original_mlflow_callback = integration_utils.MLflowCallback
232
- original_dagshub_callback = getattr(integration_utils, 'DagsHubCallback', None)
233
-
234
  # Sostituisci con mock che non fanno nulla
235
  integration_utils.MLflowCallback = lambda: None
236
  if original_dagshub_callback:
237
  integration_utils.DagsHubCallback = lambda: None
238
-
239
  try:
240
  # Initialize the Trainer
241
  trainer = SetFitTrainer(
242
  model=model,
243
  train_dataset=train_ds,
244
  eval_dataset=test_ds,
245
- metric="accuracy",
246
  **setfit_params,
247
  seed=GLOBAL_SEED,
248
  )
249
-
250
  # IMPORTANTE: Rimuovi i callback problematici dal st_trainer
251
- if hasattr(trainer, 'st_trainer') and trainer.st_trainer is not None:
252
  callbacks_to_remove = []
253
  for callback in trainer.st_trainer.callback_handler.callbacks:
254
  callback_class_name = callback.__class__.__name__
255
  # Rimuovi MLflow e DagsHub callbacks
256
- if 'MLflow' in callback_class_name or 'DagsHub' in callback_class_name:
257
  callbacks_to_remove.append(callback)
258
-
259
  for callback in callbacks_to_remove:
260
  print(f"Removing problematic callback: {callback.__class__.__name__}")
261
  trainer.st_trainer.callback_handler.remove_callback(callback)
262
-
263
  finally:
264
  # Ripristina le classi originali
265
  integration_utils.MLflowCallback = original_mlflow_callback
@@ -268,14 +288,16 @@ def train_model_setfit(model_config, train_ds, test_ds):
268
 
269
  # 4. Train the model
270
  print("Starting SetFit model training...")
271
-
272
  # Log parametri manualmente
273
- mlflow.log_params({
274
- "model_checkpoint": model_config['model_checkpoint'],
275
- **setfit_params,
276
- "seed": GLOBAL_SEED,
277
- })
278
-
 
 
279
  trainer.train()
280
  print("Training complete.")
281
 
@@ -283,67 +305,64 @@ def train_model_setfit(model_config, train_ds, test_ds):
283
  print("Evaluating model...")
284
  metrics = trainer.evaluate()
285
  print(f"Metrics: {metrics}")
286
-
287
  # Log metriche manualmente
288
  mlflow.log_metrics(metrics)
289
 
290
- # 6. Get predictions for the classification report
291
  y_true = test_ds["label"]
292
  y_pred = model.predict(test_ds["text"])
293
 
294
  return model, metrics, y_true, y_pred, "setfit"
295
 
296
- def train_model_transformers(model_config, train_ds, test_ds):
 
297
  """Train the model using standard Transformers Trainer."""
 
298
  from transformers import (
299
- AutoTokenizer,
300
- AutoModelForSequenceClassification,
301
- TrainingArguments,
302
  Trainer,
303
- DataCollatorWithPadding
304
  )
305
- import torch
306
- import numpy as np
307
-
308
  # 1. Load tokenizer and model
309
  print(f"Loading Transformers model: {model_config['model_checkpoint']}")
310
- tokenizer = AutoTokenizer.from_pretrained(model_config['model_checkpoint'])
311
-
312
  # Determine the number of unique labels
313
  num_labels = len(set(train_ds["label"]))
314
  model = AutoModelForSequenceClassification.from_pretrained(
315
- model_config['model_checkpoint'],
316
- num_labels=num_labels
317
  )
318
-
319
  # 2. Tokenize the datasets
320
  def tokenize_function(examples):
321
- return tokenizer(examples["text"], truncation=True, max_length=256, padding=False) # prova anche con 256
322
-
 
 
323
  print("Tokenizing datasets...")
324
- tokenized_train = train_ds.map(tokenize_function, batched=True, remove_columns=["text"])
325
- tokenized_test = test_ds.map(tokenize_function, batched=True, remove_columns=["text"])
326
-
327
  # 3. Data collator
328
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
329
-
330
  # 4. Define evaluation metrics
331
- def compute_metrics(eval_pred):
332
- logits, labels = eval_pred
333
  predictions = np.argmax(logits, axis=-1)
334
  acc = accuracy_score(labels, predictions)
335
- f1_macro = f1_score(labels, predictions, average='macro')
336
- f1_weighted = f1_score(labels, predictions, average='weighted')
337
- return {
338
- "accuracy": acc,
339
- "f1_macro": f1_macro,
340
- "f1_weighted": f1_weighted
341
- }
342
-
343
  # 5. Training arguments
344
  training_args = TrainingArguments(
345
  output_dir="./results",
346
- **model_config['params'],
347
  seed=GLOBAL_SEED,
348
  eval_strategy="epoch",
349
  save_strategy="epoch",
@@ -351,7 +370,7 @@ def train_model_transformers(model_config, train_ds, test_ds):
351
  report_to="none", # Disable automatic reporting to avoid conflicts with MLflow
352
  push_to_hub=False, # Disable pushing to hub
353
  )
354
-
355
  # 6. Initialize the Trainer
356
  trainer = Trainer(
357
  model=model,
@@ -362,82 +381,83 @@ def train_model_transformers(model_config, train_ds, test_ds):
362
  data_collator=data_collator,
363
  compute_metrics=compute_metrics,
364
  )
365
-
366
  # 7. Train the model
367
  print("Starting Transformers model training...")
368
  trainer.train()
369
  print("Training complete.")
370
-
371
  # 8. Evaluate the model
372
  print("Evaluating model...")
373
  metrics = trainer.evaluate()
374
  print(f"Metrics: {metrics}")
375
-
376
  # 9. Log metrics to MLflow manually
377
  for key, value in metrics.items():
378
  mlflow.log_metric(key, value)
379
-
380
  # 10. Get predictions
381
  predictions = trainer.predict(tokenized_test)
382
  y_pred = np.argmax(predictions.predictions, axis=-1)
383
  y_true = tokenized_test["label"]
384
-
385
  # 11. Log classification report
386
  label_encoder = test_ds.label_encoder
387
  target_names = label_encoder.classes_
388
-
389
  report = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
390
  print("\nClassification Report:")
391
  print(classification_report(y_true, y_pred, target_names=target_names))
392
-
393
  # Log per-class metrics to MLflow
394
  for label, metrics_dict in report.items():
395
  if isinstance(metrics_dict, dict):
396
  for metric_name, value in metrics_dict.items():
397
  mlflow.log_metric(f"{label}_{metric_name}", value)
398
-
399
  # Log label mapping
400
- mlflow.log_dict({str(k): v for k, v in enumerate(label_encoder.classes_)}, "label_mapping.json")
401
-
 
 
402
  return (model, tokenizer), metrics, y_true, y_pred, "transformers"
403
 
404
 
405
  if __name__ == "__main__":
406
  args = init_parser().parse_args()
407
-
408
  # Get configurations
409
  train_config = DATASET_CONFIGs[args.train_dataset]
410
  test_config = DATASET_CONFIGs[args.test_dataset] if args.test_dataset else None
411
  model_config = MODEL_CONFIGS[args.model_name]
412
-
413
  # Load data
414
  train_ds, test_ds, eval_strategy = load_and_prepare_data(
415
- train_config,
416
- test_config,
417
- args.test_size,
418
- max_train_samples=args.max_train_samples
419
  )
420
-
421
  # Set up MLflow
422
- dagshub.init(repo_owner='se4ai2526-uniba', repo_name='Capibara', mlflow=True)
423
  mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
424
-
425
  # Generate run name
426
  run_name = args.run_name or f"{args.model_name}_{args.train_dataset}"
427
  if args.test_dataset:
428
  run_name += f"_{args.test_dataset}"
429
-
430
  # Train model
431
  with mlflow.start_run(run_name=run_name):
432
- mlflow.log_params({
433
- "train_dataset": args.train_dataset,
434
- "test_dataset": args.test_dataset or "holdout",
435
- "model_name": args.model_name,
436
- "evaluation_strategy": eval_strategy,
437
- "use_setfit": args.use_setfit,
438
- })
439
-
 
 
440
  if args.use_setfit:
441
  train_model_setfit(model_config, train_ds, test_ds)
442
  else:
443
- train_model_transformers(model_config, train_ds, test_ds)
 
2
  import argparse
3
  import os
4
  import sys
5
+
 
6
  import dagshub
 
 
7
  from datasets import Dataset
8
+ import mlflow
9
+ import pandas as pd
10
+ from sklearn.metrics import accuracy_score, classification_report, f1_score
11
  from sklearn.model_selection import train_test_split
 
 
12
  import torch
 
 
13
 
14
  from syntetic_issue_report_data_generation.config import (
 
 
 
15
  MLFLOW_EXPERIMENT_NAME,
16
+ MLFLOW_TRACKING_URI,
17
+ MODEL_CONFIGS,
18
+ SOFT_CLEANED_DATA_DIR,
19
+ DATASET_CONFIGs,
20
  )
21
 
22
+ print(f"CUDA available: {torch.cuda.is_available()}")
23
+ print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
24
 
25
  # Global settings
26
  GLOBAL_SEED = 42
27
+ os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI
28
+
29
 
30
  def init_parser():
31
  """Initialize the argument parser."""
32
  parser = argparse.ArgumentParser(description="Train a model for issue classification.")
33
  parser.add_argument(
34
+ "--train-dataset",
35
+ type=str,
36
+ required=True,
37
  choices=DATASET_CONFIGs.keys(),
38
+ help="Name of the train dataset configuration to use (from config.py)",
39
  )
40
  parser.add_argument(
41
+ "--test-dataset",
42
+ type=str,
43
  required=False,
44
  choices=DATASET_CONFIGs.keys(),
45
+ help="Name of the test dataset configuration to use (from config.py). If not provided, will create a holdout split from train data.",
46
  )
47
  parser.add_argument(
48
+ "--model-name",
49
+ type=str,
50
+ required=True,
51
  choices=MODEL_CONFIGS.keys(),
52
+ help="Name of the model configuration to use (from config.py)",
53
  )
54
  parser.add_argument(
55
  "--use-setfit",
56
  action="store_true",
57
+ help="Use SetFit for training instead of standard transformers",
58
  )
59
  parser.add_argument(
60
  "--test-size",
61
  type=float,
62
  default=0.2,
63
+ help="Test size for holdout split if test-dataset not provided (default: 0.2)",
64
  )
65
  parser.add_argument(
66
  "--max-train-samples",
67
  type=int,
68
+ default=None,
69
+ help="Maximum number of train samples to use for training. Uses stratified sampling if provided.",
70
  )
71
  parser.add_argument(
72
+ "--run-name", type=str, default=None, help="Custom name for the MLflow run"
 
 
 
73
  )
74
  return parser
75
 
76
+
77
  def load_and_prepare_data(train_config, test_config=None, test_size=0.2, max_train_samples=None):
78
  """
79
  Load and prepare data from config entries.
80
+
81
  Args:
82
  train_config: Train dataset configuration dictionary
83
  test_config: Optional test dataset configuration dictionary
84
  test_size: Size of holdout split if test_config not provided
85
+ max_train_samples: Maximum number of train samples to use
86
  """
87
  from sklearn.preprocessing import LabelEncoder
88
+
89
  print(f"Loading train data from: {train_config['data_path']}")
90
+
91
  # Get train configuration
92
+ train_path = SOFT_CLEANED_DATA_DIR / train_config["data_path"]
93
+ train_label_col = train_config["label_col"]
94
+ train_title_col = train_config.get("title_col")
95
+ train_body_col = train_config["body_col"]
96
+ train_sep = train_config.get("sep", ",")
97
 
98
  # Load train data
99
  if not train_path.exists():
100
  print(f"Error: Train file not found at {train_path}")
101
  sys.exit(1)
102
+
103
  train_df = pd.read_csv(train_path, sep=train_sep)
104
+
105
+ # Validate required columns exist in train data
106
+ required_columns = [train_label_col, train_body_col]
107
+ if train_title_col:
108
+ required_columns.append(train_title_col)
109
+
110
+ missing_columns = [col for col in required_columns if col not in train_df.columns]
111
+ if missing_columns:
112
+ print(
113
+ f"Error: Required columns {missing_columns} not found in train dataset. Available columns: {list(train_df.columns)}"
114
+ )
115
+ sys.exit(1)
116
+
117
  # Handle test data
118
  if test_config:
119
  print(f"Loading test data from: {test_config['data_path']}")
120
+ test_path = SOFT_CLEANED_DATA_DIR / test_config["data_path"]
121
+ test_label_col = test_config["label_col"]
122
+ test_title_col = test_config.get("title_col")
123
+ test_body_col = test_config["body_col"]
124
+ test_sep = test_config.get("sep", ",")
125
+
126
  if not test_path.exists():
127
  print(f"Error: Test file not found at {test_path}")
128
  sys.exit(1)
129
+
130
  test_df = pd.read_csv(test_path, sep=test_sep)
131
+
132
  evaluation_strategy = "pre-split"
133
+
134
  # Create text columns with respective configurations
135
  if train_title_col and train_body_col:
136
+ train_df["text"] = (
137
+ train_df[train_title_col].fillna("") + " " + train_df[train_body_col].fillna("")
138
+ )
139
  else:
140
+ train_df["text"] = train_df[train_body_col].fillna("")
141
+
142
  if test_title_col and test_body_col:
143
+ test_df["text"] = (
144
+ test_df[test_title_col].fillna("") + " " + test_df[test_body_col].fillna("")
145
+ )
146
  else:
147
+ test_df["text"] = test_df[test_body_col].fillna("")
148
+
149
  # Rename label columns to 'label'
150
  train_df = train_df[["text", train_label_col]].rename(columns={train_label_col: "label"})
151
  test_df = test_df[["text", test_label_col]].rename(columns={test_label_col: "label"})
152
  else:
153
  print(f"No test dataset provided. Creating holdout split with test_size={test_size}")
154
+
155
  # Create text column
156
  if train_title_col and train_body_col:
157
+ train_df["text"] = (
158
+ train_df[train_title_col].fillna("") + " " + train_df[train_body_col].fillna("")
159
+ )
160
  else:
161
+ train_df["text"] = train_df[train_body_col].fillna("")
162
+
163
  # Select and rename columns
164
  train_df = train_df[["text", train_label_col]].rename(columns={train_label_col: "label"})
165
+
166
  # Create holdout split
167
  train_df, test_df = train_test_split(
168
+ train_df, test_size=test_size, random_state=GLOBAL_SEED, stratify=train_df["label"]
 
 
 
169
  )
170
  evaluation_strategy = "holdout"
171
 
172
  # Applica il campionamento se max_train_samples è specificato e il dataset è più grande
173
  if max_train_samples is not None and len(train_df) > max_train_samples:
174
+ print(
175
+ f"Sampling {max_train_samples} samples from the training set (original size: {len(train_df)})."
176
+ )
177
+
178
  # Per garantire il campionamento stratificato, calcoliamo quanti campioni prendere per classe
179
+ num_classes = train_df["label"].nunique() # Numero di classi univoche
180
  samples_per_class = max_train_samples // num_classes
181
+
182
  # Campiona stratificato
183
  sampled_train_df_list = []
184
  for label_val in train_df["label"].unique():
185
  class_subset = train_df[train_df["label"] == label_val]
186
+ sampled_train_df_list.append(
187
+ class_subset.sample(
188
+ n=min(len(class_subset), samples_per_class), random_state=GLOBAL_SEED
189
+ )
190
+ )
191
 
192
+ train_df = (
193
+ pd.concat(sampled_train_df_list)
194
+ .sample(frac=1, random_state=GLOBAL_SEED)
195
+ .reset_index(drop=True)
196
+ ) # Ricombina e mescola
197
+ print(f"New train samples after stratified sampling: {len(train_df)}")
198
 
199
  # Encode labels to integers
200
  label_encoder = LabelEncoder()
201
+
202
  # Fit on combined labels to ensure consistency
203
  all_labels = pd.concat([train_df["label"], test_df["label"]])
204
  label_encoder.fit(all_labels)
205
+
206
  # Transform labels
207
  train_df["label"] = label_encoder.transform(train_df["label"])
208
  test_df["label"] = label_encoder.transform(test_df["label"])
209
+
210
  # Log label mapping
211
  label_mapping = {str(label): int(idx) for idx, label in enumerate(label_encoder.classes_)}
212
  print(f"Label mapping: {label_mapping}")
213
+
214
  # Reset index to avoid issues
215
  train_df = train_df.reset_index(drop=True)
216
  test_df = test_df.reset_index(drop=True)
 
219
  train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
220
  test_dataset = Dataset.from_pandas(test_df, preserve_index=False)
221
 
222
+ print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
223
+ print(f"Train columns: {train_dataset.column_names}")
224
+ print(f"Test columns: {test_dataset.column_names}")
225
  print(f"Number of unique labels: {len(label_encoder.classes_)}")
226
+
227
  # Store label encoder in the dataset for later use
228
  train_dataset.label_encoder = label_encoder
229
  test_dataset.label_encoder = label_encoder
230
+
231
  return train_dataset, test_dataset, evaluation_strategy
232
 
233
+
234
  def train_model_setfit(model_config, train_ds, test_ds):
235
  """Train the model using SetFit."""
236
  from setfit import SetFitModel, SetFitTrainer
237
+
238
  # 1. Load the pretrained model from Hugging Face
239
  print(f"Loading SetFit model: {model_config['model_checkpoint']}")
240
  model = SetFitModel.from_pretrained(
241
+ model_config["model_checkpoint"],
242
  )
243
 
244
  # 2. Define training arguments
245
+ setfit_params = model_config["params"]
246
 
247
  # 3. Monkey patch per disabilitare i callback problematici
248
  import transformers.integrations.integration_utils as integration_utils
249
+
250
  # Salva le classi originali
251
  original_mlflow_callback = integration_utils.MLflowCallback
252
+ original_dagshub_callback = getattr(integration_utils, "DagsHubCallback", None)
253
+
254
  # Sostituisci con mock che non fanno nulla
255
  integration_utils.MLflowCallback = lambda: None
256
  if original_dagshub_callback:
257
  integration_utils.DagsHubCallback = lambda: None
258
+
259
  try:
260
  # Initialize the Trainer
261
  trainer = SetFitTrainer(
262
  model=model,
263
  train_dataset=train_ds,
264
  eval_dataset=test_ds,
265
+ metric="accuracy",
266
  **setfit_params,
267
  seed=GLOBAL_SEED,
268
  )
269
+
270
  # IMPORTANTE: Rimuovi i callback problematici dal st_trainer
271
+ if hasattr(trainer, "st_trainer") and trainer.st_trainer is not None:
272
  callbacks_to_remove = []
273
  for callback in trainer.st_trainer.callback_handler.callbacks:
274
  callback_class_name = callback.__class__.__name__
275
  # Rimuovi MLflow e DagsHub callbacks
276
+ if "MLflow" in callback_class_name or "DagsHub" in callback_class_name:
277
  callbacks_to_remove.append(callback)
278
+
279
  for callback in callbacks_to_remove:
280
  print(f"Removing problematic callback: {callback.__class__.__name__}")
281
  trainer.st_trainer.callback_handler.remove_callback(callback)
282
+
283
  finally:
284
  # Ripristina le classi originali
285
  integration_utils.MLflowCallback = original_mlflow_callback
 
288
 
289
  # 4. Train the model
290
  print("Starting SetFit model training...")
291
+
292
  # Log parametri manualmente
293
+ mlflow.log_params(
294
+ {
295
+ "model_checkpoint": model_config["model_checkpoint"],
296
+ **setfit_params,
297
+ "seed": GLOBAL_SEED,
298
+ }
299
+ )
300
+
301
  trainer.train()
302
  print("Training complete.")
303
 
 
305
  print("Evaluating model...")
306
  metrics = trainer.evaluate()
307
  print(f"Metrics: {metrics}")
308
+
309
  # Log metriche manualmente
310
  mlflow.log_metrics(metrics)
311
 
312
+ # 6. Get predictions for the classification report
313
  y_true = test_ds["label"]
314
  y_pred = model.predict(test_ds["text"])
315
 
316
  return model, metrics, y_true, y_pred, "setfit"
317
 
318
+
319
+ def train_model_transformers(model_config, train_ds, test_ds):
320
  """Train the model using standard Transformers Trainer."""
321
+ import numpy as np
322
  from transformers import (
323
+ AutoModelForSequenceClassification,
324
+ AutoTokenizer,
325
+ DataCollatorWithPadding,
326
  Trainer,
327
+ TrainingArguments,
328
  )
329
+
 
 
330
  # 1. Load tokenizer and model
331
  print(f"Loading Transformers model: {model_config['model_checkpoint']}")
332
+ tokenizer = AutoTokenizer.from_pretrained(model_config["model_checkpoint"])
333
+
334
  # Determine the number of unique labels
335
  num_labels = len(set(train_ds["label"]))
336
  model = AutoModelForSequenceClassification.from_pretrained(
337
+ model_config["model_checkpoint"], num_labels=num_labels
 
338
  )
339
+
340
  # 2. Tokenize the datasets
341
  def tokenize_function(examples):
342
+ return tokenizer(
343
+ examples["text"], truncation=True, max_length=256, padding=False
344
+ ) # prova anche con 256
345
+
346
  print("Tokenizing datasets...")
347
+ tokenized_train = train_ds.map(tokenize_function, batched=True, remove_columns=["text"])
348
+ tokenized_test = test_ds.map(tokenize_function, batched=True, remove_columns=["text"])
349
+
350
  # 3. Data collator
351
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
352
+
353
  # 4. Define evaluation metrics
354
+ def compute_metrics(eval_pred):
355
+ logits, labels = eval_pred
356
  predictions = np.argmax(logits, axis=-1)
357
  acc = accuracy_score(labels, predictions)
358
+ f1_macro = f1_score(labels, predictions, average="macro")
359
+ f1_weighted = f1_score(labels, predictions, average="weighted")
360
+ return {"accuracy": acc, "f1_macro": f1_macro, "f1_weighted": f1_weighted}
361
+
 
 
 
 
362
  # 5. Training arguments
363
  training_args = TrainingArguments(
364
  output_dir="./results",
365
+ **model_config["params"],
366
  seed=GLOBAL_SEED,
367
  eval_strategy="epoch",
368
  save_strategy="epoch",
 
370
  report_to="none", # Disable automatic reporting to avoid conflicts with MLflow
371
  push_to_hub=False, # Disable pushing to hub
372
  )
373
+
374
  # 6. Initialize the Trainer
375
  trainer = Trainer(
376
  model=model,
 
381
  data_collator=data_collator,
382
  compute_metrics=compute_metrics,
383
  )
384
+
385
  # 7. Train the model
386
  print("Starting Transformers model training...")
387
  trainer.train()
388
  print("Training complete.")
389
+
390
  # 8. Evaluate the model
391
  print("Evaluating model...")
392
  metrics = trainer.evaluate()
393
  print(f"Metrics: {metrics}")
394
+
395
  # 9. Log metrics to MLflow manually
396
  for key, value in metrics.items():
397
  mlflow.log_metric(key, value)
398
+
399
  # 10. Get predictions
400
  predictions = trainer.predict(tokenized_test)
401
  y_pred = np.argmax(predictions.predictions, axis=-1)
402
  y_true = tokenized_test["label"]
403
+
404
  # 11. Log classification report
405
  label_encoder = test_ds.label_encoder
406
  target_names = label_encoder.classes_
407
+
408
  report = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
409
  print("\nClassification Report:")
410
  print(classification_report(y_true, y_pred, target_names=target_names))
411
+
412
  # Log per-class metrics to MLflow
413
  for label, metrics_dict in report.items():
414
  if isinstance(metrics_dict, dict):
415
  for metric_name, value in metrics_dict.items():
416
  mlflow.log_metric(f"{label}_{metric_name}", value)
417
+
418
  # Log label mapping
419
+ mlflow.log_dict(
420
+ {str(k): v for k, v in enumerate(label_encoder.classes_)}, "label_mapping.json"
421
+ )
422
+
423
  return (model, tokenizer), metrics, y_true, y_pred, "transformers"
424
 
425
 
426
  if __name__ == "__main__":
427
  args = init_parser().parse_args()
428
+
429
  # Get configurations
430
  train_config = DATASET_CONFIGs[args.train_dataset]
431
  test_config = DATASET_CONFIGs[args.test_dataset] if args.test_dataset else None
432
  model_config = MODEL_CONFIGS[args.model_name]
433
+
434
  # Load data
435
  train_ds, test_ds, eval_strategy = load_and_prepare_data(
436
+ train_config, test_config, args.test_size, max_train_samples=args.max_train_samples
 
 
 
437
  )
438
+
439
  # Set up MLflow
440
+ dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Capibara", mlflow=True)
441
  mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
442
+
443
  # Generate run name
444
  run_name = args.run_name or f"{args.model_name}_{args.train_dataset}"
445
  if args.test_dataset:
446
  run_name += f"_{args.test_dataset}"
447
+
448
  # Train model
449
  with mlflow.start_run(run_name=run_name):
450
+ mlflow.log_params(
451
+ {
452
+ "train_dataset": args.train_dataset,
453
+ "test_dataset": args.test_dataset or "holdout",
454
+ "model_name": args.model_name,
455
+ "evaluation_strategy": eval_strategy,
456
+ "use_setfit": args.use_setfit,
457
+ }
458
+ )
459
+
460
  if args.use_setfit:
461
  train_model_setfit(model_config, train_ds, test_ds)
462
  else:
463
+ train_model_transformers(model_config, train_ds, test_ds)
tests/test_train.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+ import tempfile
4
+ from unittest.mock import patch
5
+
6
+ from datasets import Dataset
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pytest
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ from syntetic_issue_report_data_generation.config import MODEL_CONFIGS, DATASET_CONFIGs
14
+ from syntetic_issue_report_data_generation.modeling.train import (
15
+ init_parser,
16
+ load_and_prepare_data,
17
+ train_model_setfit,
18
+ train_model_transformers,
19
+ )
20
+
21
+
22
+ @pytest.fixture
23
+ def temp_data_dir():
24
+ """Create a temporary directory for test data files."""
25
+ with tempfile.TemporaryDirectory() as tmpdirname:
26
+ yield Path(tmpdirname)
27
+
28
+
29
+ @pytest.fixture
30
+ def sample_train_data():
31
+ """Create sample training data with balanced classes."""
32
+ return pd.DataFrame(
33
+ {
34
+ "title": [
35
+ "Bug in login",
36
+ "Feature request",
37
+ "Performance issue",
38
+ "UI problem",
39
+ "Bug in logout",
40
+ "Add search",
41
+ "Memory leak",
42
+ "New API endpoint",
43
+ "Crash on startup",
44
+ "Enhancement needed",
45
+ ],
46
+ "body": [
47
+ "Cannot login to system",
48
+ "Add dark mode feature",
49
+ "Slow loading times",
50
+ "Button misaligned",
51
+ "Cannot logout properly",
52
+ "Need search functionality",
53
+ "High memory usage",
54
+ "REST API needed",
55
+ "Application crashes",
56
+ "Improve user experience",
57
+ ],
58
+ "label": [
59
+ "bug",
60
+ "enhancement",
61
+ "bug",
62
+ "bug",
63
+ "bug",
64
+ "enhancement",
65
+ "bug",
66
+ "enhancement",
67
+ "bug",
68
+ "enhancement",
69
+ ],
70
+ }
71
+ )
72
+
73
+
74
+ @pytest.fixture
75
+ def sample_imbalanced_data():
76
+ """Create sample data with imbalanced classes (for stratified sampling test)."""
77
+ return pd.DataFrame(
78
+ {
79
+ "title": [
80
+ "Bug 1",
81
+ "Bug 2",
82
+ "Bug 3",
83
+ "Bug 4",
84
+ "Bug 5",
85
+ "Bug 6",
86
+ "Bug 7",
87
+ "Bug 8",
88
+ "Enhancement 1",
89
+ "Enhancement 2",
90
+ ],
91
+ "body": [
92
+ "Bug body 1",
93
+ "Bug body 2",
94
+ "Bug body 3",
95
+ "Bug body 4",
96
+ "Bug body 5",
97
+ "Bug body 6",
98
+ "Bug body 7",
99
+ "Bug body 8",
100
+ "Enhancement body 1",
101
+ "Enhancement body 2",
102
+ ],
103
+ "label": [
104
+ "bug",
105
+ "bug",
106
+ "bug",
107
+ "bug",
108
+ "bug",
109
+ "bug",
110
+ "bug",
111
+ "bug",
112
+ "enhancement",
113
+ "enhancement",
114
+ ],
115
+ }
116
+ )
117
+
118
+
119
+ @pytest.fixture
120
+ def train_config_with_title(temp_data_dir, sample_train_data):
121
+ """Create train config with title and body columns."""
122
+ train_path = temp_data_dir / "train_with_title.csv"
123
+ sample_train_data.to_csv(train_path, index=False)
124
+
125
+ return {
126
+ "data_path": "train_with_title.csv",
127
+ "label_col": "label",
128
+ "title_col": "title",
129
+ "body_col": "body",
130
+ "sep": ",",
131
+ }
132
+
133
+
134
+ @pytest.fixture
135
+ def imbalanced_train_config(temp_data_dir, sample_imbalanced_data):
136
+ """Create train config with imbalanced data."""
137
+ train_path = temp_data_dir / "train_imbalanced.csv"
138
+ sample_imbalanced_data.to_csv(train_path, index=False)
139
+
140
+ return {
141
+ "data_path": "train_imbalanced.csv",
142
+ "label_col": "label",
143
+ "title_col": "title",
144
+ "body_col": "body",
145
+ "sep": ",",
146
+ }
147
+
148
+
149
+ @pytest.fixture
150
+ def minimal_train_data():
151
+ """Create minimal training data for quick training tests."""
152
+ return pd.DataFrame(
153
+ {
154
+ "title": [
155
+ "Bug 1",
156
+ "Bug 2",
157
+ "Enhancement 1",
158
+ "Enhancement 2",
159
+ "Bug 3",
160
+ "Enhancement 3",
161
+ ],
162
+ "body": [
163
+ "Bug body 1",
164
+ "Bug body 2",
165
+ "Enh body 1",
166
+ "Enh body 2",
167
+ "Bug body 3",
168
+ "Enh body 3",
169
+ ],
170
+ "label": ["bug", "bug", "enhancement", "enhancement", "bug", "enhancement"],
171
+ }
172
+ )
173
+
174
+
175
+ @pytest.fixture
176
+ def minimal_train_config(temp_data_dir, minimal_train_data):
177
+ """Create train config with minimal data for fast training."""
178
+ train_path = temp_data_dir / "minimal_train.csv"
179
+ minimal_train_data.to_csv(train_path, index=False)
180
+
181
+ return {
182
+ "data_path": "minimal_train.csv",
183
+ "label_col": "label",
184
+ "title_col": "title",
185
+ "body_col": "body",
186
+ "sep": ",",
187
+ }
188
+
189
+
190
+ @pytest.fixture
191
+ def minimal_model_config_setfit():
192
+ """Create minimal SetFit model configuration for testing."""
193
+ return {
194
+ "model_checkpoint": "sentence-transformers/paraphrase-MiniLM-L3-v2", # Small, fast model
195
+ "params": {"num_epochs": 1, "batch_size": 4, "num_iterations": 5, "max_length": 64},
196
+ }
197
+
198
+
199
+ @pytest.fixture
200
+ def minimal_model_config_transformers():
201
+ """Create minimal Transformers model configuration for testing."""
202
+ return {
203
+ "model_checkpoint": "prajjwal1/bert-tiny", # Very small BERT model
204
+ "params": {
205
+ "num_train_epochs": 1,
206
+ "per_device_train_batch_size": 2,
207
+ "per_device_eval_batch_size": 2,
208
+ "learning_rate": 5e-5,
209
+ "warmup_steps": 0,
210
+ "weight_decay": 0.01,
211
+ "logging_steps": 1,
212
+ },
213
+ }
214
+
215
+
216
+ class TestDataLoadingAndPreparation:
217
+ """Test class for data loading and preparation functionality."""
218
+
219
+ def test_load_data_with_valid_config(self, train_config_with_title, temp_data_dir):
220
+ """Verify that data loads correctly with valid train dataset configuration."""
221
+ with patch(
222
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
223
+ temp_data_dir,
224
+ ):
225
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
226
+ train_config_with_title, test_config=None, test_size=0.2
227
+ )
228
+
229
+ # Check that datasets were created
230
+ assert isinstance(train_ds, Dataset)
231
+ assert isinstance(test_ds, Dataset)
232
+
233
+ # Check that datasets have correct columns
234
+ assert set(train_ds.column_names) == {"text", "label"}
235
+ assert set(test_ds.column_names) == {"text", "label"}
236
+
237
+ # Check that datasets are not empty
238
+ assert len(train_ds) > 0
239
+ assert len(test_ds) > 0
240
+
241
+ # Check total samples
242
+ assert len(train_ds) + len(test_ds) == 10
243
+
244
+ # Check label encoder is attached
245
+ assert hasattr(train_ds, "label_encoder")
246
+ assert hasattr(test_ds, "label_encoder")
247
+
248
+ # Check that labels are integers
249
+ assert all(isinstance(label, int) for label in train_ds["label"])
250
+ assert all(isinstance(label, int) for label in test_ds["label"])
251
+
252
+ def test_load_data_creates_holdout_split(self, train_config_with_title, temp_data_dir):
253
+ """Verify holdout split is created when no test dataset is provided."""
254
+ with patch(
255
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
256
+ temp_data_dir,
257
+ ):
258
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
259
+ train_config_with_title, test_config=None, test_size=0.2
260
+ )
261
+
262
+ # Check eval strategy is holdout
263
+ assert eval_strategy == "holdout"
264
+
265
+ # Check that both datasets exist
266
+ assert len(train_ds) > 0
267
+ assert len(test_ds) > 0
268
+
269
+ # Check that split is approximately correct (80/20 split of 10 samples)
270
+ total_samples = len(train_ds) + len(test_ds)
271
+ assert total_samples == 10
272
+ assert len(test_ds) == 2 # 20% of 10 = 2
273
+ assert len(train_ds) == 8 # 80% of 10 = 8
274
+
275
+ # Verify no data leakage (no overlap between train and test)
276
+ train_texts = set(train_ds["text"])
277
+ test_texts = set(test_ds["text"])
278
+ assert len(train_texts.intersection(test_texts)) == 0
279
+
280
+ def test_label_encoding_consistency(self, train_config_with_title, temp_data_dir):
281
+ """Verify labels are encoded consistently across train/test sets."""
282
+ with patch(
283
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
284
+ temp_data_dir,
285
+ ):
286
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
287
+ train_config_with_title, test_config=None, test_size=0.2
288
+ )
289
+
290
+ # Check that label encoders are the same object
291
+ assert train_ds.label_encoder is test_ds.label_encoder
292
+
293
+ # Check that labels are integers
294
+ assert all(isinstance(label, int) for label in train_ds["label"])
295
+ assert all(isinstance(label, int) for label in test_ds["label"])
296
+
297
+ # Check that label classes are consistent
298
+ train_label_classes = train_ds.label_encoder.classes_
299
+ test_label_classes = test_ds.label_encoder.classes_
300
+ assert list(train_label_classes) == list(test_label_classes)
301
+
302
+ # Check that we have the expected classes (bug and enhancement)
303
+ expected_classes = sorted(["bug", "enhancement"])
304
+ actual_classes = sorted(train_ds.label_encoder.classes_)
305
+ assert actual_classes == expected_classes
306
+
307
+ # Check that encoded labels are in valid range [0, num_classes)
308
+ num_classes = len(train_ds.label_encoder.classes_)
309
+ assert num_classes == 2
310
+ assert all(0 <= label < num_classes for label in train_ds["label"])
311
+ assert all(0 <= label < num_classes for label in test_ds["label"])
312
+
313
+ # Check that both classes appear in train set (due to stratification)
314
+ train_unique_labels = set(train_ds["label"])
315
+ assert len(train_unique_labels) == 2
316
+
317
+ def test_text_column_creation_with_title_and_body(
318
+ self, train_config_with_title, temp_data_dir
319
+ ):
320
+ """Verify text column combines title and body correctly."""
321
+ with patch(
322
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
323
+ temp_data_dir,
324
+ ):
325
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
326
+ train_config_with_title, test_config=None, test_size=0.2
327
+ )
328
+
329
+ # Check that text column exists and is the only text column
330
+ assert "text" in train_ds.column_names
331
+ assert "text" in test_ds.column_names
332
+ assert "title" not in train_ds.column_names
333
+ assert "body" not in train_ds.column_names
334
+
335
+ # Check that all text entries are non-empty strings
336
+ assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
337
+ assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
338
+
339
+ # Check that text contains content (not just whitespace)
340
+ assert all(len(text.strip()) > 0 for text in train_ds["text"])
341
+ assert all(len(text.strip()) > 0 for text in test_ds["text"])
342
+
343
+ # Check that text is longer than just title or body alone
344
+ # (indicating concatenation happened)
345
+ for text in train_ds["text"]:
346
+ # Text should have reasonable length (at least 10 chars)
347
+ assert len(text) >= 10
348
+
349
+ def test_max_train_samples_stratified_sampling(self, imbalanced_train_config, temp_data_dir):
350
+ """Verify stratified sampling works correctly when max_train_samples is specified."""
351
+ with patch(
352
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
353
+ temp_data_dir,
354
+ ):
355
+ # Original data has 10 samples: 8 bugs, 2 enhancements
356
+ # Request only 4 samples
357
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
358
+ imbalanced_train_config, test_config=None, test_size=0.2, max_train_samples=4
359
+ )
360
+
361
+ # Check that train dataset size is reduced to 4
362
+ assert len(train_ds) == 4
363
+
364
+ # Check that we still have both classes (stratified sampling)
365
+ unique_labels = set(train_ds["label"])
366
+ assert len(unique_labels) == 2, "Stratified sampling should preserve both classes"
367
+
368
+ # Check that class distribution is approximately maintained
369
+ # Original: 80% bug, 20% enhancement
370
+ # With 4 samples: should have ~3 bugs, ~1 enhancement
371
+ label_counts = {}
372
+ for label in train_ds["label"]:
373
+ label_name = train_ds.label_encoder.inverse_transform([label])[0]
374
+ label_counts[label_name] = label_counts.get(label_name, 0) + 1
375
+
376
+ # At least one sample from minority class
377
+ assert label_counts.get("enhancement", 0) >= 1
378
+ # Majority class should have more samples
379
+ assert label_counts.get("bug", 0) >= label_counts.get("enhancement", 0)
380
+
381
+ # Test dataset should remain at 20% of original (2 samples)
382
+ assert len(test_ds) == 2
383
+
384
+ # Total should be less than original
385
+ assert len(train_ds) + len(test_ds) < 10
386
+
387
+
388
+ class TestConfiguration:
389
+ """Test class for configuration and argument parsing functionality."""
390
+
391
+ def test_parser_accepts_valid_arguments(self):
392
+ """Verify parser accepts all valid combinations of arguments."""
393
+ parser = init_parser()
394
+
395
+ # Get first valid dataset and model from configs
396
+ valid_dataset = list(DATASET_CONFIGs.keys())[0]
397
+ valid_model = list(MODEL_CONFIGS.keys())[0]
398
+
399
+ # Test 1: Minimal required arguments
400
+ args = parser.parse_args(["--train-dataset", valid_dataset, "--model-name", valid_model])
401
+ assert args.train_dataset == valid_dataset
402
+ assert args.model_name == valid_model
403
+ assert args.test_dataset is None
404
+ assert args.test_size == 0.2 # default value
405
+ assert args.max_train_samples is None # default value
406
+ assert args.use_setfit is False # default value
407
+ assert args.run_name is None # default value
408
+
409
+ # Test 2: All arguments provided
410
+ if len(DATASET_CONFIGs.keys()) > 1:
411
+ valid_test_dataset = list(DATASET_CONFIGs.keys())[1]
412
+ else:
413
+ valid_test_dataset = valid_dataset
414
+
415
+ args = parser.parse_args(
416
+ [
417
+ "--train-dataset",
418
+ valid_dataset,
419
+ "--test-dataset",
420
+ valid_test_dataset,
421
+ "--model-name",
422
+ valid_model,
423
+ "--use-setfit",
424
+ "--test-size",
425
+ "0.3",
426
+ "--max-train-samples",
427
+ "100",
428
+ "--run-name",
429
+ "test_run",
430
+ ]
431
+ )
432
+ assert args.train_dataset == valid_dataset
433
+ assert args.test_dataset == valid_test_dataset
434
+ assert args.model_name == valid_model
435
+ assert args.use_setfit is True
436
+ assert args.test_size == 0.3
437
+ assert args.max_train_samples == 100
438
+ assert args.run_name == "test_run"
439
+
440
+ # Test 3: Only use-setfit flag
441
+ args = parser.parse_args(
442
+ ["--train-dataset", valid_dataset, "--model-name", valid_model, "--use-setfit"]
443
+ )
444
+ assert args.use_setfit is True
445
+
446
+ # Test 4: Custom test-size
447
+ args = parser.parse_args(
448
+ ["--train-dataset", valid_dataset, "--model-name", valid_model, "--test-size", "0.15"]
449
+ )
450
+ assert args.test_size == 0.15
451
+
452
+ # Test 5: Custom max-train-samples
453
+ args = parser.parse_args(
454
+ [
455
+ "--train-dataset",
456
+ valid_dataset,
457
+ "--model-name",
458
+ valid_model,
459
+ "--max-train-samples",
460
+ "500",
461
+ ]
462
+ )
463
+ assert args.max_train_samples == 500
464
+
465
+ def test_parser_rejects_invalid_dataset_names(self):
466
+ """Verify parser rejects dataset names not in DATASET_CONFIGs."""
467
+ parser = init_parser()
468
+
469
+ # Get valid model
470
+ valid_model = list(MODEL_CONFIGS.keys())[0]
471
+
472
+ # Test 1: Invalid train dataset
473
+ with pytest.raises(SystemExit):
474
+ parser.parse_args(
475
+ ["--train-dataset", "invalid_dataset_name", "--model-name", valid_model]
476
+ )
477
+
478
+ # Test 2: Invalid test dataset
479
+ valid_dataset = list(DATASET_CONFIGs.keys())[0]
480
+ with pytest.raises(SystemExit):
481
+ parser.parse_args(
482
+ [
483
+ "--train-dataset",
484
+ valid_dataset,
485
+ "--test-dataset",
486
+ "invalid_test_dataset",
487
+ "--model-name",
488
+ valid_model,
489
+ ]
490
+ )
491
+
492
+ # Test 3: Invalid model name
493
+ with pytest.raises(SystemExit):
494
+ parser.parse_args(
495
+ ["--train-dataset", valid_dataset, "--model-name", "invalid_model_name"]
496
+ )
497
+
498
+ # Test 4: Missing required argument (train-dataset)
499
+ with pytest.raises(SystemExit):
500
+ parser.parse_args(["--model-name", valid_model])
501
+
502
+ # Test 5: Missing required argument (model-name)
503
+ with pytest.raises(SystemExit):
504
+ parser.parse_args(["--train-dataset", valid_dataset])
505
+
506
+
507
+ class TestTrainingPipeline:
508
+ @pytest.mark.slow
509
+ def test_setfit_training_completes(
510
+ self, minimal_train_config, minimal_model_config_setfit, temp_data_dir
511
+ ):
512
+ """Verify SetFit training runs without errors (using minimal data)."""
513
+ with patch(
514
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
515
+ temp_data_dir,
516
+ ):
517
+ # Load data
518
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
519
+ minimal_train_config,
520
+ test_config=None,
521
+ test_size=0.33, # 2 samples for test, 4 for train
522
+ )
523
+
524
+ # Mock MLflow to avoid logging during tests
525
+ with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
526
+ # Train model
527
+ result = train_model_setfit(minimal_model_config_setfit, train_ds, test_ds)
528
+
529
+ # Check that result is returned
530
+ assert result is not None
531
+
532
+ # Check that result has expected structure
533
+ model, metrics, y_true, y_pred, model_type = result
534
+
535
+ # Check model type
536
+ assert model_type == "setfit"
537
+
538
+ # Check that model is returned
539
+ assert model is not None
540
+
541
+ # Check that metrics are computed
542
+ assert isinstance(metrics, dict)
543
+ assert "accuracy" in metrics
544
+ assert "f1_macro" in metrics
545
+ assert "f1_weighted" in metrics
546
+
547
+ # Check that metrics are in valid range [0, 1]
548
+ assert 0 <= metrics["accuracy"] <= 1
549
+ assert 0 <= metrics["f1_macro"] <= 1
550
+ assert 0 <= metrics["f1_weighted"] <= 1
551
+
552
+ # Check that predictions are returned
553
+ assert y_true is not None
554
+ assert y_pred is not None
555
+
556
+ # Check that predictions have correct length
557
+ assert len(y_true) == len(test_ds)
558
+ assert len(y_pred) == len(test_ds)
559
+
560
+ # Check that predictions are in valid label space
561
+ num_classes = len(train_ds.label_encoder.classes_)
562
+ assert all(0 <= pred < num_classes for pred in y_pred)
563
+ assert all(0 <= true < num_classes for true in y_true)
564
+
565
+ @pytest.mark.slow
566
+ def test_transformers_training_completes(
567
+ self, minimal_train_config, minimal_model_config_transformers, temp_data_dir
568
+ ):
569
+ """Verify Transformers training runs without errors (using minimal data)."""
570
+ with patch(
571
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
572
+ temp_data_dir,
573
+ ):
574
+ # Load data
575
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
576
+ minimal_train_config,
577
+ test_config=None,
578
+ test_size=0.33, # 2 samples for test, 4 for train
579
+ )
580
+
581
+ # Mock MLflow to avoid logging during tests
582
+ with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
583
+ # Train model
584
+ result = train_model_transformers(
585
+ minimal_model_config_transformers, train_ds, test_ds
586
+ )
587
+
588
+ # Check that result is returned
589
+ assert result is not None
590
+
591
+ # Check that result has expected structure
592
+ model_tuple, metrics, y_true, y_pred, model_type = result
593
+
594
+ # Check model type
595
+ assert model_type == "transformers"
596
+
597
+ # Check that model and tokenizer are returned
598
+ assert model_tuple is not None
599
+ assert isinstance(model_tuple, tuple)
600
+ assert len(model_tuple) == 2
601
+ model, tokenizer = model_tuple
602
+ assert model is not None
603
+ assert tokenizer is not None
604
+
605
+ # Check that metrics are computed
606
+ assert isinstance(metrics, dict)
607
+ # Transformers returns metrics with eval_ prefix
608
+ assert any("accuracy" in key for key in metrics.keys())
609
+
610
+ # Extract accuracy value (could be 'accuracy' or 'eval_accuracy')
611
+ accuracy_key = [k for k in metrics.keys() if "accuracy" in k][0]
612
+ accuracy = metrics[accuracy_key]
613
+ assert 0 <= accuracy <= 1
614
+
615
+ # Check that predictions are returned
616
+ assert y_true is not None
617
+ assert y_pred is not None
618
+
619
+ # Check that predictions have correct length
620
+ assert len(y_true) == len(test_ds)
621
+ assert len(y_pred) == len(test_ds)
622
+
623
+ # Check that predictions are in valid label space
624
+ num_classes = len(train_ds.label_encoder.classes_)
625
+ assert all(0 <= pred < num_classes for pred in y_pred)
626
+ assert all(0 <= true < num_classes for true in y_true)
627
+
628
+ # Check that predictions are numpy arrays or lists of integers
629
+ assert all(isinstance(pred, (int, np.integer)) for pred in y_pred)
630
+ assert all(isinstance(true, (int, np.integer)) for true in y_true)
631
+
632
+
633
+ class TestErrorHandling:
634
+ """Test class for error handling functionality."""
635
+
636
+ def test_missing_train_file_raises_error(self, temp_data_dir):
637
+ """Verify appropriate error when train file doesn't exist."""
638
+ # Create config pointing to non-existent file
639
+ missing_file_config = {
640
+ "data_path": "non_existent_file.csv",
641
+ "label_col": "label",
642
+ "title_col": "title",
643
+ "body_col": "body",
644
+ "sep": ",",
645
+ }
646
+
647
+ with patch(
648
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
649
+ temp_data_dir,
650
+ ):
651
+ # Should call sys.exit(1) when file doesn't exist
652
+ with pytest.raises(SystemExit) as excinfo:
653
+ load_and_prepare_data(missing_file_config, test_config=None, test_size=0.2)
654
+
655
+ # Check that exit code is 1
656
+ assert excinfo.value.code == 1
657
+
658
+ def test_invalid_label_column(self, temp_data_dir):
659
+ """Verify error handling when specified label column doesn't exist."""
660
+ # Create data with specific columns
661
+ sample_data = pd.DataFrame(
662
+ {
663
+ "title": ["Bug 1", "Enhancement 1"],
664
+ "body": ["Bug body", "Enhancement body"],
665
+ "type": ["bug", "enhancement"], # Different column name
666
+ }
667
+ )
668
+
669
+ # Save to file
670
+ train_path = temp_data_dir / "invalid_label_col.csv"
671
+ sample_data.to_csv(train_path, index=False)
672
+
673
+ # Create config with wrong label column name
674
+ invalid_label_config = {
675
+ "data_path": "invalid_label_col.csv",
676
+ "label_col": "label", # This column doesn't exist
677
+ "title_col": "title",
678
+ "body_col": "body",
679
+ "sep": ",",
680
+ }
681
+
682
+ with patch(
683
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
684
+ temp_data_dir,
685
+ ):
686
+ # Should call sys.exit(1) when label column doesn't exist
687
+ with pytest.raises(SystemExit) as excinfo:
688
+ load_and_prepare_data(invalid_label_config, test_config=None, test_size=0.2)
689
+
690
+ # Check that exit code is 1
691
+ assert excinfo.value.code == 1
692
+
693
+ def test_invalid_body_column(self, temp_data_dir):
694
+ """Verify error handling when specified body column doesn't exist."""
695
+ # Create data with specific columns
696
+ sample_data = pd.DataFrame(
697
+ {
698
+ "title": ["Bug 1", "Enhancement 1"],
699
+ "description": ["Bug body", "Enhancement body"], # Different column name
700
+ "label": ["bug", "enhancement"],
701
+ }
702
+ )
703
+
704
+ # Save to file
705
+ train_path = temp_data_dir / "invalid_body_col.csv"
706
+ sample_data.to_csv(train_path, index=False)
707
+
708
+ # Create config with wrong body column name
709
+ invalid_body_config = {
710
+ "data_path": "invalid_body_col.csv",
711
+ "label_col": "label",
712
+ "title_col": "title",
713
+ "body_col": "body", # This column doesn't exist
714
+ "sep": ",",
715
+ }
716
+
717
+ with patch(
718
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
719
+ temp_data_dir,
720
+ ):
721
+ # Should call sys.exit(1) when body column doesn't exist
722
+ with pytest.raises(SystemExit) as excinfo:
723
+ load_and_prepare_data(invalid_body_config, test_config=None, test_size=0.2)
724
+
725
+ # Check that exit code is 1
726
+ assert excinfo.value.code == 1
727
+
728
+
729
+ class TestEdgeCases:
730
+ """Test class for edge case scenarios."""
731
+
732
+ def test_very_small_dataset(self, temp_data_dir):
733
+ """Verify training with very small datasets (< 10 samples)."""
734
+ # Create very small dataset (6 samples total, 3 per class)
735
+ very_small_data = pd.DataFrame(
736
+ {
737
+ "title": ["Bug 1", "Bug 2", "Bug 3", "Enh 1", "Enh 2", "Enh 3"],
738
+ "body": [
739
+ "Small bug 1",
740
+ "Small bug 2",
741
+ "Small bug 3",
742
+ "Small enh 1",
743
+ "Small enh 2",
744
+ "Small enh 3",
745
+ ],
746
+ "label": ["bug", "bug", "bug", "enhancement", "enhancement", "enhancement"],
747
+ }
748
+ )
749
+
750
+ # Save to file
751
+ train_path = temp_data_dir / "very_small.csv"
752
+ very_small_data.to_csv(train_path, index=False)
753
+
754
+ small_config = {
755
+ "data_path": "very_small.csv",
756
+ "label_col": "label",
757
+ "title_col": "title",
758
+ "body_col": "body",
759
+ "sep": ",",
760
+ }
761
+
762
+ with patch(
763
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
764
+ temp_data_dir,
765
+ ):
766
+ # Load data with small test split to ensure at least 1 sample per class in train
767
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
768
+ small_config,
769
+ test_config=None,
770
+ test_size=0.33, # 2 samples for test (1 per class), 4 for train
771
+ )
772
+
773
+ # Check that datasets were created despite small size
774
+ assert isinstance(train_ds, Dataset)
775
+ assert isinstance(test_ds, Dataset)
776
+
777
+ # Check that both datasets have samples
778
+ assert len(train_ds) > 0
779
+ assert len(test_ds) > 0
780
+
781
+ # Check total is preserved
782
+ assert len(train_ds) + len(test_ds) == 6
783
+
784
+ # Check that stratification preserved both classes in train set
785
+ train_unique_labels = set(train_ds["label"])
786
+ assert len(train_unique_labels) >= 1 # At least one class
787
+
788
+ # Check that we have valid label encoding
789
+ num_classes = len(train_ds.label_encoder.classes_)
790
+ assert num_classes == 2
791
+ assert all(0 <= label < num_classes for label in train_ds["label"])
792
+ assert all(0 <= label < num_classes for label in test_ds["label"])
793
+
794
+ # Check that text was properly created
795
+ assert all(isinstance(text, str) and len(text) > 0 for text in train_ds["text"])
796
+ assert all(isinstance(text, str) and len(text) > 0 for text in test_ds["text"])
797
+
798
+
799
+ class TestOutputValidation:
800
+ """Test class for output validation functionality."""
801
+
802
+ @pytest.mark.slow
803
+ def test_predictions_match_label_space(
804
+ self,
805
+ minimal_train_config,
806
+ minimal_model_config_setfit,
807
+ minimal_model_config_transformers,
808
+ temp_data_dir,
809
+ ):
810
+ """Verify predictions are within valid label space."""
811
+ with patch(
812
+ "syntetic_issue_report_data_generation.modeling.train.SOFT_CLEANED_DATA_DIR",
813
+ temp_data_dir,
814
+ ):
815
+ # Load data
816
+ train_ds, test_ds, eval_strategy = load_and_prepare_data(
817
+ minimal_train_config,
818
+ test_config=None,
819
+ test_size=0.33, # 2 samples for tfest, 4 for train
820
+ )
821
+
822
+ # Get the valid label space
823
+ num_classes = len(train_ds.label_encoder.classes_)
824
+ valid_label_space = set(range(num_classes))
825
+
826
+ # Mock MLflow to avoid logging during tests
827
+ with patch("syntetic_issue_report_data_generation.modeling.train.mlflow"):
828
+ # Test with SetFit
829
+ model, metrics, y_true, y_pred, model_type = train_model_setfit(
830
+ minimal_model_config_setfit, train_ds, test_ds
831
+ )
832
+
833
+ # Check that all predictions are in valid label space
834
+ assert all(
835
+ pred in valid_label_space for pred in y_pred
836
+ ), f"SetFit predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred)}"
837
+
838
+ # Check that all true labels are in valid label space
839
+ assert all(
840
+ true in valid_label_space for true in y_true
841
+ ), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true)}"
842
+
843
+ # Check that predictions are within [0, num_classes)
844
+ assert all(
845
+ 0 <= pred < num_classes for pred in y_pred
846
+ ), f"SetFit predictions out of range [0, {num_classes})"
847
+
848
+ # Check that y_true matches the original test labels
849
+ assert list(y_true) == list(
850
+ test_ds["label"]
851
+ ), "True labels don't match original test dataset labels"
852
+
853
+ # Test with Transformers
854
+ (model_t, tokenizer), metrics_t, y_true_t, y_pred_t, model_type_t = (
855
+ train_model_transformers(minimal_model_config_transformers, train_ds, test_ds)
856
+ )
857
+
858
+ # Check that all predictions are in valid label space
859
+ assert all(
860
+ pred in valid_label_space for pred in y_pred_t
861
+ ), f"Transformers predictions contain invalid labels. Valid: {valid_label_space}, Got: {set(y_pred_t)}"
862
+
863
+ # Check that all true labels are in valid label space
864
+ assert all(
865
+ true in valid_label_space for true in y_true_t
866
+ ), f"True labels contain invalid values. Valid: {valid_label_space}, Got: {set(y_true_t)}"
867
+
868
+ # Check that predictions are within [0, num_classes)
869
+ assert all(
870
+ 0 <= pred < num_classes for pred in y_pred_t
871
+ ), f"Transformers predictions out of range [0, {num_classes})"
872
+
873
+ # Check that y_true matches the original test labels
874
+ assert list(y_true_t) == list(
875
+ test_ds["label"]
876
+ ), "True labels don't match original test dataset labels"
877
+
878
+ # Additional check: verify predictions are integers
879
+ assert all(
880
+ isinstance(pred, (int, np.integer)) for pred in y_pred
881
+ ), "SetFit predictions must be integers"
882
+ assert all(
883
+ isinstance(pred, (int, np.integer)) for pred in y_pred_t
884
+ ), "Transformers predictions must be integers"
885
+
886
+ # Check that at least some predictions were made (not all same)
887
+ # This is a sanity check - with random initialization, we should get some variation
888
+ # (Though with very small data, it's possible all predictions are the same)
889
+ unique_preds = len(set(y_pred))
890
+ unique_preds_t = len(set(y_pred_t))
891
+ assert unique_preds >= 1, "SetFit made no predictions"
892
+ assert unique_preds_t >= 1, "Transformers made no predictions"