fundamental
SateQuest Notes
A Gentle Introduction to Machine Learning
- Important Concepts:
- Machine learning is about predictions and classifications.
- The bias-variance trade-off:
- A model that fits training data well but performs poorly on new data demonstrates this trade-off.
- Fancy methods (like deep learning) aren’t necessarily better. What matters is how well they predict new data.
- The best method is the one that performs well on testing data, not just training data.
- There are methods for choosing which data to use for training vs. testing, which can be explored further.
Machine Learning Fundamentals: Cross Validation
-
Predicting Heart Disease:
- We want to predict if someone has heart disease based on variables like chest pain, blood circulation, etc.
- Once a new patient arrives, we can measure these variables to make a prediction.
-
Choosing the Best Machine Learning Method:
- Common machine learning methods:
- Logistic regression
- K-nearest neighbors (KNN)
- Support vector machines (SVM)
- Many more methods are available.
- How do we choose the best one? The answer is cross-validation.
- Common machine learning methods:
-
Cross-Validation Overview:
- Cross-validation helps compare different machine learning methods and assess their
-
Two Key Steps in Machine Learning:
- Training the algorithm:
- We use some of the data to estimate the parameters of a model, e.g., fitting a curve for logistic regression.
- In machine learning lingo, this step is called training.
- Testing the algorithm:
- We evaluate how well the trained method works by testing it on new data.
- In machine learning lingo, this is called testing.
- Training the algorithm:
-
Why We Don’t Use All Data for Training:
- Using all the data for training leaves no data to test the method.
- Reusing the same data for both training and testing results in biased performance.
-
A Better Approach:
- Divide the data: Use the first 75% for training and the last 25% for testing.
- This allows us to compare methods by seeing how well they categorize the test data.
-
The Challenge of Splitting Data:
- How do we know that using the first 75% for training and the last 25% for testing is the best split?
- What if using another portion of the data, like the first 25% for testing, would give better results?
-
Cross-Validation Solution:
- Cross-validation divides the data into multiple blocks and tests each one individually.
- Example:
- Start by using the first three blocks for training, and the last block for testing.
- Keep track of how well the method performs.
- Rotate the blocks: Use a different set of blocks for training and another block for testing.
- Repeat until each block has been used for testing.
- This process allows us to compare methods and find which one performs best across all test blocks.
-
Types of Cross-Validation:
- Four-fold cross-validation: The data is divided into four blocks (as in the example above).
- Leave-one-out cross-validation: Each individual sample is treated as a block.
- Ten-fold cross-validation: Common in practice, where the data is divided into ten blocks.
-
Tuning Parameters:
- Some methods, like ridge regression, involve a tuning parameter that needs to be guessed (not estimated from the data).
- Cross-validation, such as ten-fold cross-validation, can help determine the best value for this tuning parameter.
Machine Learning Fundamentals: The Confusion Matrix
-
Predicting Heart Disease with Machine Learning:
- We have clinical measurements like chest pain, blood circulation, blocked arteries, and weight.
- We want to predict whether someone will develop heart disease using machine learning methods.
- Methods to consider:
- Logistic regression
- K-nearest neighbors (KNN)
- Random forest
- Many other machine learning methods.
-
How to Choose the Best Method:
- Step 1: Divide the data into training and testing sets.
- Step 2: Train all selected machine learning methods using the training data.
- Step 3: Test each method on the testing data.
- Cross-validation can be used to improve the reliability of these steps (see StatQuest for more details).
-
Evaluating Model Performance:
- To evaluate each method, we use a confusion matrix.
- Rows: Predicted values from the algorithm.
- Columns: Actual known values.
- Since we are predicting heart disease or no heart disease, the confusion matrix has 2 rows and 2 columns.
- To evaluate each method, we use a confusion matrix.
-
Confusion Matrix Breakdown:
- True Positives (Top left): Patients with heart disease correctly classified as having heart disease.
- True Negatives (Bottom right): Patients without heart disease correctly classified as not having heart disease.
- False Negatives (Bottom left): Patients with heart disease incorrectly classified as not having it.
- False Positives (Top right): Patients without heart disease incorrectly classified as having it.
-
Example with Random Forest:
- True positives: 142 (correctly classified as having heart disease).
- True negatives: 110 (correctly classified as not having heart disease).
- False negatives: 29 (incorrectly classified as not having heart disease).
- False positives: 22 (incorrectly classified as having heart disease).
- Green boxes on the diagonal represent correct classifications.
- Red boxes off the diagonal represent errors.
-
Comparing Methods:
- Comparing random forest to k-nearest neighbors (KNN):
- Random forest correctly classified more patients with heart disease (142 vs. 107).
- Random forest also correctly classified more patients without heart disease (110 vs. 79).
- If we had to choose between the two, we would select random forest.
- Comparing random forest to k-nearest neighbors (KNN):
-
Logistic Regression Example:
- When applied, logistic regression produces a confusion matrix similar to random forest, making it hard to decide which method is better.
- More advanced metrics like sensitivity, specificity, ROC, and AUC can help us make a decision (covered in the next StatQuest).
Machine Learning Fundamentals: Sensitivity and Specificity
-
Recap of the Confusion Matrix (2x2 Example):
- Rows = predicted values.
- Columns = known truth.
- Categories: Has heart disease or does not have heart disease.
- True positives (top left): Patients with heart disease correctly identified.
- True negatives (bottom right): Patients without heart disease correctly identified.
- False negatives (bottom left): Patients with heart disease incorrectly classified as not having it.
- False positives (top right): Patients without heart disease incorrectly classified as having it.
-
Calculating Sensitivity and Specificity (2x2 Matrix):
- Sensitivity: Measures how well the model correctly identifies patients with heart disease.
- Formula: Sensitivity = True Positives / (True Positives + False Negatives).
- Specificity: Measures how well the model correctly identifies patients without heart disease.
- Formula: Specificity = True Negatives / (True Negatives + False Positives).
- Sensitivity: Measures how well the model correctly identifies patients with heart disease.
-
Example with Logistic Regression:
- Confusion Matrix:
- True Positives = 139, False Negatives = 32.
- True Negatives = 112, False Positives = 20.
- Sensitivity = 139 / (139 + 32) = 0.81.
- 81% of patients with heart disease were correctly identified.
- Specificity = 112 / (112 + 20) = 0.85.
- 85% of patients without heart disease were correctly identified.
- Confusion Matrix:
-
Example with Random Forest:
- Confusion Matrix:
- True Positives = 142, False Negatives = 29.
- True Negatives = 110, False Positives = 22.
- Sensitivity = 142 / (142 + 29) = 0.83.
- 83% of patients with heart disease were correctly identified.
- Specificity = 110 / (110 + 22) = 0.83.
- 83% of patients without heart disease were correctly identified.
- Confusion Matrix:
-
Comparison of Logistic Regression vs. Random Forest:
- Random forest: Higher sensitivity (better at identifying patients with heart disease).
- Logistic regression: Higher specificity (better at identifying patients without heart disease).
- Choice of method depends on which is more important:
- Prioritize sensitivity if identifying patients with heart disease is crucial.
- Prioritize specificity if identifying patients without heart disease is more important.
-
Sensitivity and Specificity for Larger Confusion Matrices (3x3 Example):
- When the confusion matrix has three or more categories, sensitivity and specificity must be calculated for each category individually.
Machine Learning Fundamentals: Bias and Variance
-
Introduction:
- We have data on the weight and height of several mice.
- Lighter mice tend to be shorter, while heavier mice tend to be taller.
- After a certain weight, however, mice stop getting taller and just become more obese.
- The goal is to predict mouse height based on its weight.
-
Modeling the Relationship:
- Ideally, we’d have the exact mathematical formula that describes the relationship between weight and height.
- Since we don’t know the formula, we’ll use two machine learning methods to approximate the relationship.
- The true relationship curve will remain in the figure for reference.
-
Step 1: Splitting the Data:
- The data is split into two sets:
- Training set: Used to train the machine learning models (blue dots).
- Testing set: Used to evaluate the models (green dots).
- The data is split into two sets:
-
Linear Regression (Least Squares):
- The first method we use is linear regression, which fits a straight line to the training data.
- A straight line doesn't have the flexibility to accurately capture the arc in the true relationship between weight and height.
- No matter how well we fit it to the training data, the line won’t capture the curve.
- The inability of the model to capture the true relationship is called bias.
- Linear regression has high bias because it cannot adjust to the arc.
-
Flexible Model (Squiggly Line):
- Another method might fit a squiggly line to the training data.
- This line is very flexible and follows the arc of the true relationship closely, meaning it has low bias.
- It hugs the training data tightly and does a much better job of capturing the true relationship than the straight line.
-
Comparing Fits:
- We can compare the performance of the two models by calculating the sums of squares.
- This involves measuring the distances between the predicted values and the actual data points, squaring those distances, and summing them up.
- Distances are squared so that negative and positive values don’t cancel each other out.
- Squiggly line: Fits the training data perfectly, so the distances are zero.
- Linear regression: Fits the data less perfectly, resulting in a larger sum of squares.
- In the training set, the squiggly line wins because it fits the data better.
-
Step 2: Testing on New Data:
- Now we evaluate the models on the testing set.
- The squiggly line, which did well on the training set, performs poorly on the testing set.
- The straight line, despite being less flexible, performs better on the testing set.
-
Understanding Bias and Variance:
- Bias: The error introduced by the model’s assumptions (e.g., a straight line assuming the relationship is linear).
- The straight line has high bias but low variance (it performs consistently across datasets).
- The squiggly line has low bias (can capture the curve) but high variance (inconsistent performance across datasets).
- Variance: The variability of a model’s performance on different datasets.
- A model with high variance performs well on the training set but poorly on the testing set (i.e., it’s inconsistent).
- Overfitting: When a model (like the squiggly line) fits the training data too well but doesn’t generalize well to new data.
- Bias: The error introduced by the model’s assumptions (e.g., a straight line assuming the relationship is linear).
-
Finding the Sweet Spot:
- The ideal machine learning algorithm has low bias (can capture the true relationship) and low variance (makes consistent predictions).
- This involves finding a balance between simple models (like linear regression) and complex models (like the squiggly line).
-
Overfitting: When a model performs well on training data but poorly on testing data due to high variance.
- Methods to find the balance between simple and complex models:
- Regularization: A technique to reduce overfitting by controlling the flexibility of the model.
- Boosting: A method that builds strong models by combining weaker ones.
- Bagging: A technique that involves combining the predictions of multiple models to reduce variance (e.g., used in Random Forest).
- Methods to find the balance between simple and complex models:
ROC and AUC, Clearly Explained!
-
Introduction:
- We have data on mouse weight and whether a mouse is obese or not obese.
- The blue dots represent obese mice, and the red dots represent mice that are not obese.
- We want to predict if a mouse is obese based on its weight.
-
Logistic Regression Model:
- We fit a logistic regression curve to the data.
- The y-axis now represents the probability that a mouse is obese.
- The x-axis remains the weight of the mouse.
- The curve predicts:
- High probability of obesity for a heavy mouse.
- Low probability of obesity for a light mouse.
- Logistic regression gives us the probability of obesity based on weight.
- We fit a logistic regression curve to the data.
-
Classification with Logistic Regression:
- To classify a mouse as obese or not obese, we need a threshold to turn probabilities into categories.
- A common threshold is 0.5:
- Mice with a probability greater than 0.5 are classified as obese.
- Mice with a probability less than or equal to 0.5 are classified as not obese.
- We use this threshold to classify various mice.
-
Evaluating the Model:
- We test the logistic regression model using known data:
- We correctly and incorrectly classify mice as obese or not obese.
- Confusion matrix:
- True positives: Correctly classified as obese.
- False positives: Incorrectly classified as obese but are not obese.
- True negatives: Correctly classified as not obese.
- False negatives: Incorrectly classified as not obese but are actually obese.
- We test the logistic regression model using known data:
-
Sensitivity and Specificity:
- Sensitivity: Measures the percentage of obese mice correctly classified.
- Specificity: Measures the percentage of non-obese mice correctly classified.
-
Changing the Threshold:
- If it's important to correctly classify every obese mouse, we can lower the threshold (e.g., to 0.1):
- This results in fewer false negatives but more false positives.
- Example: Classifying mice as infected with a disease like Ebola, where it's essential to catch every positive case.
- If we raise the threshold (e.g., to 0.9):
- This results in fewer false positives but more false negatives.
- For some datasets, a higher threshold does a better job at classifying obese and non-obese mice.
- If it's important to correctly classify every obese mouse, we can lower the threshold (e.g., to 0.1):
-
Finding the Best Threshold:
- Instead of testing every possible threshold, we use an ROC curve (Receiver Operator Characteristic curve) to summarize the model’s performance.
- ROC Curve:
- Y-axis: True Positive Rate (Sensitivity).
- X-axis: False Positive Rate (1 - Specificity).
- The ROC curve helps visualize how the model's sensitivity and false positive rate change as the threshold varies.
-
Building the ROC Curve:
- Start with a low threshold where all samples are classified as obese:
- True positive rate = 1 (all obese mice correctly classified).
- False positive rate = 1 (all non-obese mice incorrectly classified as obese).
- Plot this point at (1,1) on the ROC curve.
- As we increase the threshold, the true positive and false positive rates change:
- We calculate these rates at each threshold and plot the points.
- The points move leftward and downward on the graph as the threshold increases.
- The best points are above and to the left of the green diagonal line, showing that more obese mice are correctly classified than non-obese mice are misclassified.
- Start with a low threshold where all samples are classified as obese:
-
Area Under the Curve (AUC):
- The AUC (Area Under the Curve) summarizes the ROC curve:
- A higher AUC means better model performance.
- Example: An AUC of 0.9 means the model is very good at classifying the data.
- If comparing two models (e.g., logistic regression vs. random forest), the model with the higher AUC is better.
- The AUC (Area Under the Curve) summarizes the ROC curve:
-
Alternative to ROC: Precision:
- Another metric that replaces the false positive rate is precision:
- Precision: True Positives / (True Positives + False Positives).
- Precision measures the proportion of correctly classified positives.
- Precision is useful when there's a class imbalance (e.g., more non-obese than obese mice or when studying a rare disease).
- Another metric that replaces the false positive rate is precision: