shortcuts/notebooks/regressions.ipynb

171 lines
99 KiB
Plaintext
Raw Normal View History

2024-09-30 23:58:32 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Regressions\n",
"\n",
"$X = (0,1)$\n",
"\n",
"$Y = sin(2\\pi X)$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"![name](
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"// deno-lint-ignore-file\n",
"\n",
"import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\";\n",
"import pl from \"npm:nodejs-polars\";\n",
"import plot from \"../plot/mod.ts\";\n",
"\n",
"const data = await Deno.readTextFile(\"assets/X_Y_Sinusoid_Data.csv\");\n",
"const df = pl.readCSV(data, { sep: \",\" });\n",
"\n",
"const real = pl.DataFrame({ x: new Array(100).fill(0).map((_, i) => i / 100)}).select(\n",
" pl.col('x'),\n",
" pl.col('x').mul(2).mul(3.14).sin().alias('y')\n",
");\n",
"\n",
"const draw = (x, y, title = \"Sinusoid Data\") => \n",
" plot.DrawPlot(\n",
" { \n",
" title,\n",
" width: 6,\n",
" height: 3,\n",
" XLabel: \"X\", \n",
" YLabel: \"Y\", \n",
" }, \n",
" { type: \"line\", data: [real.x, real.y], legend: \"Sinusoid\", lineDashes: [3, 4], lineColor: \"#ff8888\", lineWidth: 1 },\n",
" { type: \"scatter\", data: [x, y], legend: \"Data\", lineDashes: [3, 4], lineWidth: 2, glyphColor: \"#4444ff\", glyphShape: \"circle\" },\n",
" { type: \"trend\", data: [x, y], legend: \"Trend\", lineDashes: [4, 2], lineColor: '#aacccc', lineWidth: .5 },\n",
" );\n",
"\n",
2024-10-01 00:00:01 +00:00
" const comparePredicted = (x, y, predicted) => plot.DrawPlot(\n",
" { \n",
" width: 7,\n",
" height: 4,\n",
" XLabel: \"X\", \n",
" YLabel: \"Y\", \n",
" }, \n",
" { type: \"line\", data: [real.x, real.y], legend: \"Sinusoid\", lineDashes: [3, 4], lineColor: \"#ff8888\", lineWidth: 1 },\n",
" { type: \"linePoints\", data: [x, y], legend: \"Test Data\", lineDashes: [3, 4], lineColor: \"#8888ff\", glyphColor: \"#4444ff\", glyphShape: \"circle\" },\n",
" { type: \"linePoints\", data: [x, predicted], lineWidth: .5, legend: \"Predicted\", glyphColor: '#f00', glyphShape: \"pyramid\" },\n",
");\n",
"\n",
2024-09-30 23:58:32 +00:00
"\n",
"draw(df.x, df.y);"
]
},
2024-10-01 00:00:01 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Polynomial Tranformation\n",
"\n",
"First we try to predict values without polynomial transformation:"
]
},
2024-09-30 23:58:32 +00:00
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
2024-10-01 00:00:01 +00:00
"![name](
2024-09-30 23:58:32 +00:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import regr from '../regr/mod.ts';\n",
"\n",
2024-10-01 00:00:01 +00:00
"const linregWoPoly = regr.Linear();\n",
2024-09-30 23:58:32 +00:00
"\n",
2024-10-01 00:00:01 +00:00
"linregWoPoly.fit(df.drop('y').rows(), df.select('y').rows());\n",
"const predWoPoly = linregWoPoly.predict(df.drop('y').rows());\n",
2024-09-30 23:58:32 +00:00
"\n",
2024-10-01 00:00:01 +00:00
"comparePredicted(df.x, df.y, predWoPoly);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now augment our dataset with high degree polynomial:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"![name](
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import { polynomialTransform } from '../encoding.ts';\n",
2024-09-30 23:58:32 +00:00
"\n",
2024-10-01 00:00:01 +00:00
"\n",
"\n",
"const polyFeatures = polynomialTransform(df.drop('y'), 36, false, false)\n",
"\n",
"const [polyX, polyY] = [polyFeatures, df.select('y')]\n",
"\n",
"const linregPoly = regr.Linear();\n",
"\n",
"linregPoly.fit(polyX.rows(), polyY.rows());\n",
"const predY = linregPoly.predict(polyX.rows());\n",
"\n",
"comparePredicted(df.x, df.y, predY);\n"
2024-09-30 23:58:32 +00:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Deno",
"language": "typescript",
"name": "deno"
},
"language_info": {
"codemirror_mode": "typescript",
"file_extension": ".ts",
"mimetype": "text/x.typescript",
"name": "typescript",
"nbconvert_exporter": "script",
"pygments_lexer": "typescript",
"version": "5.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}