Classification in Imbalanced Datasets

Choosing the right metrics and optimizing performance when working with imbalanced data distributions

Introduction

One of the most challenging problems in machine learning projects is working with imbalanced datasets. An imbalanced dataset occurs when some classes have far fewer examples than others.

In this tutorial, we will examine a BERT model trained on a highly imbalanced dataset with categories A, B, C, and D.

Number of Examples per Category

Category Number of Examples Data Distribution
Category A 100
Category B 100
Category C 10
Category D 10

As seen from the table and chart, categories A and B have 100 examples each, while categories C and D have only 10 examples each. This imbalance significantly affects the model's performance.

Data Imbalance Summary: Categories A and B have 10 times more examples than categories C and D. This situation negatively affects the model's ability to correctly predict minority classes (C and D).

Confusion Matrix

To evaluate the model's performance, the confusion matrix shows the actual and predicted classes.

Predicted/Actual A (Actual) B (Actual) C (Actual) D (Actual)
A (Predicted) 85 7 4 3
B (Predicted) 10 120 3 5
C (Predicted) 3 3 2 2
D (Predicted) 2 0 1 1

From the confusion matrix, it can be observed that the model's performance is good for categories A and B, but poor for categories C and D.

Metric Calculation for Category A Predicted/Actual A (Actual) B (Actual) C (Actual) D (Actual) A (Predicted) 85 True Positive 7 False Positive 4 False Positive 3 False Positive B (Predicted) 10 False Negative 120 True Negative 3 True Negative 5 True Negative C (Predicted) 3 False Negative 3 True Negative 2 True Negative 2 True Negative D (Predicted) 2 False Negative 0 True Negative 1 True Negative 1 True Negative
Metric Calculations for Category A

Basic Values

From the confusion matrix for category A, we obtain the following values:

  • True Positive (TP) = 85
    Actual category A examples predicted as A
  • False Positive (FP) = 7 + 4 + 3 = 14
    Examples not actually A, but predicted as A
    • 7 examples from B
    • 4 examples from C
    • 3 examples from D
  • False Negative (FN) = 10 + 3 + 2 = 15
    Examples actually A, but not predicted as A
    • 10 examples predicted as B
    • 3 examples predicted as C
    • 2 examples predicted as D
  • True Negative (TN) = 120 + 3 + 5 + 3 + 2 + 2 + 0 + 1 + 1 = 137
    Examples not A and not predicted as A

Step-by-Step Metric Calculations

1. Precision

Precision = TP / (TP + FP)

Precision = 85 / (85 + 14) = 85 / 99 ≈ 0.859 = 85.9%

2. Recall (Sensitivity)

Recall = TP / (TP + FN)

Recall = 85 / (85 + 15) = 85 / 100 = 0.85 = 85%

3. F1 Score

F1 = 2 * (Precision * Recall) / (Precision + Recall)

F1 ≈ 2 * (0.859 * 0.85) / (0.859 + 0.85) ≈ 0.854 = 85.4%

Metric Calculation for Category C Predicted/Actual A (Actual) B (Actual) C (Actual) D (Actual) A (Predicted) 85 True Negative 7 True Negative 4 False Negative 3 True Negative B (Predicted) 10 True Negative 120 True Negative 3 False Negative 5 True Negative C (Predicted) 3 False Positive 3 False Positive 2 True Positive 2 False Positive D (Predicted) 2 True Negative 0 True Negative 1 False Negative 1 True Negative
Metric Calculations for Category C
Difficulty Level by Category
  1. Category B: F1 = 0.896
  2. Category A: F1 = 0.854
  3. Category C: F1 = 0.200
  4. Category D: F1 = 0.133
Priority Areas for Improvement
  • Increase Recall value for category D
  • Increase Precision and Recall values for category C
For this purpose, techniques such as oversampling, SMOTE, class weighting, and focal loss can be used.

Performance Metrics

The basic metrics used in evaluating classification performance are: Precision, Recall, and F1 Score.

Basic Concepts

Formulas

Precision = TP / (TP + FP)
Recall = TP / (TP + FN)
F1 Score = 2 * (Precision * Recall) / (Precision + Recall)

Category-Based Metrics

Category A

0.85

F1 Score

Category B

0.87

F1 Score

Category C

0.20

F1 Score

Category D

0.14

F1 Score

Category-Based Performance Metrics

Category Precision Recall F1 Score Performance Level
Category A 0.85 0.85 0.85 ✓ Good
Category B 0.89 0.86 0.87 ✓ Good
Category C 0.20 0.20 0.20 ✗ Very Low
Category D 0.25 0.10 0.14 ✗ Very Low

Average Metrics

Macro-average

Equal weight is given to each class.

  • Precision: 0.55
  • Recall: 0.50
  • F1 Score: 0.52
Poor performance of minority classes lowers the macro-average.

Weighted Average

Weighted according to class sizes.

  • Precision: 0.81
  • Recall: 0.79
  • F1 Score: 0.80
Appears higher due to the effect of larger classes.

Micro-average

Calculated across all examples.

  • Precision: 0.84
  • Recall: 0.80
  • F1 Score: 0.82
May be misleading due to the effect of larger classes.

Imbalance Analysis

The extreme imbalance in the dataset seriously affects model performance. In this section, the effects of imbalance are analyzed.

Main Effects of Imbalance

Balanced Data

The model learns in a balanced way when all classes are equally represented.

Result: High F1 score

Imbalanced Data

Underrepresented classes cause bias toward majority classes.

Result: Very low F1 score for C and D

Root Causes of Imbalance

  1. Insufficient Learning: Few examples prevent the model from learning correctly.
  2. Generalization Difficulty: Difficult to generalize to new examples with few samples.
  3. Bias Toward Majority Classes: The model focuses on majority classes.
  4. Decision Boundary Problem: Few examples make it difficult to determine correct boundaries.

Solution Strategies

Various strategies can be applied to improve model performance in imbalanced datasets:

1. Data Balancing Techniques

Data augmentation and multiplication techniques for underrepresented classes:

  • Oversampling: Increase the number of data points for categories C and D
  • SMOTE: Generate synthetic examples
  • Data Augmentation: Augmentation techniques for text data
# Example of data augmentation for categories C and D
from nlpaug.augmenter.word import SynonymAug

syn_aug = SynonymAug()
augmented_samples = []

for text, label in minority_samples:
    for _ in range(50):
        aug_text = syn_aug.augment(text)
        augmented_samples.append((aug_text, label))

2. Model Improvement Strategies

Optimizing the training process:

  • Class Weighting: Assigning higher weights to categories C and D
  • Focal Loss: Loss function focusing on difficult examples
  • Threshold Optimization: Setting low prediction thresholds
# Class-weighted Focal Loss implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        CE_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-CE_loss)
        F_loss = (1 - pt)**self.gamma * CE_loss
        return torch.mean(F_loss)

# A:1, B:1, C:25, D:25 weights
class_weights = torch.tensor([1.0, 1.0, 25.0, 25.0])
criterion = FocalLoss(alpha=class_weights, gamma=3.0)

3. Ensemble Methods

Combining multiple models to improve performance:

  • Multiple Model Integration: Combination of models trained with different strategies
  • Bagging and Boosting: Using AdaBoost or Gradient Boosting
  • Cascade Classification: Staged classification approach

4. Advanced Techniques

Deep learning and meta-learning approaches:

  • Few-Shot Learning: Learning techniques with few examples
  • Meta-Learning: MAML or Prototypical Networks
  • Active Learning: Selection of the most informative examples

These techniques enhance the model's generalization ability, especially in situations with limited data.

Practical Implementation Steps

Imbalanced Dataset Improvement Plan

  1. Data Augmentation: Apply 50-100 times data augmentation for categories C and D.
    Use backtranslation, word replacement, and SMOTE techniques.
  2. Class Weighting: Use high weights (25-50x) for C and D.
    Give more importance to minority classes with Focal Loss.
  3. Two-Stage Training: First general training, then fine-tuning focused on minority classes.
    Lower the learning rate (e.g., 5e-6) and train longer in the second stage.
  4. Threshold Optimization: Set low threshold values for C and D (0.2-0.3).
    Sacrifice precision to increase recall.
  5. Ensemble Models: Combine models trained with different strategies.
    Give higher weight (0.3-0.4) to models focused on C and D.

Expected Improvement

Category C
Target:
0.50
Category D
Current:
0.14
Target:
0.45
Macro F1 Score
Current:
0.52
Target:
0.70

Real World Applications

Imbalanced datasets appear in various fields in the real world. Here are some example scenarios:

Medical Diagnosis

Diagnosis of rare diseases naturally involves imbalanced datasets.

Problem:

  • Very limited data for rare diseases
  • False negatives can have serious consequences

Solution:

  • Special loss functions that increase the cost of false negatives
  • Creating synthetic data with examples from different stages of disease
  • Few-shot learning techniques

Fraud Detection

Financial fraud detection typically involves extremely imbalanced data.

Problem:

  • Fraudulent transactions represent a very small percentage of total transactions
  • Patterns change rapidly

Solution:

  • Anomaly detection and one-class classification
  • Time-based sampling
  • Improving labeling efficiency with active learning

Fault Detection

Fault detection in industrial equipment is challenging due to limited fault data.

Problem:

  • Real fault data is scarce
  • May not have enough examples for each fault type

Solution:

  • Creating fault data with physics-based simulations
  • Transfer learning
  • Hybrid supervised/unsupervised learning methods

Advanced Resources

You can check the following resources to learn more about classification with imbalanced datasets:

Online Courses

  • "Practical Solutions for Class Imbalance" - Kaggle Courses
  • "Advanced NLP with BERT" - TensorFlow Data and Deployment Specialization
  • "Anomaly Detection in Time Series Data with Keras" - Coursera
  • "Machine Learning with Imbalanced Datasets" - edX
ease;">

Academic Papers

Code Libraries

Basic Values
Values obtained from the confusion matrix for category C:
Step-by-Step Metric Calculations
1. Precision
Formula: Precision = TP / (TP + FP)
Calculation: 2 / (2 + 8) = 0.200 = 20.0%
2. Recall (Sensitivity)
Formula: Recall = TP / (TP + FN)
Calculation: 2 / (2 + 8) = 0.200 = 20.0%
3. F1 Score
Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
Calculation: 2 * (0.200 * 0.200) / (0.200 + 0.200) = 0.200 = 20.0%
Result
These values indicate that the model is failing to detect category C effectively.
Metric Comparison for All Categories 0.0 0.2 0.4 0.6 0.8 1.0 Categories Value A (100 examples) B (100 examples) C (10 examples) D (10 examples) 0.86 0.85 0.85 0.87 0.92 0.90 0.20 0.20 0.20 0.25 0.09 0.13 Precision Recall F1 Score High number of examples Low number of examples Performance decreases as the number of examples decreases
Calculated Metrics for All Categories - Summary Table
Summary table of precision, recall, and F1 scores calculated for each category:
Metrics by Category
Category Number of Examples Precision Recall F1 Score TP FP FN TN
A 100 0.859 0.85 0.854 85 14 15 137
B 100 0.870 0.923 0.896 120 18 10 103
C 10 0.200 0.20 0.200 2 8 8 233
D 10 0.250 0.091 0.133 1 3 10 237
Calculation Details and Formulas
Precision: TP / (TP + FP)
Recall (Sensitivity): TP / (TP + FN)
F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
Average Metrics
Average Type Precision Recall F1 Score Description
Macro-average 0.545 0.516 0.521 Gives equal weight to each class.
Weighted Average 0.807 0.789 0.798 Weighted by the number of examples in each class.
Micro-average 0.840 0.800 0.820 Calculated across all classes.
Impact of Data Imbalance
  • A and B (100 examples): F1 ≈ 0.85-0.90
  • C and D (10 examples): F1 ≈ 0.13-0.20