diff --git a/notebooks/regressions.ipynb b/notebooks/regressions.ipynb index f889c55..461a4d4 100644 --- a/notebooks/regressions.ipynb +++ b/notebooks/regressions.ipynb @@ -4,11 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Regressions\n", - "\n", - "$X = (0,1)$\n", - "\n", - "$Y = sin(2\\pi X)$" + "# Regressions\n" ] }, { @@ -29,8 +25,6 @@ ], "source": [ "// deno-lint-ignore-file\n", - "\n", - "import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\";\n", "import pl from \"npm:nodejs-polars\";\n", "import plot from \"../plot/mod.ts\";\n", "\n", @@ -249,6 +243,315 @@ "source": [ "lassoPoly.score(df.select('y').rows(), pl.DataFrame({\"py\":predLassoY}).rows());\n" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Logistic Regression\n", + "\n", + "Logistic regression is applicable for classification problems when there are linear relationships in the data.\n", + "\n", + "For example we'll use a simple linear pattern for predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.dataresource+json": { + "bytes": null, + "data": [ + { + "is_even": 1, + "x": 0, + "x2": -0.9906467224227306, + "x3": 0.00261179419023964 + }, + { + "is_even": 0, + "x": 1, + "x2": 0.00202764718654894, + "x3": 0.0013126936809724655 + }, + { + "is_even": 1, + "x": 2, + "x2": -0.9972687463767552, + "x3": -1.9905576758692132 + }, + { + "is_even": 0, + "x": 3, + "x2": 0.009024117668298463, + "x3": 1.0021275936361942 + }, + { + "is_even": 1, + "x": 4, + "x2": -0.9968599379688912, + "x3": -0.9947633704461447 + }, + { + "is_even": 0, + "x": 5, + "x2": 0.0019298798125454942, + "x3": -0.9913865823439642 + }, + { + "is_even": 1, + "x": 6, + "x2": -0.9967558047825875, + "x3": 0.00801993440722967 + }, + { + "is_even": 0, + "x": 7, + "x2": 0.007736464312311071, + "x3": 0.00034330022067074583 + }, + { + "is_even": 1, + "x": 8, + "x2": -0.9959077406033643, + "x3": -1.994011690184521 + }, + { + "is_even": 0, + "x": 9, + "x2": 0.002684616942261051, + "x3": 1.0072230674765243 + } + ], + "description": null, + "dialect": null, + "encoding": null, + "format": null, + "hash": null, + "homepage": null, + "licenses": null, + "mediatype": null, + "path": null, + "schema": { + "fields": [ + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x2", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x3", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "is_even", + "rdfType": null, + "title": null, + "type": "number" + } + ], + "foreignKeys": null, + "missingValues": null, + "primaryKey": null + }, + "sources": null, + "title": null + }, + "text/html": [ + "
xx2x3is_even
0-0.99064672242273060.002611794190239641
10.002027647186548940.00131269368097246550
2-0.9972687463767552-1.99055767586921321
30.0090241176682984631.00212759363619420
4-0.9968599379688912-0.99476337044614471
50.0019298798125454942-0.99138658234396420
6-0.99675580478258750.008019934407229671
70.0077364643123110710.000343300220670745830
8-0.9959077406033643-1.9940116901845211
90.0026846169422610511.00722306747652430
" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "const clsDf = pl.DataFrame({ \n", + " x: new Array(100).fill(0).map((_, i) => i ),\n", + " x2: new Array(100).fill(0).map((_, i) => i % 2 - 1 + Math.random() / 100),\n", + " x3: new Array(100).fill(0).map((_, i) => i % 2 - i % 3 + Math.random() / 100),\n", + " }).select(\n", + " pl.all(),\n", + " pl.col('x').modulo(2).eq(0).add(0).alias('is_even'),\n", + ");\n", + "clsDf.head(10);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that I left `x` in training data which is a continuos sequence that couln't be generalized by this model, however the model should guess a correct class in most cases. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [ + "hide_code", + "parameters" + ] + }, + "outputs": [], + "source": [ + "const drawBin = (x1, y1, t = \"Example\") => {\n", + " const xTrue = clsDf.x.toArray().map((v, i) => v % 2)\n", + " const yTrue = clsDf.is_even.toArray()\n", + " y1 = y1.map((v, i) => v + i * 0.02)\n", + " x1 = x1.map((v, i) => v % 2)\n", + " return plot.DrawPlot(\n", + " { \n", + " title: \"\",\n", + " width: 2.5,\n", + " height: 2,\n", + " XLabel: \"X\", \n", + " YLabel: \"Y\", \n", + " }, \n", + " { type: \"scatter\", data: [xTrue, yTrue], lineDashes: [3, 4], glypRadius: 12, glyphColor: \"#00f\", glyphShape: \"ring\" },\n", + " { type: \"scatter\", data: [x1, y1], legend: t, lineDashes: [3, 4], glypRadius: 3, glyphColor: \"#f00\", glyphShape: \"ring\" },\n", + " );\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import {trainTestSplit} from \"../split.ts\";\n", + "\n", + "\n", + "const {testX, trainX, testY, trainY} = trainTestSplit(clsDf, 0.05, true, \"is_even\");\n", + "\n", + "const drawTestBin = () => drawBin(testX.x.toArray(), testY.is_even.toArray(), \"Test Data\");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "![name]()" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "const logreg = regr.Logistic({\n", + " epochs: 5000,\n", + " learningRate: 0.001,\n", + "});\n", + "logreg.fit(trainX.rows(), trainY.rows());\n", + "\n", + "const predLogReg = logreg.predict(testX.rows());\n", + "\n", + "const yPred1 = predLogReg.map((x) => x[0]);\n", + "\n", + "drawBin(\n", + " testX.x.toArray(), \n", + " yPred1,\n", + " \"Predicted\"\n", + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [ + "hide_code", + "parameters" + ] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "![name]()" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "drawTestBin()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0, 1, 1, 0, 1 ] should be [ 0, 1, 1, 0, 1 ]\n" + ] + } + ], + "source": [ + "console.log(yPred1, \"should be\", testY.is_even.toArray());" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[33m1.6635532311438688\u001b[39m" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logreg.loss();" + ] } ], "metadata": { diff --git a/plot/mod.wasm b/plot/mod.wasm index b4c33f7..46a2fc4 100755 Binary files a/plot/mod.wasm and b/plot/mod.wasm differ diff --git a/regr/main.go b/regr/main.go index 4395ffd..05fe8c3 100644 --- a/regr/main.go +++ b/regr/main.go @@ -15,6 +15,7 @@ func InitRegrExports(this js.Value, args []js.Value) interface{} { exports.Set("ElasticNet", js.FuncOf(src.NewElasticNetJS)) exports.Set("Lasso", js.FuncOf(src.NewLassoJS)) exports.Set("R2Score", js.FuncOf(src.R2ScoreJS)) + exports.Set("Logistic", js.FuncOf(src.NewLogisticRegressionJS)) return exports } diff --git a/regr/mod.wasm b/regr/mod.wasm index 55a3d54..c8299c7 100755 Binary files a/regr/mod.wasm and b/regr/mod.wasm differ diff --git a/regr/src/LogisticRegressionJS.go b/regr/src/LogisticRegressionJS.go new file mode 100644 index 0000000..5c11367 --- /dev/null +++ b/regr/src/LogisticRegressionJS.go @@ -0,0 +1,40 @@ +//go:build js && wasm +// +build js,wasm + +package src + +import ( + "syscall/js" +) + +func NewLogisticRegressionJS(this js.Value, args []js.Value) interface{} { + var ( + epochs = args[0].Get("epochs").Int() + learningRate = args[0].Get("learningRate").Float() + ) + reg := &MCLogisticRegression{ + Epochs: epochs, + LearningRate: learningRate, + } + loss := 1.0 + obj := js.Global().Get("Object").New() + obj.Set("fit", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + X := JSFloatArray2D(args[0]) + Y := JSFloatArray2D(args[1]) + XDense := Array2DToDense(X) + YDense := Array2DToDense(Y) + reg.Fit(XDense, YDense) + return nil + })) + obj.Set("predict", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + X := JSFloatArray2D(args[0]) + XDense := Array2DToDense(X) + Y := reg.Predict(XDense) + loss = reg.Loss(XDense, Y) + return MatrixToJSArray2D(Y) + })) + obj.Set("loss", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + return loss + })) + return obj +} diff --git a/regr/src/utils_js.go b/regr/src/utils_js.go index 903470a..d7f5125 100644 --- a/regr/src/utils_js.go +++ b/regr/src/utils_js.go @@ -5,6 +5,8 @@ package src import ( "syscall/js" + + "gonum.org/v1/gonum/mat" ) func JSFloatArray(arg js.Value) []float64 { @@ -35,3 +37,26 @@ func ToJSArray[T any](arr []T) []interface{} { } return jsArr } + +func MatrixToJSArray(m mat.Matrix) []interface{} { + r, c := m.Dims() + data := make([]float64, r*c) + for i := 0; i < r; i++ { + for j := 0; j < c; j++ { + data[i*c+j] = m.At(i, j) + } + } + return ToJSArray(data) +} + +func MatrixToJSArray2D(m mat.Matrix) []interface{} { + r, c := m.Dims() + data := make([][]interface{}, r) + for i := 0; i < r; i++ { + data[i] = make([]interface{}, c) + for j := 0; j < c; j++ { + data[i][j] = m.At(i, j) + } + } + return ToJSArray(data) +}