Classification in Imbalanced Datasets
Choosing the right metrics and optimizing performance when working with imbalanced data distributions
Introduction
One of the most challenging problems in machine learning projects is working with imbalanced datasets. An imbalanced dataset occurs when some classes have far fewer examples than others.
In this tutorial, we will examine a BERT model trained on a highly imbalanced dataset with categories A, B, C, and D.
Number of Examples per Category
Category | Number of Examples | Data Distribution |
---|---|---|
Category A | 100 |
|
Category B | 100 |
|
Category C | 10 | |
Category D | 10 |
As seen from the table and chart, categories A and B have 100 examples each, while categories C and D have only 10 examples each. This imbalance significantly affects the model's performance.
Confusion Matrix
To evaluate the model's performance, the confusion matrix shows the actual and predicted classes.
Predicted/Actual | A (Actual) | B (Actual) | C (Actual) | D (Actual) |
---|---|---|---|---|
A (Predicted) | 85 | 7 | 4 | 3 |
B (Predicted) | 10 | 120 | 3 | 5 |
C (Predicted) | 3 | 3 | 2 | 2 |
D (Predicted) | 2 | 0 | 1 | 1 |
From the confusion matrix, it can be observed that the model's performance is good for categories A and B, but poor for categories C and D.
Basic Values
From the confusion matrix for category A, we obtain the following values:
- True Positive (TP) = 85
Actual category A examples predicted as A - False Positive (FP) = 7 + 4 + 3 = 14
Examples not actually A, but predicted as A- 7 examples from B
- 4 examples from C
- 3 examples from D
- False Negative (FN) = 10 + 3 + 2 = 15
Examples actually A, but not predicted as A- 10 examples predicted as B
- 3 examples predicted as C
- 2 examples predicted as D
- True Negative (TN) = 120 + 3 + 5 + 3 + 2 + 2 + 0 + 1 + 1 = 137
Examples not A and not predicted as A
Step-by-Step Metric Calculations
1. Precision
Precision = TP / (TP + FP)
Precision = 85 / (85 + 14) = 85 / 99 ≈ 0.859 = 85.9%
2. Recall (Sensitivity)
Recall = TP / (TP + FN)
Recall = 85 / (85 + 15) = 85 / 100 = 0.85 = 85%
3. F1 Score
F1 = 2 * (Precision * Recall) / (Precision + Recall)
F1 ≈ 2 * (0.859 * 0.85) / (0.859 + 0.85) ≈ 0.854 = 85.4%
- Category B: F1 = 0.896
- Category A: F1 = 0.854
- Category C: F1 = 0.200
- Category D: F1 = 0.133
- Increase Recall value for category D
- Increase Precision and Recall values for category C
Performance Metrics
The basic metrics used in evaluating classification performance are: Precision, Recall, and F1 Score.
Basic Concepts
- True Positive (TP): Correct positive predictions
- False Positive (FP): Incorrect positive predictions
- False Negative (FN): Incorrect negative predictions
- True Negative (TN): Correct negative predictions
Formulas
Recall = TP / (TP + FN)
F1 Score = 2 * (Precision * Recall) / (Precision + Recall)
Category-Based Metrics
Category A
F1 Score
Category B
F1 Score
Category C
F1 Score
Category D
F1 Score
Category-Based Performance Metrics
Category | Precision | Recall | F1 Score | Performance Level |
---|---|---|---|---|
Category A | 0.85 | 0.85 | 0.85 | ✓ Good |
Category B | 0.89 | 0.86 | 0.87 | ✓ Good |
Category C | 0.20 | 0.20 | 0.20 | ✗ Very Low |
Category D | 0.25 | 0.10 | 0.14 | ✗ Very Low |
Average Metrics
Macro-average
Equal weight is given to each class.
- Precision: 0.55
- Recall: 0.50
- F1 Score: 0.52
Weighted Average
Weighted according to class sizes.
- Precision: 0.81
- Recall: 0.79
- F1 Score: 0.80
Micro-average
Calculated across all examples.
- Precision: 0.84
- Recall: 0.80
- F1 Score: 0.82
Imbalance Analysis
The extreme imbalance in the dataset seriously affects model performance. In this section, the effects of imbalance are analyzed.
Main Effects of Imbalance
- Difference Between Macro and Weighted Average: A large difference indicates failure in minority classes.
- Low Recall: Low recall for C and D shows that the model struggles to recognize these classes.
- Error Distribution: Misclassification of C and D examples as A and B in the confusion matrix.
Balanced Data
The model learns in a balanced way when all classes are equally represented.
Imbalanced Data
Underrepresented classes cause bias toward majority classes.
Root Causes of Imbalance
- Insufficient Learning: Few examples prevent the model from learning correctly.
- Generalization Difficulty: Difficult to generalize to new examples with few samples.
- Bias Toward Majority Classes: The model focuses on majority classes.
- Decision Boundary Problem: Few examples make it difficult to determine correct boundaries.
Solution Strategies
Various strategies can be applied to improve model performance in imbalanced datasets:
1. Data Balancing Techniques
Data augmentation and multiplication techniques for underrepresented classes:
- Oversampling: Increase the number of data points for categories C and D
- SMOTE: Generate synthetic examples
- Data Augmentation: Augmentation techniques for text data
from nlpaug.augmenter.word import SynonymAug
syn_aug = SynonymAug()
augmented_samples = []
for text, label in minority_samples:
for _ in range(50):
aug_text = syn_aug.augment(text)
augmented_samples.append((aug_text, label))
2. Model Improvement Strategies
Optimizing the training process:
- Class Weighting: Assigning higher weights to categories C and D
- Focal Loss: Loss function focusing on difficult examples
- Threshold Optimization: Setting low prediction thresholds
class FocalLoss(nn.Module):
def __init__(self, alpha, gamma=2.0):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
CE_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
pt = torch.exp(-CE_loss)
F_loss = (1 - pt)**self.gamma * CE_loss
return torch.mean(F_loss)
# A:1, B:1, C:25, D:25 weights
class_weights = torch.tensor([1.0, 1.0, 25.0, 25.0])
criterion = FocalLoss(alpha=class_weights, gamma=3.0)
3. Ensemble Methods
Combining multiple models to improve performance:
- Multiple Model Integration: Combination of models trained with different strategies
- Bagging and Boosting: Using AdaBoost or Gradient Boosting
- Cascade Classification: Staged classification approach
4. Advanced Techniques
Deep learning and meta-learning approaches:
- Few-Shot Learning: Learning techniques with few examples
- Meta-Learning: MAML or Prototypical Networks
- Active Learning: Selection of the most informative examples
These techniques enhance the model's generalization ability, especially in situations with limited data.
Practical Implementation Steps
Imbalanced Dataset Improvement Plan
-
Data Augmentation: Apply 50-100 times data augmentation for categories C and D.
Use backtranslation, word replacement, and SMOTE techniques.
-
Class Weighting: Use high weights (25-50x) for C and D.
Give more importance to minority classes with Focal Loss.
-
Two-Stage Training: First general training, then fine-tuning focused on minority classes.
Lower the learning rate (e.g., 5e-6) and train longer in the second stage.
-
Threshold Optimization: Set low threshold values for C and D (0.2-0.3).
Sacrifice precision to increase recall.
-
Ensemble Models: Combine models trained with different strategies.
Give higher weight (0.3-0.4) to models focused on C and D.
Expected Improvement
Category C
Category D
Macro F1 Score
Real World Applications
Imbalanced datasets appear in various fields in the real world. Here are some example scenarios:
Medical Diagnosis
Diagnosis of rare diseases naturally involves imbalanced datasets.
Problem:
- Very limited data for rare diseases
- False negatives can have serious consequences
Solution:
- Special loss functions that increase the cost of false negatives
- Creating synthetic data with examples from different stages of disease
- Few-shot learning techniques
Fraud Detection
Financial fraud detection typically involves extremely imbalanced data.
Problem:
- Fraudulent transactions represent a very small percentage of total transactions
- Patterns change rapidly
Solution:
- Anomaly detection and one-class classification
- Time-based sampling
- Improving labeling efficiency with active learning
Fault Detection
Fault detection in industrial equipment is challenging due to limited fault data.
Problem:
- Real fault data is scarce
- May not have enough examples for each fault type
Solution:
- Creating fault data with physics-based simulations
- Transfer learning
- Hybrid supervised/unsupervised learning methods
Advanced Resources
You can check the following resources to learn more about classification with imbalanced datasets: