Precision-Recall Curve
Graphical representation of classification model performance showing trade-off between precision and recall across thresholds.
What is a Precision-Recall Curve?
A Precision-Recall (PR) curve is a graphical representation of a classification model's performance that illustrates the trade-off between precision (positive predictive value) and recall (sensitivity) across different decision thresholds. It provides a comprehensive view of how well a model performs in identifying positive instances while maintaining accuracy in its positive predictions.
Key Concepts
Precision-Recall Curve Components
graph TD
A[Precision-Recall Curve] --> B[Precision]
A --> C[Recall]
A --> D[Threshold]
A --> E[Performance Metrics]
B --> B1[Precision = TP / (TP + FP)]
B --> B2[Positive Predictive Value]
C --> C1[Recall = TP / (TP + FN)]
C --> C2[Sensitivity]
C --> C3[True Positive Rate]
D --> D1[Decision Threshold]
D --> D2[Varies from 0 to 1]
E --> E1[AUC-PR]
E --> E2[F1-Score]
E --> E3[Optimal Threshold]
style A fill:#f9f,stroke:#333
style B fill:#cfc,stroke:#333
style C fill:#fcc,stroke:#333
Core Metrics
| Metric | Formula | Interpretation |
|---|---|---|
| Precision | TP / (TP + FP) | Positive predictive value |
| Recall | TP / (TP + FN) | Sensitivity, true positive rate |
| F1-Score | 2 × (Precision × Recall) / (Precision + Recall) | Harmonic mean of precision and recall |
| Support | TP + FN | Total positive instances |
Mathematical Foundations
PR Curve Construction
The PR curve is constructed by plotting precision against recall at various threshold settings:
- Sort predictions: Order predicted probabilities from highest to lowest
- Vary threshold: Move threshold from 1 to 0
- Calculate precision/recall: At each threshold, compute precision and recall
- Plot points: Connect the (recall, precision) points
Area Under the PR Curve (AUC-PR)
The AUC-PR represents the average precision across all recall levels:
$$AUC_ = \int_{0}^{1} Precision(Recall) , d(Recall)$$
Where:
- $Precision(Recall)$ is precision as a function of recall
- AUC-PR ranges from 0 to 1
- Higher values indicate better performance
Relationship to Other Metrics
The PR curve is particularly useful for imbalanced datasets where the ROC curve can be misleading:
- For balanced datasets: ROC and PR curves provide similar information
- For imbalanced datasets: PR curve focuses on the positive class performance
- AUC-PR vs AUC-ROC: AUC-PR is more informative when positive class is rare
Applications
Model Evaluation
- Imbalanced Classification: Fraud detection, rare disease diagnosis
- Information Retrieval: Search engine performance
- Model Comparison: Comparing different algorithms
- Threshold Selection: Choosing optimal decision threshold
- Performance Assessment: Evaluating model effectiveness
Performance Analysis
- Positive Class Focus: Performance on the class of interest
- Threshold Optimization: Finding best trade-off point
- Model Selection: Choosing between different models
- Feature Importance: Evaluating feature impact on positive class
- Error Analysis: Understanding model weaknesses on positive class
Industry Applications
- Healthcare: Rare disease detection
- Finance: Fraud detection systems
- Security: Intrusion detection
- Marketing: Customer churn prediction
- Manufacturing: Defect detection
Implementation
Basic PR Curve
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# Generate synthetic data with class imbalance
X, y = make_classification(n_samples=1000, n_classes=2,
weights=[0.9, 0.1], random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train model
model = LogisticRegression(class_weight='balanced')
model.fit(X_train, y_train)
# Get predicted probabilities
y_scores = model.predict_proba(X_test)[:, 1]
# Compute PR curve
precision, recall, thresholds = precision_recall_curve(y_test, y_scores)
average_precision = average_precision_score(y_test, y_scores)
# Plot PR curve
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='darkorange', lw=2,
label=f'PR curve (AP = {average_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="upper right")
plt.grid(True)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
# Add baseline (proportion of positive class)
pos_proportion = np.sum(y_test) / len(y_test)
plt.plot([0, 1], [pos_proportion, pos_proportion], 'k--',
label=f'Random (AP = {pos_proportion:.2f})')
plt.legend(loc="lower left")
plt.show()
# Find optimal threshold (maximize F1-score)
f1_scores = 2 * (precision * recall) / (precision + recall)
f1_scores = np.nan_to_num(f1_scores) # Handle division by zero
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx]
print(f"Optimal threshold (max F1): {optimal_threshold:.4f}")
print(f"Precision at optimal: {precision[optimal_idx]:.4f}")
print(f"Recall at optimal: {recall[optimal_idx]:.4f}")
print(f"F1-score at optimal: {f1_scores[optimal_idx]:.4f}")
Multi-Class PR Curve
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from itertools import cycle
# Generate multi-class data with imbalance
X, y = make_classification(n_samples=1000, n_classes=3, n_informative=5,
weights=[0.7, 0.2, 0.1], random_state=42)
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train model
classifier = OneVsRestClassifier(LogisticRegression(class_weight='balanced'))
classifier.fit(X_train, y_train)
# Get predicted probabilities
y_score = classifier.predict_proba(X_test)
# Compute PR curve for each class
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_score[:, i])
average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
# Compute micro-average PR curve
precision["micro"], recall["micro"], _ = precision_recall_curve(
y_test.ravel(), y_score.ravel()
)
average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
# Plot PR curves
plt.figure(figsize=(8, 6))
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
plt.plot(recall[i], precision[i], color=color, lw=2,
label=f'Class {i} (AP = {average_precision[i]:.2f})')
plt.plot(recall["micro"], precision["micro"], color='deeppink', linestyle=':', lw=4,
label=f'Micro-average (AP = {average_precision["micro"]:.2f})')
# Add baseline for each class
for i in range(n_classes):
pos_proportion = np.sum(y_test[:, i]) / len(y_test)
plt.plot([0, 1], [pos_proportion, pos_proportion], '--', color=colors[i],
alpha=0.5, label=f'Random class {i}' if i == 0 else "")
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Multi-Class Precision-Recall Curve')
plt.legend(loc="lower left")
plt.grid(True)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.show()
# Print class distribution and AP scores
print("Class distribution in test set:")
for i in range(n_classes):
print(f"Class {i}: {np.sum(y_test[:, i])} positive samples ({np.sum(y_test[:, i])/len(y_test)*100:.1f}%)")
print("\nAverage Precision scores:")
for i in range(n_classes):
print(f"Class {i}: {average_precision[i]:.4f}")
print(f"Micro-average: {average_precision['micro']:.4f}")
Interactive PR Curve
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def plot_interactive_pr_curve(precision, recall, average_precision, thresholds=None):
"""Create interactive PR curve visualization"""
fig = go.Figure()
# Add PR curve
fig.add_trace(go.Scatter(
x=recall, y=precision,
name=f'PR curve (AP = {average_precision:.2f})',
line=dict(color='darkorange', width=2),
hovertemplate='Recall: %{x:.2f}<br>Precision: %{y:.2f}<extra></extra>'
))
# Add baseline
pos_proportion = precision[-1] # Precision at recall=0 is the positive proportion
fig.add_trace(go.Scatter(
x=[0, 1], y=[pos_proportion, pos_proportion],
name=f'Random (AP = {pos_proportion:.2f})',
line=dict(color='navy', width=2, dash='dash')
))
# Add threshold markers if provided
if thresholds is not None:
# Sample thresholds for visualization
sample_indices = np.linspace(0, len(thresholds)-1, 20, dtype=int)
for idx in sample_indices:
fig.add_trace(go.Scatter(
x=[recall[idx]], y=[precision[idx]],
mode='markers',
marker=dict(size=8, color='red'),
name=f'Threshold: {thresholds[idx]:.2f}',
hovertemplate=f'Threshold: {thresholds[idx]:.2f}<br>Recall: {recall[idx]:.2f}<br>Precision: {precision[idx]:.2f}<extra></extra>',
showlegend=False
))
# Update layout
fig.update_layout(
title='Interactive Precision-Recall Curve',
xaxis_title='Recall',
yaxis_title='Precision',
xaxis=dict(range=[0, 1], constrain='domain'),
yaxis=dict(range=[0, 1.05]),
width=800,
height=600,
hovermode='closest'
)
fig.show()
# Example usage
plot_interactive_pr_curve(precision, recall, average_precision, thresholds)
Performance Optimization
Threshold Selection Methods
| Method | Description | Formula |
|---|---|---|
| Max F1-Score | Maximizes harmonic mean of precision and recall | $F1 = \max(2 \cdot \frac{Precision \cdot Recall}{Precision + Recall})$ |
| Precision at K | Maximizes precision at specific recall level | $P@K = \max(Precision)$ where $Recall \geq K$ |
| Youden's J | Maximizes (sensitivity + specificity - 1) | $J = \max(Sensitivity + Specificity - 1)$ |
| Cost-Based | Minimizes expected cost | $C = \min(C_ \cdot FP + C_ \cdot FN)$ |
| Break-Even Point | Point where precision equals recall | $P = R$ |
Cost-Sensitive PR Analysis
def cost_sensitive_pr_analysis(precision, recall, thresholds, cost_fp=1, cost_fn=1):
"""Perform cost-sensitive PR analysis"""
# Calculate cost at each threshold
# Note: We need to estimate FP and FN counts
# For simplicity, we'll use the precision and recall values
# In practice, you would use actual counts
# Estimate FP and FN rates
# FP_rate = (1 - precision) * (TP + FP) / (TN + FP) - this is complex without counts
# Instead, we'll use a simplified cost function
# Cost function: C = cost_fp * (1 - precision) + cost_fn * (1 - recall)
costs = cost_fp * (1 - precision[:-1]) + cost_fn * (1 - recall[:-1])
# Find optimal threshold
optimal_idx = np.argmin(costs)
optimal_threshold = thresholds[optimal_idx]
optimal_cost = costs[optimal_idx]
# Calculate cost-adjusted metrics
cost_adjusted_f1 = 1 - optimal_cost / max(cost_fp, cost_fn)
return {
'optimal_threshold': optimal_threshold,
'optimal_cost': optimal_cost,
'cost_adjusted_f1': cost_adjusted_f1,
'precision_at_optimal': precision[optimal_idx],
'recall_at_optimal': recall[optimal_idx],
'all_costs': costs
}
# Example with different cost scenarios
print("Standard Cost Scenario (FP=1, FN=1):")
results = cost_sensitive_pr_analysis(precision, recall, thresholds)
print(f"Optimal threshold: {results['optimal_threshold']:.4f}")
print(f"Optimal cost: {results['optimal_cost']:.4f}")
print(f"Precision at optimal: {results['precision_at_optimal']:.4f}")
print(f"Recall at optimal: {results['recall_at_optimal']:.4f}")
print("\nHigh Cost for False Negatives (e.g., medical diagnosis):")
results = cost_sensitive_pr_analysis(precision, recall, thresholds, cost_fp=1, cost_fn=10)
print(f"Optimal threshold: {results['optimal_threshold']:.4f}")
print(f"Optimal cost: {results['optimal_cost']:.4f}")
print("\nHigh Cost for False Positives (e.g., spam filtering):")
results = cost_sensitive_pr_analysis(precision, recall, thresholds, cost_fp=5, cost_fn=1)
print(f"Optimal threshold: {results['optimal_threshold']:.4f}")
print(f"Optimal cost: {results['optimal_cost']:.4f}")
PR Curve Comparison
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
def compare_models_pr(X, y, models, model_names):
"""Compare PR curves of multiple models"""
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
plt.figure(figsize=(10, 8))
for model, name in zip(models, model_names):
# Train model
if hasattr(model, 'class_weight'):
model.set_params(class_weight='balanced')
model.fit(X_train, y_train)
# Get predicted probabilities
if hasattr(model, "predict_proba"):
y_scores = model.predict_proba(X_test)[:, 1]
else: # For models without predict_proba
y_scores = model.decision_function(X_test)
# Compute PR curve
precision, recall, _ = precision_recall_curve(y_test, y_scores)
average_precision = average_precision_score(y_test, y_scores)
# Plot PR curve
plt.plot(recall, precision, lw=2,
label=f'{name} (AP = {average_precision:.2f})')
# Add baseline
pos_proportion = np.sum(y_test) / len(y_test)
plt.plot([0, 1], [pos_proportion, pos_proportion], 'k--', lw=2,
label=f'Random (AP = {pos_proportion:.2f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Model Comparison with Precision-Recall Curves')
plt.legend(loc="lower left")
plt.grid(True)
plt.show()
# Example comparison with imbalanced data
X, y = make_classification(n_samples=1000, n_classes=2,
weights=[0.9, 0.1], random_state=42)
models = [
LogisticRegression(class_weight='balanced'),
RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42),
SVC(probability=True, class_weight='balanced', random_state=42)
]
model_names = ['Logistic Regression', 'Random Forest', 'SVM']
compare_models_pr(X, y, models, model_names)
Challenges
Interpretation Challenges
- Class Imbalance: PR curves are particularly sensitive to class distribution
- Threshold Dependence: Performance varies with threshold
- Multiple Classes: Complexity increases with multi-class problems
- Cost Sensitivity: Doesn't account for different error costs
- Context Dependence: Needs domain-specific interpretation
Practical Challenges
- Data Quality: Sensitive to labeling errors
- Model Selection: Different models may have similar curves
- Threshold Selection: Choosing optimal threshold
- Visualization: Hard to visualize for many models
- Comparison: Comparing curves across different datasets
Technical Challenges
- Computational Complexity: Calculating for large datasets
- Probability Calibration: Models need well-calibrated probabilities
- Statistical Significance: Determining meaningful differences
- Multi-Class Extension: Extending to multi-class problems
- Interpretability: Making results understandable to stakeholders
Research and Advancements
Key Developments
- "The Relationship Between Precision-Recall and ROC Curves" (Davis & Goadrich, 2006)
- Established formal relationship between PR and ROC curves
- Showed that PR curves are more informative for imbalanced data
- "A Systematic Analysis of Performance Measures for Classification Tasks" (Sokolova & Lapalme, 2009)
- Comprehensive analysis of evaluation metrics
- Guidelines for choosing between PR and ROC curves
- "The Precision-Recall Plot Is More Informative than the ROC Plot When Evaluating Binary Classifiers on Imbalanced Datasets" (Saito & Rehmsmeier, 2015)
- Demonstrated advantages of PR curves for imbalanced data
- Provided empirical evidence for PR curve superiority in imbalanced scenarios
Emerging Research Directions
- Cost-Sensitive PR Analysis: Incorporating error costs
- Dynamic PR Curves: Time-dependent PR analysis
- Uncertainty Quantification: Confidence intervals for PR curves
- Fairness-Aware PR: Bias detection in PR analysis
- Multi-Objective PR: Balancing multiple metrics
- Deep Learning PR: PR curve analysis for neural networks
- Causal PR: Causal interpretation of PR curves
- Explainable PR: Interpretable PR curve analysis
Best Practices
Design
- Class Definition: Clearly define positive/negative classes
- Evaluation Protocol: Use appropriate cross-validation
- Multiple Metrics: Use PR curves with other evaluation metrics
- Class Imbalance: Consider PR curves for imbalanced data
- Visualization: Use appropriate visualization techniques
Implementation
- Data Quality: Ensure high-quality labeled data
- Class Balance: Address severe class imbalance
- Probability Calibration: Calibrate model probabilities
- Multiple Models: Compare multiple models
- Cross-Validation: Use robust evaluation protocols
Analysis
- AUC-PR Interpretation: Understand AUC-PR limitations
- Threshold Selection: Choose appropriate threshold method
- Error Analysis: Investigate misclassified instances
- Feature Importance: Analyze feature impact on PR curve
- Domain Context: Interpret results in domain context
Reporting
- Complete Reporting: Report AUC-PR with confidence intervals
- Contextual Information: Provide domain context
- Visual Representation: Include PR curve visualizations
- Statistical Significance: Report p-values for comparisons
- Cost Analysis: Include cost-sensitive analysis when relevant