Class Imbalance¶
📖 Click to Expand
⚖️ What is Class Imbalance?¶
Class imbalance occurs when one class significantly outnumbers the other(s). For example, in fraud detection, 98% of transactions may be legitimate while only 2% are fraudulent. This skew can distort model learning.
- Models often learn to favor the majority class
- Overall accuracy may be misleading
- Minority class may be poorly predicted
🌍 Real-World Examples¶
- Healthcare: predicting rare diseases (e.g., cancer)
- Finance: credit card fraud detection
- Cybersecurity: detecting intrusions or malware
- Customer Success: identifying churned or escalated users
🚧 Challenges with Imbalanced Data¶
- Accuracy becomes misleading (e.g., always predicting the majority yields high accuracy but zero value)
- Decision boundaries may skew toward the dominant class
- Minority class data often lacks enough variation to generalize
🎯 Why It Matters¶
- Improves precision and recall for the minority class
- Makes models fairer and more useful in high-risk applications
- Avoids false confidence from inflated accuracy metrics
This notebook covers techniques to address imbalance using oversampling, undersampling, synthetic data generation, and algorithm-level adjustments.
import pandas as pd
import numpy as np
# Set random seed for reproducibility
np.random.seed(42)
# Number of samples
n_samples = 10_000
# Class distribution (98% legitimate, 2% fraudulent)
fraud_ratio = 0.02
legitimate_ratio = 1 - fraud_ratio
# Number of samples per class
n_fraud = int(n_samples * fraud_ratio)
n_legitimate = n_samples - n_fraud
# Generate features
# 1. Transaction Amount (log-normal distribution for skewness)
transaction_amount_legit = np.random.lognormal(mean=3.5, sigma=1, size=n_legitimate)
transaction_amount_fraud = np.random.lognormal(mean=5, sigma=1.5, size=n_fraud)
# 2. Transaction Type (categorical: online, in-person, other)
transaction_type_legit = np.random.choice(['Online', 'In-Person', 'Other'], size=n_legitimate, p=[0.5, 0.4, 0.1])
transaction_type_fraud = np.random.choice(['Online', 'In-Person', 'Other'], size=n_fraud, p=[0.7, 0.2, 0.1])
# 3. Time of Transaction (hour of the day: 0-23)
time_of_transaction_legit = np.random.randint(0, 24, size=n_legitimate)
time_of_transaction_fraud = np.random.randint(0, 24, size=n_fraud)
# Combine data
data = {
"Transaction_Amount": np.concatenate([transaction_amount_legit, transaction_amount_fraud]),
"Transaction_Type": np.concatenate([transaction_type_legit, transaction_type_fraud]),
"Time_of_Transaction": np.concatenate([time_of_transaction_legit, time_of_transaction_fraud]),
"Class": np.concatenate([np.zeros(n_legitimate), np.ones(n_fraud)]), # 0: Legitimate, 1: Fraud
}
# Create DataFrame
df = pd.DataFrame(data)
# Shuffle the dataset
df = df.sample(frac=1).reset_index(drop=True)
df.head()
Transaction_Amount | Transaction_Type | Time_of_Transaction | Class | |
---|---|---|---|---|
0 | 16.360820 | In-Person | 3 | 0.0 |
1 | 7.728736 | In-Person | 17 | 0.0 |
2 | 19.464428 | In-Person | 6 | 0.0 |
3 | 1.828955 | Online | 15 | 0.0 |
4 | 7.894720 | Online | 10 | 0.0 |
⚖️ Oversampling¶
📖 Click to Expand
⚖️ What is Oversampling?¶
Oversampling is a technique to address class imbalance by increasing the number of samples in the minority class.
- Involves duplicating or synthetically generating minority class examples
- Aims to balance the class distribution without losing any data
- Typically applied before model training on the training set only
📌 When to Use¶
- Severe imbalance causing poor recall or precision
- You want to preserve all majority class data
- You're using models that can handle redundant data (e.g., tree-based models)
⚠️ Cautions¶
- Risk of overfitting if the same samples are duplicated
- Synthetic methods (like SMOTE) can introduce noise or unrealistic data points
Create Datasets
# Install required libraries
# pip install imbalanced-learn scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import pandas as pd
# Create a toy imbalanced dataset
X, y = make_classification(n_classes=2, class_sep=2,
weights=[0.9, 0.1], n_informative=3, n_redundant=1,
flip_y=0, n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=42)
print(pd.DataFrame(X))
print(y)
# Visualize original class distribution
print(pd.Series(y).value_counts())
0 1 2 3 4 0 1.021519 -0.548166 -2.079610 2.988799 -2.282637 1 0.210297 0.378811 -1.558116 1.737750 -2.083806 2 0.312219 0.266673 -0.873507 2.422799 -2.373239 3 0.613577 -0.487993 -1.824985 2.346751 -2.172025 4 -0.062001 -0.837709 -1.414938 1.124668 -1.826684 .. ... ... ... ... ... 995 0.141158 1.410214 -2.320817 0.749351 -1.453040 996 0.580855 0.286195 -2.667136 1.670716 -1.840654 997 0.769980 -1.460925 -3.103829 1.648837 -1.679590 998 1.133483 0.553364 -3.761090 1.939160 -1.648539 999 0.457990 -2.664332 -0.959610 2.629161 -2.386679 [1000 rows x 5 columns] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0] 0 900 1 100 dtype: int64
Define Functions - Plot Before vs After
import matplotlib.pyplot as plt
import numpy as np
# Function to visualize before and after resampling
def plot_before_after(X_before, y_before, X_after, y_after, technique_name, feature_names=None):
"""
Plots before and after resampling for a dataset.
Parameters:
X_before (array): Original feature data before resampling.
y_before (array): Target labels before resampling.
X_after (array): Feature data after resampling.
y_after (array): Target labels after resampling.
technique_name (str): Name of the resampling technique.
feature_names (list or None): Feature names for x-axis and y-axis labels. Should be a list like ['Feature 1', 'Feature 2'].
"""
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharex=True, sharey=True)
# Set feature names for axes
x_label = feature_names[0] if feature_names else 'Feature 1'
y_label = feature_names[1] if feature_names else 'Feature 2'
# Get unique labels dynamically for legends
labels_before = np.unique(y_before)
labels_after = np.unique(y_after)
# Before Resampling
for label in labels_before:
axes[0].scatter(X_before[y_before == label][:, 0], X_before[y_before == label][:, 1],
label=f'Class {label}', alpha=0.7)
axes[0].set_title(f"Before Resampling ({technique_name})")
axes[0].set_xlabel(x_label)
axes[0].set_ylabel(y_label)
axes[0].legend()
axes[0].grid(True)
# After Resampling
for label in labels_after:
axes[1].scatter(X_after[y_after == label][:, 0], X_after[y_after == label][:, 1],
label=f'Class {label}', alpha=0.7)
axes[1].set_title(f"After Resampling ({technique_name})")
axes[1].set_xlabel(x_label)
axes[1].set_ylabel(y_label)
axes[1].legend()
axes[1].grid(True)
# Adjust layout
plt.tight_layout()
plt.show()
🎲 Random Oversampling¶
📖 Click to Expand
🎲 What is Random Oversampling?¶
Random Oversampling duplicates existing minority class samples until class balance is achieved.
- Simple and fast to implement
- No synthetic data — just replication
- Often used as a baseline technique
⚠️ Limitations¶
- High risk of overfitting to duplicated samples
- No increase in data diversity
# !pip uninstall -y imbalanced-learn scikit-learn
# !pip install scikit-learn==1.2.2
# !pip install imbalanced-learn==0.10.1
# Random Oversampling
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X, y)
# Check new class distribution
print(pd.Series(y_resampled).value_counts())
0 900 1 900 dtype: int64
plot_before_after(X, y, X_resampled, y_resampled, "Random Oversampling")
🧪 SMOTE¶
📖 Click to Expand
🧪 What is SMOTE?¶
SMOTE (Synthetic Minority Over-sampling Technique) generates new, synthetic minority class samples by interpolating between existing ones.
- Increases minority class diversity
- Reduces overfitting compared to random duplication
- Widely used in practice
⚠️ Limitations¶
- Can create overlapping between classes if not used carefully
- Sensitive to noise and outliers in the minority class
- Not ideal for high-dimensional or categorical data without adaptation
# SMOTE
from imblearn.over_sampling import SMOTE
smote = SMOTE(random_state=42)
X_resampled_smote, y_resampled_smote = smote.fit_resample(X, y)
# Check new class distribution
print(pd.Series(y_resampled_smote).value_counts())
0 900 1 900 dtype: int64
plot_before_after(X, y, X_resampled_smote, y_resampled_smote, "SMOTE")
📈 ADASYN¶
📖 Click to Expand
📈 What is ADASYN?¶
ADASYN (Adaptive Synthetic Sampling) is a SMOTE variant that focuses on generating more synthetic samples in hard-to-learn minority regions.
- Adapts sampling density based on local data difficulty
- Prioritizes borderline or sparsely represented areas
- Encourages better model learning near class boundaries
⚠️ Limitations¶
- Can amplify noise if minority class is poorly defined
- Adds complexity compared to SMOTE
- May shift decision boundaries unexpectedly
# ADASYN
from imblearn.over_sampling import ADASYN
adasyn = ADASYN(random_state=42)
X_resampled_adasyn, y_resampled_adasyn = adasyn.fit_resample(X, y)
# Check new class distribution
print(pd.Series(y_resampled_adasyn).value_counts())
0 900 1 900 dtype: int64
plot_before_after(X, y, X_resampled_adasyn, y_resampled_adasyn, "ADASYN")
🧊 Borderline SMOTE¶
📖 Click to Expand
🧊 What is Borderline SMOTE?¶
Borderline SMOTE generates synthetic samples only for minority class points near the decision boundary, where misclassification risk is highest.
- Focuses on ambiguous regions between classes
- Improves classifier sensitivity at the boundary
- Less likely to oversample safe or noisy regions
⚠️ Limitations¶
- Still assumes meaningful interpolation in feature space
- Not suitable if boundary regions are heavily noisy or mislabeled
# Borderline-SMOTE
from imblearn.over_sampling import BorderlineSMOTE
borderline_smote = BorderlineSMOTE(kind='borderline-1', random_state=42)
X_resampled_borderline, y_resampled_borderline = borderline_smote.fit_resample(X, y)
# Check new class distribution
print(pd.Series(y_resampled_borderline).value_counts())
0 900 1 900 dtype: int64
plot_before_after(X, y, X_resampled_borderline, y_resampled_borderline, "Borderline-SMOTE")
📉 Undersampling¶
📖 Click to Expand
📉 What is Undersampling?¶
Undersampling reduces the number of majority class samples to balance class distribution.
- Works by removing redundant or less informative majority class data
- Helps speed up training and reduces memory usage
- Often paired with ensemble methods to preserve performance
📌 When to Use¶
- Abundant majority class data that’s easy to trim
- Class imbalance is mild-to-moderate
- When model overfitting to the majority class is a concern
⚠️ Limitations¶
- Risk of losing valuable information
- May lead to underperforming models if informative samples are discarded
Create Datasets
# Install required libraries
# pip install imbalanced-learn scikit-learn pandas matplotlib
from sklearn.datasets import make_classification
import pandas as pd
# Create a toy imbalanced dataset
X, y = make_classification(n_classes=2, class_sep=2,
weights=[0.9, 0.1], n_informative=3, n_redundant=1,
flip_y=0, n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=42)
# Convert to DataFrame for better visualization
df_X = pd.DataFrame(X, columns=[f"Feature_{i+1}" for i in range(X.shape[1])])
df_y = pd.Series(y, name="Target")
# Visualize class distribution
print("Original Class Distribution:")
print(df_y.value_counts())
Original Class Distribution: 0 900 1 100 Name: Target, dtype: int64
🎯 Random Undersampling¶
📖 Click to Expand
🎯 What is Random Undersampling?¶
Random Undersampling removes a subset of majority class samples at random to balance the dataset.
- Very easy to implement
- Reduces training time and memory usage
- Often used as a quick baseline method
⚠️ Limitations¶
- Can discard informative majority samples
- May lead to underfitting or unstable models if too aggressive
from imblearn.under_sampling import RandomUnderSampler
# Random Undersampling
rus = RandomUnderSampler(random_state=42)
X_resampled_rus, y_resampled_rus = rus.fit_resample(X, y)
# Visualize resampled class distribution
print("Random Undersampling Class Distribution:")
print(pd.Series(y_resampled_rus).value_counts())
Random Undersampling Class Distribution: 0 100 1 100 dtype: int64
# Plot before and after
plot_before_after(X, y, X_resampled_rus, y_resampled_rus, "Random Undersampling")
🔗 Tomek Links¶
📖 Click to Expand
🔗 What are Tomek Links?¶
A Tomek Link is a pair of samples from opposite classes that are each other’s nearest neighbors. Removing the majority class sample in each pair helps clean the class boundary.
- Used to reduce class overlap
- Improves boundary clarity between classes
- Often combined with undersampling or SMOTE
⚠️ Limitations¶
- Only removes borderline points — not sufficient alone for severe imbalance
- Assumes Euclidean distance is meaningful for proximity
from imblearn.under_sampling import TomekLinks
# Tomek Links
tomek = TomekLinks()
X_resampled_tomek, y_resampled_tomek = tomek.fit_resample(X, y)
# Visualize resampled class distribution
print("Tomek Links Class Distribution:")
print(pd.Series(y_resampled_tomek).value_counts())
plot_before_after(X, y, X_resampled_tomek, y_resampled_tomek, "Tomek Links")
Tomek Links Class Distribution: 0 900 1 100 dtype: int64
🧹 ENN¶
📖 Click to Expand
🧹 What is ENN?¶
Edited Nearest Neighbors (ENN) removes majority class samples that are misclassified by their nearest neighbors.
- Acts as a data cleaning filter
- Helps refine noisy or overlapping class boundaries
- Can improve classifier generalization
⚠️ Limitations¶
- May remove too many samples if class boundaries are unclear
- Sensitive to the choice of
k
(number of neighbors)
from imblearn.under_sampling import EditedNearestNeighbours
# Edited Nearest Neighbors (ENN)
enn = EditedNearestNeighbours(n_neighbors=3)
X_resampled_enn, y_resampled_enn = enn.fit_resample(X, y)
# Visualize resampled class distribution
print("Edited Nearest Neighbors Class Distribution:")
print(pd.Series(y_resampled_enn).value_counts())
plot_before_after(X, y, X_resampled_enn, y_resampled_enn, "ENN")
Edited Nearest Neighbors Class Distribution: 0 899 1 100 dtype: int64
📉 NearMiss¶
📖 Click to Expand
📉 What is NearMiss?¶
NearMiss is an undersampling technique that selects majority class samples closest to the minority class based on distance.
- Retains majority samples near the decision boundary
- Focuses learning on challenging cases
- Comes in different variants (NearMiss-1, 2, 3) depending on how distances are computed
⚠️ Limitations¶
- Can discard important global patterns in the majority class
- Sensitive to noisy or overlapping data
from imblearn.under_sampling import NearMiss
# NearMiss (Version 1)
nearmiss = NearMiss(version=1)
X_resampled_nearmiss, y_resampled_nearmiss = nearmiss.fit_resample(X, y)
# Visualize resampled class distribution
print("NearMiss Class Distribution:")
print(pd.Series(y_resampled_nearmiss).value_counts())
plot_before_after(X, y, X_resampled_nearmiss, y_resampled_nearmiss, "NearMiss")
NearMiss Class Distribution: 0 100 1 100 dtype: int64
🧬 Data Augmentation¶
📖 Click to Expand
🧬 What is Data Augmentation?¶
Data augmentation involves creating new synthetic samples to enrich the dataset, especially the minority class.
- Helps boost sample diversity without collecting new data
- Often used when oversampling alone is insufficient
- Techniques range from simple noise injection to generative models like GANs
📌 Benefits¶
- Improves model robustness
- Reduces overfitting on limited minority examples
- Encourages better generalization in real-world settings
🧠 Data Augmentation using GANs¶
📖 Click to Expand
🧠 What is GAN-based Augmentation?¶
Generative Adversarial Networks (GANs) can be used to create high-fidelity synthetic samples for the minority class.
- A GAN learns the data distribution and generates new, realistic examples
- Particularly useful for image, text, or structured data where diversity matters
- Helps capture nonlinear patterns that simpler methods miss
⚠️ Limitations¶
- Requires significant training time and tuning
- Risk of generating unrealistic or mode-collapsed samples
- Needs careful validation to avoid injecting noise
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# Generate synthetic 2D real data
def create_real_data(n_samples=100):
X1 = np.random.normal(loc=0.0, scale=1.0, size=(n_samples, 1))
X2 = np.random.normal(loc=0.0, scale=1.0, size=(n_samples, 1))
return np.hstack((X1, X2))
real_data = create_real_data(100)
# Define Generator
def build_generator(latent_dim, output_dim):
model = tf.keras.Sequential([
layers.Dense(16, activation='relu', input_dim=latent_dim),
layers.Dense(output_dim, activation='linear')
])
return model
# Define Discriminator
def build_discriminator(input_dim):
model = tf.keras.Sequential([
layers.Dense(16, activation='relu', input_dim=input_dim),
layers.Dense(1, activation='sigmoid')
])
return model
# Instantiate models
latent_dim = 2
generator = build_generator(latent_dim=latent_dim, output_dim=2)
discriminator = build_discriminator(input_dim=2)
# Optimizers
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# Loss function
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
# Training step
@tf.function
def train_step(real_data):
# Convert real_data to float32 for compatibility
real_data = tf.cast(real_data, tf.float32)
batch_size = real_data.shape[0]
# Generate fake data
noise = tf.random.normal([batch_size, latent_dim])
generated_data = generator(noise, training=True)
# Combine real and fake data
combined_data = tf.concat([real_data, generated_data], axis=0)
labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
# Add noise to labels for stability
labels += 0.05 * tf.random.uniform(labels.shape)
# Train discriminator
with tf.GradientTape() as disc_tape:
predictions = discriminator(combined_data, training=True)
disc_loss = cross_entropy(labels, predictions)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# Train generator
noise = tf.random.normal([batch_size, latent_dim])
misleading_labels = tf.ones((batch_size, 1)) # Labels trick discriminator into thinking generated data is real
with tf.GradientTape() as gen_tape:
generated_data = generator(noise, training=True)
predictions = discriminator(generated_data, training=True)
gen_loss = cross_entropy(misleading_labels, predictions)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
return disc_loss, gen_loss
# GAN training loop
def train_gan(generator, discriminator, real_data, epochs=1000):
for epoch in range(epochs):
disc_loss, gen_loss = train_step(real_data)
if epoch % 100 == 0:
print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
# Train GAN
train_gan(generator, discriminator, real_data, epochs=1000)
# Generate synthetic data
noise = tf.random.normal([100, latent_dim])
synthetic_data = generator(noise, training=False)
# Visualize real and synthetic data
plt.scatter(real_data[:, 0], real_data[:, 1], label="Real Data", alpha=0.7)
plt.scatter(synthetic_data[:, 0], synthetic_data[:, 1], label="Synthetic Data", alpha=0.7)
plt.legend()
plt.show()
/Users/ashrithreddy/anaconda3/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Epoch 0/1000, Discriminator Loss: 0.6918, Generator Loss: 0.6879 Epoch 100/1000, Discriminator Loss: 0.6609, Generator Loss: 0.6349 Epoch 200/1000, Discriminator Loss: 0.7305, Generator Loss: 0.6583 Epoch 300/1000, Discriminator Loss: 0.6890, Generator Loss: 0.7710 Epoch 400/1000, Discriminator Loss: 0.6894, Generator Loss: 0.6655 Epoch 500/1000, Discriminator Loss: 0.6957, Generator Loss: 0.6012 Epoch 600/1000, Discriminator Loss: 0.6843, Generator Loss: 0.6411 Epoch 700/1000, Discriminator Loss: 0.6939, Generator Loss: 0.6868 Epoch 800/1000, Discriminator Loss: 0.7018, Generator Loss: 0.6934 Epoch 900/1000, Discriminator Loss: 0.6937, Generator Loss: 0.6285
⚖️ Class Weighting¶
📖 Click to Expand
⚖️ What is Class Weighting?¶
Class weighting assigns higher penalty to errors on the minority class during training, guiding the model to pay more attention to underrepresented outcomes.
- Implemented via
class_weight='balanced'
in many ML libraries - Adjusts the loss function without modifying the data
- Especially useful for models like logistic regression, SVMs, and neural networks
📌 When to Use¶
- When resampling isn’t feasible or effective
- When the algorithm natively supports weighted loss
- To avoid modifying data distribution through oversampling/undersampling
📊 Logistic Regression Example¶
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Generate an imbalanced dataset
X, y = make_classification(n_classes=2, class_sep=2,
weights=[0.9, 0.1], n_informative=3, n_redundant=1,
flip_y=0, n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=42)
# Split dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y_train)
class_weight_dict = dict(enumerate(class_weights))
print("Class Weights:", class_weight_dict)
# Train a Logistic Regression model with class weights
model = LogisticRegression(class_weight=class_weight_dict, random_state=42)
model.fit(X_train, y_train)
# Evaluate the model
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))
Class Weights: {0: 0.5564387917329093, 1: 4.929577464788732} precision recall f1-score support 0 1.00 1.00 1.00 271 1 1.00 1.00 1.00 29 accuracy 1.00 300 macro avg 1.00 1.00 1.00 300 weighted avg 1.00 1.00 1.00 300
⚙️ Algorithm Level Techniques¶
📖 Click to Expand
⚙️ What Are Algorithm-Level Techniques?¶
Algorithm-level techniques address class imbalance within the learning algorithm itself, without altering the data.
🔧 Cost-Sensitive Learning¶
- Modifies the loss function to assign higher penalties to misclassified minority class samples
- Common in logistic regression, SVMs, and neural networks via
class_weight
or custom loss
🧠 Ensemble Approaches¶
- Use internal balancing or adaptive weighting to improve minority class performance
- Examples:
- Balanced Random Forest: samples each tree’s data to be class-balanced
- EasyEnsemble: trains multiple models on balanced subsets and combines predictions
- Boosting with Class Weights: supported in XGBoost, LightGBM, CatBoost