{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SPg_7aF97ut2"
},
"source": [
"\n",
"\n",
"# Continuous prediction with regression\n",
"\n",
"Recall, we discussed a strict threshold classifier with accuracy as the loss function. Now consider continuous prediction, we need a loss function. A reasonable strategy would be to minimize the squared distances between our predictions and the observed values. In other words, $\\sum_{i=1}^n (Y_i - \\hat \\mu_i)^2.$\n",
"\n",
"If we were to dived this by $n$, it would be the average of the squared errors, or the *mean squared error* (MSE). We can use minimizing the squared error both as a rule for finding a good prediction and as our evaluation strategy for held out data. \n",
"\n",
"What's left is to figure out how to come up with $\\hat \\mu_i$, our predictions for the observation $Y_i$. We previously considered just a rescaled version of $X$, our predictor, using regression through the origin. In this module, we'll try a slightly more complex model that includes a location (intercept) shift and a scale factor (slope). The consequence will be to fit the best line, in a certain sense, through our $X$, $Y$ paired data.\n",
"\n",
"To tie ourselves down with an example, consider the previous lecture's example, consider trying to get the FLAIR value from the other, non-FLAIR, imaging values. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 192
},
"colab_type": "code",
"id": "NUkxnWrB1RrP",
"outputId": "2d9ff002-6ea2-435a-bbfb-f58ab0361971"
},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | FLAIR | \n", "PD | \n", "T1 | \n", "T2 | \n", "FLAIR_10 | \n", "PD_10 | \n", "T1_10 | \n", "T2_10 | \n", "FLAIR_20 | \n", "PD_20 | \n", "T1_20 | \n", "T2_20 | \n", "GOLD_Lesions | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1.143692 | \n", "1.586219 | \n", "-0.799859 | \n", "1.634467 | \n", "0.437568 | \n", "0.823800 | \n", "-0.002059 | \n", "0.573663 | \n", "0.279832 | \n", "0.548341 | \n", "0.219136 | \n", "0.298662 | \n", "0 | \n", "
1 | \n", "1.652552 | \n", "1.766672 | \n", "-1.250992 | \n", "0.921230 | \n", "0.663037 | \n", "0.880250 | \n", "-0.422060 | \n", "0.542597 | \n", "0.422182 | \n", "0.549711 | \n", "0.061573 | \n", "0.280972 | \n", "0 | \n", "
2 | \n", "1.036099 | \n", "0.262042 | \n", "-0.858565 | \n", "-0.058211 | \n", "-0.044280 | \n", "-0.308569 | \n", "0.014766 | \n", "-0.256075 | \n", "-0.136532 | \n", "-0.350905 | \n", "0.020673 | \n", "-0.259914 | \n", "0 | \n", "
3 | \n", "1.037692 | \n", "0.011104 | \n", "-1.228796 | \n", "-0.470222 | \n", "-0.013971 | \n", "-0.000498 | \n", "-0.395575 | \n", "-0.221900 | \n", "0.000807 | \n", "-0.003085 | \n", "-0.193249 | \n", "-0.139284 | \n", "0 | \n", "