expr, split

This commit is contained in:
Anton Nesterov 2024-09-26 03:41:33 +02:00
parent 6c5a77ae3f
commit 71583054f8
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
3 changed files with 214 additions and 140 deletions

View file

@ -13,11 +13,13 @@
"source": [ "source": [
"# Data Analytics with JS\n", "# Data Analytics with JS\n",
"\n", "\n",
"The dataset contains all the information about cars, a name of a manufacturer, all car's technical parameters and a sale price of a car.\n", "The dataset contains all the information about cars, a name of a manufacturer,\n",
"all car's technical parameters and a sale price of a car.\n",
"\n", "\n",
"Libraries:\n", "Libraries:\n",
" - nodejs-polars\n", "\n",
" - @observable/plot" "- nodejs-polars\n",
"- @observable/plot"
] ]
}, },
{ {
@ -33,7 +35,8 @@
"source": [ "source": [
"## Exploring Data\n", "## Exploring Data\n",
"\n", "\n",
"Use [pola-rs](https://pola-rs.github.io/nodejs-polars/modules.html) dataframes to read and manipulate data.\n", "Use [pola-rs](https://pola-rs.github.io/nodejs-polars/modules.html) dataframes\n",
"to read and manipulate data.\n",
"\n", "\n",
"- use `df.head(n)` to get n first rows\n", "- use `df.head(n)` to get n first rows\n",
"- use `df.sample(n)` to get n random rows from the dataset\n", "- use `df.sample(n)` to get n random rows from the dataset\n",
@ -489,13 +492,13 @@
} }
], ],
"source": [ "source": [
"import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\"\n", "import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\";\n",
"import pl from \"npm:nodejs-polars\"\n", "import pl from \"npm:nodejs-polars\";\n",
"\n", "\n",
"let data = await Deno.readTextFile('assets/CarPrice_Assignment.csv')\n", "let data = await Deno.readTextFile(\"assets/CarPrice_Assignment.csv\");\n",
"let df = pl.readCSV(data, { sep: \",\" })\n", "let df = pl.readCSV(data, { sep: \",\" });\n",
"\n", "\n",
"await display(df.head(5))" "await display(df.head(5));"
] ]
}, },
{ {
@ -610,12 +613,12 @@
], ],
"source": [ "source": [
"await display(\n", "await display(\n",
" df.select(\n", " df.select(\n",
" 'enginesize', \n", " \"enginesize\",\n",
" 'horsepower', \n", " \"horsepower\",\n",
" 'price'\n", " \"price\",\n",
" ).describe()\n", " ).describe(),\n",
")" ");"
] ]
}, },
{ {
@ -645,9 +648,9 @@
], ],
"source": [ "source": [
"// check for duplicates\n", "// check for duplicates\n",
"const hasDups = df.select('car_ID').isDuplicated().toArray().includes(true)\n", "const hasDups = df.select(\"car_ID\").isDuplicated().toArray().includes(true);\n",
"// if there are duplicates, use df.filter()\n", "// if there are duplicates, use df.filter()\n",
"hasDups" "hasDups;"
] ]
}, },
{ {
@ -793,17 +796,17 @@
], ],
"source": [ "source": [
"// get brand names from `CarName`\n", "// get brand names from `CarName`\n",
"let brandNameTable = df.select('CarName').map((row) => {\n", "let brandNameTable = df.select(\"CarName\").map((row) => {\n",
" const [carName] = row\n", " const [carName] = row;\n",
" const brand = carName.split(' ')[0].toLowerCase()\n", " const brand = carName.split(\" \")[0].toLowerCase();\n",
" return brand\n", " return brand;\n",
"})\n", "});\n",
"\n", "\n",
"// create a dataframe from brand names\n", "// create a dataframe from brand names\n",
"let brandDf = pl.DataFrame({\n", "let brandDf = pl.DataFrame({\n",
" \"brand\": brandNameTable\n", " \"brand\": brandNameTable,\n",
"})\n", "});\n",
"await display(brandDf.unique())" "await display(brandDf.unique());"
] ]
}, },
{ {
@ -927,18 +930,18 @@
"source": [ "source": [
"// transform to remove duplicates\n", "// transform to remove duplicates\n",
"brandNameTable = brandNameTable.map((name) => {\n", "brandNameTable = brandNameTable.map((name) => {\n",
" name = name\n", " name = name\n",
" .replace('maxda', 'mazda')\n", " .replace(\"maxda\", \"mazda\")\n",
" .replace('porcshce', 'porsche')\n", " .replace(\"porcshce\", \"porsche\")\n",
" .replace('toyouta', 'toyota')\n", " .replace(\"toyouta\", \"toyota\")\n",
" .replace(/(vw|vokswagen)/ig, 'volkswagen');\n", " .replace(/(vw|vokswagen)/ig, \"volkswagen\");\n",
" return name\n", " return name;\n",
"})\n", "});\n",
"\n", "\n",
"brandDf = pl.DataFrame({\n", "brandDf = pl.DataFrame({\n",
" brand: brandNameTable\n", " brand: brandNameTable,\n",
"})\n", "});\n",
"await display(brandDf.unique())" "await display(brandDf.unique());"
] ]
}, },
{ {
@ -1398,8 +1401,8 @@
], ],
"source": [ "source": [
"// add new column `brand` to our dataframe\n", "// add new column `brand` to our dataframe\n",
"df = brandDf.hstack(df)\n", "df = brandDf.hstack(df);\n",
"await display(df.head(5))" "await display(df.head(5));"
] ]
}, },
{ {
@ -1775,8 +1778,8 @@
} }
], ],
"source": [ "source": [
"df = df.drop('car_ID', 'symboling', 'CarName')\n", "df = df.drop(\"car_ID\", \"symboling\", \"CarName\");\n",
"df.head(3)" "df.head(3);"
] ]
}, },
{ {
@ -1940,8 +1943,8 @@
} }
], ],
"source": [ "source": [
"let brandCount = df.groupBy('brand').len().sort('brand_count')\n", "let brandCount = df.groupBy(\"brand\").len().sort(\"brand_count\");\n",
"brandCount" "brandCount;"
] ]
}, },
{ {
@ -2040,8 +2043,8 @@
} }
], ],
"source": [ "source": [
"let avgPricePerBrand = df.groupBy('brand').agg({'price': ['mean']})\n", "let avgPricePerBrand = df.groupBy(\"brand\").agg({ \"price\": [\"mean\"] });\n",
"avgPricePerBrand.describe()" "avgPricePerBrand.describe();"
] ]
}, },
{ {
@ -2087,8 +2090,8 @@
"source": [ "source": [
"// map brand name to price\n", "// map brand name to price\n",
"avgPricePerBrand = avgPricePerBrand\n", "avgPricePerBrand = avgPricePerBrand\n",
" .toRecords()\n", " .toRecords()\n",
" .reduce((acc, rec) => ({...acc, [rec.brand]: rec.price}), {})" " .reduce((acc, rec) => ({ ...acc, [rec.brand]: rec.price }), {});"
] ]
}, },
{ {
@ -2165,14 +2168,18 @@
"source": [ "source": [
"// create brand categories by budget\n", "// create brand categories by budget\n",
"let brandCategory = df.brand.toArray().map((brand) => {\n", "let brandCategory = df.brand.toArray().map((brand) => {\n",
" const avgPrice = avgPricePerBrand[brand]\n", " const avgPrice = avgPricePerBrand[brand];\n",
" return avgPrice < 10000 ? \"Budget\" : avgPrice > 20000 ? \"Luxury\" : \"Mid_Range\"\n", " return avgPrice < 10000\n",
"})\n", " ? \"Budget\"\n",
" : avgPrice > 20000\n",
" ? \"Luxury\"\n",
" : \"Mid_Range\";\n",
"});\n",
"let catDf = pl.DataFrame({\n", "let catDf = pl.DataFrame({\n",
" \"brand_category\": brandCategory\n", " \"brand_category\": brandCategory,\n",
"})\n", "});\n",
"\n", "\n",
"catDf.sample(5)" "catDf.sample(5);"
] ]
}, },
{ {
@ -2196,8 +2203,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"df = catDf.hstack(df)\n", "df = catDf.hstack(df);\n",
"df.writeCSV('assets/cleaned_car_prices.csv')" "df.writeCSV(\"assets/cleaned_car_prices.csv\");"
] ]
}, },
{ {
@ -2213,7 +2220,9 @@
"source": [ "source": [
"## Exploratory Data Analysis\n", "## Exploratory Data Analysis\n",
"\n", "\n",
"For plotting we use [@observable/plot](https://observablehq.com/plot) and configured shotcuts for jupyter notebooks imported from [l12.xyz/x/shortcuts](https://l12.xyz/x/shortcuts)." "For plotting we use [@observable/plot](https://observablehq.com/plot) and\n",
"configured shotcuts for jupyter notebooks imported from\n",
"[l12.xyz/x/shortcuts](https://l12.xyz/x/shortcuts)."
] ]
}, },
{ {
@ -2280,22 +2289,26 @@
} }
], ],
"source": [ "source": [
"import {Plot, document} from 'https://l12.xyz/x/shortcuts/raw/plots.ts'\n", "import { document, Plot } from \"https://l12.xyz/x/shortcuts/raw/plots.ts\";\n",
"\n", "\n",
"const brandCountRecords = brandCount.toRecords()\n", "const brandCountRecords = brandCount.toRecords();\n",
"console.log(brandCountRecords)\n", "console.log(brandCountRecords);\n",
"const brandCountPlot = Plot.plot({\n", "const brandCountPlot = Plot.plot({\n",
" marginLeft: 80,\n", " marginLeft: 80,\n",
" style: {\n", " style: {\n",
" backgroundColor: \"#fff\"\n", " backgroundColor: \"#fff\",\n",
" },\n", " },\n",
" x: {padding: 0.4},\n", " x: { padding: 0.4 },\n",
" marks: [\n", " marks: [\n",
" Plot.barX(brandCountRecords, {x: \"brand_count\", y: \"brand\", sort: {y: \"x\", order: \"descending\"}}),\n", " Plot.barX(brandCountRecords, {\n",
" ],\n", " x: \"brand_count\",\n",
" document\n", " y: \"brand\",\n",
"})\n", " sort: { y: \"x\", order: \"descending\" },\n",
"await display(brandCountPlot)" " }),\n",
" ],\n",
" document,\n",
"});\n",
"await display(brandCountPlot);"
] ]
}, },
{ {
@ -2330,8 +2343,10 @@
} }
], ],
"source": [ "source": [
"let numericColumns = df.columns.filter((col) => df[col].isNumeric() && col !== 'price')\n", "let numericColumns = df.columns.filter((col) =>\n",
"numericColumns" " df[col].isNumeric() && col !== \"price\"\n",
");\n",
"numericColumns;"
] ]
}, },
{ {
@ -2339,7 +2354,10 @@
"id": "1c8d3b90-b4b1-4da6-9e7c-f8fd7cffb77d", "id": "1c8d3b90-b4b1-4da6-9e7c-f8fd7cffb77d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Sometimes we can get some intuitive insight seeing the data plotted from different dimensions. It is an optional step, but it might help to get some assumtions about the relationships in the dataset. Below is an example for drawing plots side-by-side." "Sometimes we can get some intuitive insight seeing the data plotted from\n",
"different dimensions. It is an optional step, but it might help to get some\n",
"assumtions about the relationships in the dataset. Below is an example for\n",
"drawing plots side-by-side."
] ]
}, },
{ {
@ -2370,22 +2388,22 @@
} }
], ],
"source": [ "source": [
"import { sideBySidePlot } from 'https://l12.xyz/x/shortcuts/raw/plots.ts'\n", "import { sideBySidePlot } from \"https://l12.xyz/x/shortcuts/raw/plots.ts\";\n",
"\n", "\n",
"let records = df.toRecords();\n", "let records = df.toRecords();\n",
"\n", "\n",
"const plt = sideBySidePlot({\n", "const plt = sideBySidePlot({\n",
" x: numericColumns,\n", " x: numericColumns,\n",
" y: ['price'],\n", " y: [\"price\"],\n",
" marks: [\n", " marks: [\n",
" (x, y) => Plot.dot(records, {x, y}),\n", " (x, y) => Plot.dot(records, { x, y }),\n",
" (x, y) => Plot.linearRegressionY(records, {x, y, stroke: \"red\"}),\n", " (x, y) => Plot.linearRegressionY(records, { x, y, stroke: \"red\" }),\n",
" ],\n", " ],\n",
" cols: 3\n", " cols: 3,\n",
"})\n", "});\n",
"\n", "\n",
"await display(\n", "await display(\n",
" plt,\n", " plt,\n",
");" ");"
] ]
}, },
@ -2400,7 +2418,9 @@
"tags": [] "tags": []
}, },
"source": [ "source": [
"Let's view the list of top features that have high correlation coefficient. The pearsonCorr() function calculates the Pearson'r correlation coefficients with respect to the 'price'." "Let's view the list of top features that have high correlation coefficient. The\n",
"pearsonCorr() function calculates the Pearson'r correlation coefficients with\n",
"respect to the 'price'."
] ]
}, },
{ {
@ -2522,18 +2542,18 @@
} }
], ],
"source": [ "source": [
"// we select pearon's idx from dataframe for all numeric cols, \n", "// we select pearon's idx from dataframe for all numeric cols,\n",
"// then we transpose result so that columns become rows, \n", "// then we transpose result so that columns become rows,\n",
"// then we sort by the idx column\n", "// then we sort by the idx column\n",
"df.select(\n", "df.select(\n",
" ...numericColumns.map((col) => pl.pearsonCorr(col, 'price'))\n", " ...numericColumns.map((col) => pl.pearsonCorr(col, \"price\")),\n",
")\n", ")\n",
" .transpose({\n", " .transpose({\n",
" columnNames: [\"idx (price)\"],\n", " columnNames: [\"idx (price)\"],\n",
" headerName: \"Variable\",\n", " headerName: \"Variable\",\n",
" includeHeader: true\n", " includeHeader: true,\n",
" })\n", " })\n",
" .sort('idx (price)', true)" " .sort(\"idx (price)\", true);"
] ]
}, },
{ {
@ -2549,9 +2569,12 @@
"source": [ "source": [
"### Linearity Assumption\n", "### Linearity Assumption\n",
"\n", "\n",
"Linear regression needs the relationship between independent variable and the dependent variable to be linear. We can test this assumption with some scatter plots and regression lines.\n", "Linear regression needs the relationship between independent variable and the\n",
"dependent variable to be linear. We can test this assumption with some scatter\n",
"plots and regression lines.\n",
"\n", "\n",
"** Here we use the same side-by-side plot shortcut, but for selected varisbles with high correlation coefficent.\n" "** Here we use the same side-by-side plot shortcut, but for selected varisbles\n",
"with high correlation coefficent."
] ]
}, },
{ {
@ -2577,17 +2600,17 @@
], ],
"source": [ "source": [
"const plt = sideBySidePlot({\n", "const plt = sideBySidePlot({\n",
" x: ['enginesize', 'curbweight', 'horsepower', 'carwidth'],\n", " x: [\"enginesize\", \"curbweight\", \"horsepower\", \"carwidth\"],\n",
" y: ['price'],\n", " y: [\"price\"],\n",
" marks: [\n", " marks: [\n",
" (x, y) => Plot.dot(records, {x, y}),\n", " (x, y) => Plot.dot(records, { x, y }),\n",
" (x, y) => Plot.linearRegressionY(records, {x, y, stroke: \"red\"}),\n", " (x, y) => Plot.linearRegressionY(records, { x, y, stroke: \"red\" }),\n",
" ],\n", " ],\n",
" cols: 2\n", " cols: 2,\n",
"})\n", "});\n",
"\n", "\n",
"await display(\n", "await display(\n",
" plt,\n", " plt,\n",
");" ");"
] ]
}, },
@ -2604,12 +2627,21 @@
"source": [ "source": [
"### Homoscedasticity\n", "### Homoscedasticity\n",
"\n", "\n",
"The assumption of homoscedasticity (constant variance), is crucial to linear regression models. Homoscedasticity describes a situation in which the error term or variance or the \"noise\" or random disturbance in the relationship between the independent variables and the dependent variable is the same across all values of the independent variable. In other words, there is a constant variance present in the response variable as the predictor variable increases. If the \"noise\" is not the same across the values of an independent variable, we call it heteroscedasticity, opposite of homoscedasticity. \n", "The assumption of homoscedasticity (constant variance), is crucial to linear\n",
"regression models. Homoscedasticity describes a situation in which the error\n",
"term or variance or the \"noise\" or random disturbance in the relationship\n",
"between the independent variables and the dependent variable is the same across\n",
"all values of the independent variable. In other words, there is a constant\n",
"variance present in the response variable as the predictor variable increases.\n",
"If the \"noise\" is not the same across the values of an independent variable, we\n",
"call it heteroscedasticity, opposite of homoscedasticity.\n",
"\n", "\n",
"#### Residuals\n", "#### Residuals\n",
"\n", "\n",
"Next we apply residual expression to 'price' and 'enginesize' varibles in order to check this assumption.\n", "Next we apply residual expression to 'price' and 'enginesize' varibles in order\n",
"[The residuals function](https://l12.xyz/x/shortcuts/src/branch/main/expr.ts) uses mean squared." "to check this assumption.\n",
"[The residuals function](https://l12.xyz/x/shortcuts/src/branch/main/expr.ts)\n",
"uses mean squared."
] ]
}, },
{ {
@ -2646,24 +2678,24 @@
} }
], ],
"source": [ "source": [
"import { residuals } from 'https://l12.xyz/x/shortcuts/raw/expr.ts'\n", "import { residuals } from \"https://l12.xyz/x/shortcuts/raw/expr.ts\";\n",
"\n", "\n",
"let residualDf = df.select(\n", "let residualDf = df.select(\n",
" 'enginesize', \n", " \"enginesize\",\n",
" residuals(pl.col('enginesize'), pl.col('price'))\n", " residuals(pl.col(\"enginesize\"), pl.col(\"price\")),\n",
")\n", ");\n",
"\n", "\n",
"let residPlot = Plot.plot({\n", "let residPlot = Plot.plot({\n",
" x: \"enginesize\", \n", " x: \"enginesize\",\n",
" y: \"price\",\n", " y: \"price\",\n",
" marks: [\n", " marks: [\n",
" Plot.dot(residualDf.toRecords(), { x: \"enginesize\", y: \"price\"}),\n", " Plot.dot(residualDf.toRecords(), { x: \"enginesize\", y: \"price\" }),\n",
" Plot.ruleY([0], {stroke: '#ccc'})\n", " Plot.ruleY([0], { stroke: \"#ccc\" }),\n",
" ],\n", " ],\n",
" document\n", " document,\n",
"})\n", "});\n",
"\n", "\n",
"await display(residPlot)\n" "await display(residPlot);"
] ]
}, },
{ {
@ -2677,7 +2709,9 @@
"tags": [] "tags": []
}, },
"source": [ "source": [
"From the above plot, we can tell the error variance across the true line is dispersed somewhat not uniformly, but in a funnel like shape. So, the assumption of the *homoscedasticity* is more likely not met.\n" "From the above plot, we can tell the error variance across the true line is\n",
"dispersed somewhat not uniformly, but in a funnel like shape. So, the assumption\n",
"of the _homoscedasticity_ is more likely not met."
] ]
}, },
{ {
@ -2693,7 +2727,12 @@
"source": [ "source": [
"## Normality\n", "## Normality\n",
"\n", "\n",
"The linear regression analysis requires the dependent variable, 'price', to be normally distributed. A histogram, box plot, or a Q-Q-Plot can check if the target variable is normally distributed. The goodness of fit test, e.g., the Kolmogorov-Smirnov test can check for normality in the dependent variable. [This documentation](https://towardsdatascience.com/normality-tests-in-python-31e04aa4f411) contains more information on the normality assumption. \n", "The linear regression analysis requires the dependent variable, 'price', to be\n",
"normally distributed. A histogram, box plot, or a Q-Q-Plot can check if the\n",
"target variable is normally distributed. The goodness of fit test, e.g., the\n",
"Kolmogorov-Smirnov test can check for normality in the dependent variable.\n",
"[This documentation](https://towardsdatascience.com/normality-tests-in-python-31e04aa4f411)\n",
"contains more information on the normality assumption.\n",
"\n", "\n",
"Let's display all three charts to show how our target variable, 'price' behaves." "Let's display all three charts to show how our target variable, 'price' behaves."
] ]
@ -2732,9 +2771,9 @@
} }
], ],
"source": [ "source": [
"import { threeChart } from 'https://l12.xyz/x/shortcuts/raw/plots.ts'\n", "import { threeChart } from \"https://l12.xyz/x/shortcuts/raw/plots.ts\";\n",
"\n", "\n",
"await display(threeChart(records, \"price\"))" "await display(threeChart(records, \"price\"));"
] ]
}, },
{ {
@ -2754,9 +2793,13 @@
"- Our target variable is right-skewed\n", "- Our target variable is right-skewed\n",
"- There are some outliers in the variable\n", "- There are some outliers in the variable\n",
"\n", "\n",
"The right-skewed plot means that most prices in the dataset are on the lower end (below 15,000). The 'max' value is very far from the '75%' quantile statistic. All these plots show that the assumption for accurate linear regression modeling is not met. \n", "The right-skewed plot means that most prices in the dataset are on the lower end\n",
"(below 15,000). The 'max' value is very far from the '75%' quantile statistic.\n",
"All these plots show that the assumption for accurate linear regression modeling\n",
"is not met.\n",
"\n", "\n",
"Next, we will perform the log transformation to correct our target variable and to make it more normally distributed." "Next, we will perform the log transformation to correct our target variable and\n",
"to make it more normally distributed."
] ]
}, },
{ {
@ -2783,9 +2826,9 @@
} }
], ],
"source": [ "source": [
"import { ShapiroWilkW } from \"https://l12.xyz/x/shortcuts/raw/shapiro.ts\"\n", "import { ShapiroWilkW } from \"https://l12.xyz/x/shortcuts/raw/shapiro.ts\";\n",
"\n", "\n",
"ShapiroWilkW(df.price.sort())" "ShapiroWilkW(df.price.sort());"
] ]
}, },
{ {
@ -2822,9 +2865,9 @@
} }
], ],
"source": [ "source": [
"let log2df = df.select(pl.col(\"price\").log())\n", "let log2df = df.select(pl.col(\"price\").log());\n",
"\n", "\n",
"await display(threeChart(log2df.toRecords(), \"price\"))" "await display(threeChart(log2df.toRecords(), \"price\"));"
] ]
}, },
{ {
@ -2851,7 +2894,7 @@
} }
], ],
"source": [ "source": [
"ShapiroWilkW(log2df.price.sort())" "ShapiroWilkW(log2df.price.sort());"
] ]
}, },
{ {
@ -3301,11 +3344,11 @@
], ],
"source": [ "source": [
"let carData = pl.readCSV(\n", "let carData = pl.readCSV(\n",
" await Deno.readTextFile('assets/cleaned_car_prices.csv'),\n", " await Deno.readTextFile(\"assets/cleaned_car_prices.csv\"),\n",
" { sep: \",\" }\n", " { sep: \",\" },\n",
")\n", ");\n",
"\n", "\n",
"carData.head(5)" "carData.head(5);"
] ]
}, },
{ {
@ -3658,8 +3701,15 @@
} }
], ],
"source": [ "source": [
"let carDataGeneralized = carData.drop('brand', 'carbody', 'enginelocation', 'stroke', 'compressionratio', 'peakrpm')\n", "let carDataGeneralized = carData.drop(\n",
"carDataGeneralized.head(5)" " \"brand\",\n",
" \"carbody\",\n",
" \"enginelocation\",\n",
" \"stroke\",\n",
" \"compressionratio\",\n",
" \"peakrpm\",\n",
");\n",
"carDataGeneralized.head(5);"
] ]
}, },
{ {
@ -3673,7 +3723,8 @@
"tags": [] "tags": []
}, },
"source": [ "source": [
"Next we use one hot (binary) encoding. We assume that all non-numeric colums are categorical." "Next we use one hot (binary) encoding. We assume that all non-numeric colums are\n",
"categorical."
] ]
}, },
{ {
@ -4408,10 +4459,10 @@
} }
], ],
"source": [ "source": [
"import { oneHotEncoding } from 'https://l12.xyz/x/shortcuts/raw/encoding.ts'\n", "import { oneHotEncoding } from \"https://l12.xyz/x/shortcuts/raw/encoding.ts\";\n",
"\n", "\n",
"let encodedCarData = oneHotEncoding(carDataGeneralized)\n", "let encodedCarData = oneHotEncoding(carDataGeneralized);\n",
"encodedCarData.head(5)" "encodedCarData.head(5);"
] ]
}, },
{ {
@ -4427,7 +4478,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"encodedCarData.writeCSV('assets/encoded_car_data.csv')" "encodedCarData.writeCSV(\"assets/encoded_car_data.csv\");"
] ]
}, },
{ {

View file

@ -7,3 +7,10 @@ export function residuals(x: pl.Expr, y: pl.Expr): pl.Expr {
const beta = xM.dot(yM).div(xMSQ.sum()); const beta = xM.dot(yM).div(xMSQ.sum());
return yM.minus(beta.mul(xM)); return yM.minus(beta.mul(xM));
} }
export const ScaleExpr: pl.Expr = (pl.all().minus(pl.all().min())).div(
pl.all().max().minus(pl.all().min()),
);
export const StdNormExpr: pl.Expr = pl.all().minus(pl.all().mean()).div(
pl.all().std(),
);

16
split.ts Normal file
View file

@ -0,0 +1,16 @@
import pl from "npm:nodejs-polars";
export function trainTestSplit(df: pl.DataFrame, size, ...yFeatures: string[]) {
let shuffle = df.sample(df.height - 1);
let testSize = Math.round(shuffle.shape.height * size);
let trainSize = shuffle.shape.height - testSize;
let [train, test] = [shuffle.head(trainSize), shuffle.tail(testSize)];
let [trainY, testY] = [train.select(...yFeatures), test.select(...yFeatures)];
let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)];
return {
trainX,
trainY,
testX,
testY,
};
}