AI/ML notes

fundamental

SateQuest Notes

A Gentle Introduction to Machine Learning

  • Important Concepts:
    • Machine learning is about predictions and classifications.
    • The bias-variance trade-off:
      • A model that fits training data well but performs poorly on new data demonstrates this trade-off.
    • Fancy methods (like deep learning) aren’t necessarily better. What matters is how well they predict new data.
    • The best method is the one that performs well on testing data, not just training data.
    • There are methods for choosing which data to use for training vs. testing, which can be explored further.

Machine Learning Fundamentals: Cross Validation

  • Predicting Heart Disease:

    • We want to predict if someone has heart disease based on variables like chest pain, blood circulation, etc.
    • Once a new patient arrives, we can measure these variables to make a prediction.
  • Choosing the Best Machine Learning Method:

    • Common machine learning methods:
      • Logistic regression
      • K-nearest neighbors (KNN)
      • Support vector machines (SVM)
      • Many more methods are available.
    • How do we choose the best one? The answer is cross-validation.
  • Cross-Validation Overview:

    • Cross-validation helps compare different machine learning methods and assess their
  • Two Key Steps in Machine Learning:

    1. Training the algorithm:
      • We use some of the data to estimate the parameters of a model, e.g., fitting a curve for logistic regression.
      • In machine learning lingo, this step is called training.
    2. Testing the algorithm:
      • We evaluate how well the trained method works by testing it on new data.
      • In machine learning lingo, this is called testing.
  • Why We Don’t Use All Data for Training:

    • Using all the data for training leaves no data to test the method.
    • Reusing the same data for both training and testing results in biased performance.
  • A Better Approach:

    • Divide the data: Use the first 75% for training and the last 25% for testing.
    • This allows us to compare methods by seeing how well they categorize the test data.
  • The Challenge of Splitting Data:

    • How do we know that using the first 75% for training and the last 25% for testing is the best split?
    • What if using another portion of the data, like the first 25% for testing, would give better results?
  • Cross-Validation Solution:

    • Cross-validation divides the data into multiple blocks and tests each one individually.
    • Example:
      • Start by using the first three blocks for training, and the last block for testing.
      • Keep track of how well the method performs.
      • Rotate the blocks: Use a different set of blocks for training and another block for testing.
      • Repeat until each block has been used for testing.
    • This process allows us to compare methods and find which one performs best across all test blocks.
  • Types of Cross-Validation:

    • Four-fold cross-validation: The data is divided into four blocks (as in the example above).
    • Leave-one-out cross-validation: Each individual sample is treated as a block.
    • Ten-fold cross-validation: Common in practice, where the data is divided into ten blocks.
  • Tuning Parameters:

    • Some methods, like ridge regression, involve a tuning parameter that needs to be guessed (not estimated from the data).
    • Cross-validation, such as ten-fold cross-validation, can help determine the best value for this tuning parameter.

Machine Learning Fundamentals: The Confusion Matrix

  • Predicting Heart Disease with Machine Learning:

    • We have clinical measurements like chest pain, blood circulation, blocked arteries, and weight.
    • We want to predict whether someone will develop heart disease using machine learning methods.
    • Methods to consider:
      • Logistic regression
      • K-nearest neighbors (KNN)
      • Random forest
      • Many other machine learning methods.
  • How to Choose the Best Method:

    • Step 1: Divide the data into training and testing sets.
    • Step 2: Train all selected machine learning methods using the training data.
    • Step 3: Test each method on the testing data.
    • Cross-validation can be used to improve the reliability of these steps (see StatQuest for more details).
  • Evaluating Model Performance:

    • To evaluate each method, we use a confusion matrix.
      • Rows: Predicted values from the algorithm.
      • Columns: Actual known values.
      • Since we are predicting heart disease or no heart disease, the confusion matrix has 2 rows and 2 columns.
  • Confusion Matrix Breakdown:

    • True Positives (Top left): Patients with heart disease correctly classified as having heart disease.
    • True Negatives (Bottom right): Patients without heart disease correctly classified as not having heart disease.
    • False Negatives (Bottom left): Patients with heart disease incorrectly classified as not having it.
    • False Positives (Top right): Patients without heart disease incorrectly classified as having it.
  • Example with Random Forest:

    • True positives: 142 (correctly classified as having heart disease).
    • True negatives: 110 (correctly classified as not having heart disease).
    • False negatives: 29 (incorrectly classified as not having heart disease).
    • False positives: 22 (incorrectly classified as having heart disease).
    • Green boxes on the diagonal represent correct classifications.
    • Red boxes off the diagonal represent errors.
  • Comparing Methods:

    • Comparing random forest to k-nearest neighbors (KNN):
      • Random forest correctly classified more patients with heart disease (142 vs. 107).
      • Random forest also correctly classified more patients without heart disease (110 vs. 79).
    • If we had to choose between the two, we would select random forest.
  • Logistic Regression Example:

    • When applied, logistic regression produces a confusion matrix similar to random forest, making it hard to decide which method is better.
    • More advanced metrics like sensitivity, specificity, ROC, and AUC can help us make a decision (covered in the next StatQuest).

Machine Learning Fundamentals: Sensitivity and Specificity

  • Recap of the Confusion Matrix (2x2 Example):

    • Rows = predicted values.
    • Columns = known truth.
    • Categories: Has heart disease or does not have heart disease.
      • True positives (top left): Patients with heart disease correctly identified.
      • True negatives (bottom right): Patients without heart disease correctly identified.
      • False negatives (bottom left): Patients with heart disease incorrectly classified as not having it.
      • False positives (top right): Patients without heart disease incorrectly classified as having it.
  • Calculating Sensitivity and Specificity (2x2 Matrix):

    • Sensitivity: Measures how well the model correctly identifies patients with heart disease.
      • Formula: Sensitivity = True Positives / (True Positives + False Negatives).
    • Specificity: Measures how well the model correctly identifies patients without heart disease.
      • Formula: Specificity = True Negatives / (True Negatives + False Positives).
  • Example with Logistic Regression:

    • Confusion Matrix:
      • True Positives = 139, False Negatives = 32.
      • True Negatives = 112, False Positives = 20.
    • Sensitivity = 139 / (139 + 32) = 0.81.
      • 81% of patients with heart disease were correctly identified.
    • Specificity = 112 / (112 + 20) = 0.85.
      • 85% of patients without heart disease were correctly identified.
  • Example with Random Forest:

    • Confusion Matrix:
      • True Positives = 142, False Negatives = 29.
      • True Negatives = 110, False Positives = 22.
    • Sensitivity = 142 / (142 + 29) = 0.83.
      • 83% of patients with heart disease were correctly identified.
    • Specificity = 110 / (110 + 22) = 0.83.
      • 83% of patients without heart disease were correctly identified.
  • Comparison of Logistic Regression vs. Random Forest:

    • Random forest: Higher sensitivity (better at identifying patients with heart disease).
    • Logistic regression: Higher specificity (better at identifying patients without heart disease).
    • Choice of method depends on which is more important:
      • Prioritize sensitivity if identifying patients with heart disease is crucial.
      • Prioritize specificity if identifying patients without heart disease is more important.
  • Sensitivity and Specificity for Larger Confusion Matrices (3x3 Example):

    • When the confusion matrix has three or more categories, sensitivity and specificity must be calculated for each category individually.

Machine Learning Fundamentals: Bias and Variance

  • Introduction:

    • We have data on the weight and height of several mice.
    • Lighter mice tend to be shorter, while heavier mice tend to be taller.
    • After a certain weight, however, mice stop getting taller and just become more obese.
    • The goal is to predict mouse height based on its weight.
  • Modeling the Relationship:

    • Ideally, we’d have the exact mathematical formula that describes the relationship between weight and height.
    • Since we don’t know the formula, we’ll use two machine learning methods to approximate the relationship.
    • The true relationship curve will remain in the figure for reference.
  • Step 1: Splitting the Data:

    • The data is split into two sets:
      • Training set: Used to train the machine learning models (blue dots).
      • Testing set: Used to evaluate the models (green dots).
  • Linear Regression (Least Squares):

    • The first method we use is linear regression, which fits a straight line to the training data.
    • A straight line doesn't have the flexibility to accurately capture the arc in the true relationship between weight and height.
    • No matter how well we fit it to the training data, the line won’t capture the curve.
    • The inability of the model to capture the true relationship is called bias.
      • Linear regression has high bias because it cannot adjust to the arc.
  • Flexible Model (Squiggly Line):

    • Another method might fit a squiggly line to the training data.
    • This line is very flexible and follows the arc of the true relationship closely, meaning it has low bias.
    • It hugs the training data tightly and does a much better job of capturing the true relationship than the straight line.
  • Comparing Fits:

    • We can compare the performance of the two models by calculating the sums of squares.
    • This involves measuring the distances between the predicted values and the actual data points, squaring those distances, and summing them up.
      • Distances are squared so that negative and positive values don’t cancel each other out.
    • Squiggly line: Fits the training data perfectly, so the distances are zero.
    • Linear regression: Fits the data less perfectly, resulting in a larger sum of squares.
    • In the training set, the squiggly line wins because it fits the data better.
  • Step 2: Testing on New Data:

    • Now we evaluate the models on the testing set.
    • The squiggly line, which did well on the training set, performs poorly on the testing set.
    • The straight line, despite being less flexible, performs better on the testing set.
  • Understanding Bias and Variance:

    • Bias: The error introduced by the model’s assumptions (e.g., a straight line assuming the relationship is linear).
      • The straight line has high bias but low variance (it performs consistently across datasets).
      • The squiggly line has low bias (can capture the curve) but high variance (inconsistent performance across datasets).
    • Variance: The variability of a model’s performance on different datasets.
      • A model with high variance performs well on the training set but poorly on the testing set (i.e., it’s inconsistent).
    • Overfitting: When a model (like the squiggly line) fits the training data too well but doesn’t generalize well to new data.
  • Finding the Sweet Spot:

    • The ideal machine learning algorithm has low bias (can capture the true relationship) and low variance (makes consistent predictions).
    • This involves finding a balance between simple models (like linear regression) and complex models (like the squiggly line).
  • Overfitting: When a model performs well on training data but poorly on testing data due to high variance.

    • Methods to find the balance between simple and complex models:
      • Regularization: A technique to reduce overfitting by controlling the flexibility of the model.
      • Boosting: A method that builds strong models by combining weaker ones.
      • Bagging: A technique that involves combining the predictions of multiple models to reduce variance (e.g., used in Random Forest).

ROC and AUC, Clearly Explained!

  • Introduction:

    • We have data on mouse weight and whether a mouse is obese or not obese.
    • The blue dots represent obese mice, and the red dots represent mice that are not obese.
    • We want to predict if a mouse is obese based on its weight.
  • Logistic Regression Model:

    • We fit a logistic regression curve to the data.
      • The y-axis now represents the probability that a mouse is obese.
      • The x-axis remains the weight of the mouse.
    • The curve predicts:
      • High probability of obesity for a heavy mouse.
      • Low probability of obesity for a light mouse.
    • Logistic regression gives us the probability of obesity based on weight.
  • Classification with Logistic Regression:

    • To classify a mouse as obese or not obese, we need a threshold to turn probabilities into categories.
    • A common threshold is 0.5:
      • Mice with a probability greater than 0.5 are classified as obese.
      • Mice with a probability less than or equal to 0.5 are classified as not obese.
    • We use this threshold to classify various mice.
  • Evaluating the Model:

    • We test the logistic regression model using known data:
      • We correctly and incorrectly classify mice as obese or not obese.
    • Confusion matrix:
      • True positives: Correctly classified as obese.
      • False positives: Incorrectly classified as obese but are not obese.
      • True negatives: Correctly classified as not obese.
      • False negatives: Incorrectly classified as not obese but are actually obese.
  • Sensitivity and Specificity:

    • Sensitivity: Measures the percentage of obese mice correctly classified.
    • Specificity: Measures the percentage of non-obese mice correctly classified.
  • Changing the Threshold:

    • If it's important to correctly classify every obese mouse, we can lower the threshold (e.g., to 0.1):
      • This results in fewer false negatives but more false positives.
      • Example: Classifying mice as infected with a disease like Ebola, where it's essential to catch every positive case.
    • If we raise the threshold (e.g., to 0.9):
      • This results in fewer false positives but more false negatives.
      • For some datasets, a higher threshold does a better job at classifying obese and non-obese mice.
  • Finding the Best Threshold:

    • Instead of testing every possible threshold, we use an ROC curve (Receiver Operator Characteristic curve) to summarize the model’s performance.
    • ROC Curve:
      • Y-axis: True Positive Rate (Sensitivity).
      • X-axis: False Positive Rate (1 - Specificity).
    • The ROC curve helps visualize how the model's sensitivity and false positive rate change as the threshold varies.
  • Building the ROC Curve:

    • Start with a low threshold where all samples are classified as obese:
      • True positive rate = 1 (all obese mice correctly classified).
      • False positive rate = 1 (all non-obese mice incorrectly classified as obese).
      • Plot this point at (1,1) on the ROC curve.
    • As we increase the threshold, the true positive and false positive rates change:
      • We calculate these rates at each threshold and plot the points.
      • The points move leftward and downward on the graph as the threshold increases.
    • The best points are above and to the left of the green diagonal line, showing that more obese mice are correctly classified than non-obese mice are misclassified.
  • Area Under the Curve (AUC):

    • The AUC (Area Under the Curve) summarizes the ROC curve:
      • A higher AUC means better model performance.
      • Example: An AUC of 0.9 means the model is very good at classifying the data.
      • If comparing two models (e.g., logistic regression vs. random forest), the model with the higher AUC is better.
  • Alternative to ROC: Precision:

    • Another metric that replaces the false positive rate is precision:
      • Precision: True Positives / (True Positives + False Positives).
      • Precision measures the proportion of correctly classified positives.
      • Precision is useful when there's a class imbalance (e.g., more non-obese than obese mice or when studying a rare disease).