From 5d38fb02f9cf282867ea179715309a6bedc030cb Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Fri, 27 Sep 2024 13:20:19 +0200 Subject: [PATCH] + plots, + enc --- deno_jupyter_bug.ipynb | 2 +- encoding.ts | 84 +++++++++++++++++++++++++++++++++++++++++- expr.ts | 11 +++--- plots.ts | 47 +++++++++++++++++++++++ 4 files changed, 136 insertions(+), 8 deletions(-) diff --git a/deno_jupyter_bug.ipynb b/deno_jupyter_bug.ipynb index 0141a2c..206445f 100644 --- a/deno_jupyter_bug.ipynb +++ b/deno_jupyter_bug.ipynb @@ -27,7 +27,7 @@ ], "source": [ "const document = null;\n", - "import * as p from \"https://l12.xyz/x/shortcuts/raw/plots.ts\"\n" + "import * as p from \"https://l12.xyz/x/shortcuts/raw/plots.ts\";" ] } ], diff --git a/encoding.ts b/encoding.ts index a248dd3..9e9a840 100644 --- a/encoding.ts +++ b/encoding.ts @@ -1,6 +1,6 @@ import pl from "npm:nodejs-polars"; -export function oneHotEncoding(dataframe) { +export function oneHotEncoding(dataframe: pl.DataFrame): pl.DataFrame { let df = pl.DataFrame(); for (const columnName of dataframe.columns) { const column = dataframe[columnName]; @@ -12,3 +12,85 @@ export function oneHotEncoding(dataframe) { } return df; } + +export function polynomialTransform( + dataframe: pl.DataFrame, + degree = 2, + interaction_only = false, + include_bias = true, +): pl.DataFrame { + let polyRecords: number[][] = []; + dataframe.map((X: number[]) => { + polyRecords.push( + polynomialFeatures(X, degree, interaction_only, include_bias), + ); + }); + return pl.readRecords(polyRecords); +} + +export function polynomialFeatures( + X: number[], + degree = 2, + interaction_only = false, + include_bias = true, +): number[] { + let features = [...X]; + let prev_chunk = [...X]; + const indices = Array.from({ length: X.length }, (_, i) => i); + for (let d = 1; d < degree; d++) { + const new_chunk: any[] = []; + for (let i = 0; i < (interaction_only ? X.length - d : X.length); i++) { + const v = X[i]; + const next_index = new_chunk.length; + for (let j = i + (interaction_only ? 1 : 0); j < prev_chunk.length; j++) { + new_chunk.push(v * prev_chunk[j]); + } + indices[i] = next_index; + } + features = features.concat(new_chunk); + prev_chunk = new_chunk; + } + if (include_bias) { + features.unshift(1); + } + return features; +} + +/** + * Adds missing rows at given interval, uses mean of previous and next value. + * Example for one feature: [1, 2, 4, 5] -> [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5] + * @param feature + * @param df + * @param bin + */ +export function augmentMeanForward( + feature: string, + df: pl.DataFrame, + interval = 100, +) { + let sorted = df.sort(feature); + let result = sorted.head(0); + let n: null | number = null; + for (let i = 0; i < sorted.height - 1; i++) { + let p1 = n ?? sorted.row(i).at(-1); + let p2 = sorted.row(i + 1).at(-1); + if (p2 - p1 > interval) { + let avg = (p1 + p2) / 2; + result = pl.concat([ + result, + pl.concat([result.tail(2), sorted.slice({ offset: i + 1, length: 2 })]) + .shift(-1) + .fillNull("mean") + .tail(1), + ]); + if (p2 - avg > interval) { + i--; + n = avg; + continue; + } + result = pl.concat([result, sorted.slice(1, i)]); + n = null; + } + } + return result; +} diff --git a/expr.ts b/expr.ts index 498b1dc..58dbef2 100644 --- a/expr.ts +++ b/expr.ts @@ -12,9 +12,8 @@ export const fillzero = ( value = 0.0001, ) => (pl.all().replaceStrict(0, value, pl.all())); -export const ScaleExpr: pl.Expr = (pl.all().minus(pl.all().min())).div( - pl.all().max().minus(pl.all().min()), -); -export const StdNormExpr: pl.Expr = pl.all().minus(pl.all().mean()).div( - pl.all().std(), -); +export const minmaxScale = (col: pl.Expr) => + (col.minus(col.min())).div(col.max().minus(col.min())); + +export const standardScale = (col: pl.Expr) => + (col.minus(col.mean())).div(col.std()); diff --git a/plots.ts b/plots.ts index 3ddda04..37f52ff 100644 --- a/plots.ts +++ b/plots.ts @@ -208,3 +208,50 @@ export function threeChart(data: any[], x = "column", opts = { width: 800 }) { }), }; } + +export function distPlot(...data: number[][]) { + const colors = [ + "red", + "blue", + "green", + "orange", + "purple", + "brown", + "pink", + "gray", + "black", + "cyan", + "magenta", + "yellow", + "lightblue", + "lightgreen", + "lightgray", + ]; + const plt = Plot.plot({ + margin: 50, + width: 1200, + marks: [ + Plot.hexgrid(), + data.map((a, i) => + Plot.areaY( + a, + Plot.binX({ y: "count" }, { + x: (x: number) => x, + fill: colors[i % colors.length], + curve: "catmull-rom", + fillOpacity: .5, + }), + ) + ), + ], + document, + }); + plt.setAttribute("xmlns", "http://www.w3.org/2000/svg"); + return { + [Symbol.for("Jupyter.display")]: () => ({ + "text/html": ``, + }), + }; +}