diff --git a/notebooks/regressions.ipynb b/notebooks/regressions.ipynb
index f889c55..461a4d4 100644
--- a/notebooks/regressions.ipynb
+++ b/notebooks/regressions.ipynb
@@ -4,11 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Regressions\n",
- "\n",
- "$X = (0,1)$\n",
- "\n",
- "$Y = sin(2\\pi X)$"
+ "# Regressions\n"
]
},
{
@@ -29,8 +25,6 @@
],
"source": [
"// deno-lint-ignore-file\n",
- "\n",
- "import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\";\n",
"import pl from \"npm:nodejs-polars\";\n",
"import plot from \"../plot/mod.ts\";\n",
"\n",
@@ -249,6 +243,315 @@
"source": [
"lassoPoly.score(df.select('y').rows(), pl.DataFrame({\"py\":predLassoY}).rows());\n"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Logistic Regression\n",
+ "\n",
+ "Logistic regression is applicable for classification problems when there are linear relationships in the data.\n",
+ "\n",
+ "For example we'll use a simple linear pattern for predictions:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.dataresource+json": {
+ "bytes": null,
+ "data": [
+ {
+ "is_even": 1,
+ "x": 0,
+ "x2": -0.9906467224227306,
+ "x3": 0.00261179419023964
+ },
+ {
+ "is_even": 0,
+ "x": 1,
+ "x2": 0.00202764718654894,
+ "x3": 0.0013126936809724655
+ },
+ {
+ "is_even": 1,
+ "x": 2,
+ "x2": -0.9972687463767552,
+ "x3": -1.9905576758692132
+ },
+ {
+ "is_even": 0,
+ "x": 3,
+ "x2": 0.009024117668298463,
+ "x3": 1.0021275936361942
+ },
+ {
+ "is_even": 1,
+ "x": 4,
+ "x2": -0.9968599379688912,
+ "x3": -0.9947633704461447
+ },
+ {
+ "is_even": 0,
+ "x": 5,
+ "x2": 0.0019298798125454942,
+ "x3": -0.9913865823439642
+ },
+ {
+ "is_even": 1,
+ "x": 6,
+ "x2": -0.9967558047825875,
+ "x3": 0.00801993440722967
+ },
+ {
+ "is_even": 0,
+ "x": 7,
+ "x2": 0.007736464312311071,
+ "x3": 0.00034330022067074583
+ },
+ {
+ "is_even": 1,
+ "x": 8,
+ "x2": -0.9959077406033643,
+ "x3": -1.994011690184521
+ },
+ {
+ "is_even": 0,
+ "x": 9,
+ "x2": 0.002684616942261051,
+ "x3": 1.0072230674765243
+ }
+ ],
+ "description": null,
+ "dialect": null,
+ "encoding": null,
+ "format": null,
+ "hash": null,
+ "homepage": null,
+ "licenses": null,
+ "mediatype": null,
+ "path": null,
+ "schema": {
+ "fields": [
+ {
+ "constraints": null,
+ "description": null,
+ "example": null,
+ "format": null,
+ "name": "x",
+ "rdfType": null,
+ "title": null,
+ "type": "number"
+ },
+ {
+ "constraints": null,
+ "description": null,
+ "example": null,
+ "format": null,
+ "name": "x2",
+ "rdfType": null,
+ "title": null,
+ "type": "number"
+ },
+ {
+ "constraints": null,
+ "description": null,
+ "example": null,
+ "format": null,
+ "name": "x3",
+ "rdfType": null,
+ "title": null,
+ "type": "number"
+ },
+ {
+ "constraints": null,
+ "description": null,
+ "example": null,
+ "format": null,
+ "name": "is_even",
+ "rdfType": null,
+ "title": null,
+ "type": "number"
+ }
+ ],
+ "foreignKeys": null,
+ "missingValues": null,
+ "primaryKey": null
+ },
+ "sources": null,
+ "title": null
+ },
+ "text/html": [
+ "
x | x2 | x3 | is_even |
---|
0 | -0.9906467224227306 | 0.00261179419023964 | 1 |
1 | 0.00202764718654894 | 0.0013126936809724655 | 0 |
2 | -0.9972687463767552 | -1.9905576758692132 | 1 |
3 | 0.009024117668298463 | 1.0021275936361942 | 0 |
4 | -0.9968599379688912 | -0.9947633704461447 | 1 |
5 | 0.0019298798125454942 | -0.9913865823439642 | 0 |
6 | -0.9967558047825875 | 0.00801993440722967 | 1 |
7 | 0.007736464312311071 | 0.00034330022067074583 | 0 |
8 | -0.9959077406033643 | -1.994011690184521 | 1 |
9 | 0.002684616942261051 | 1.0072230674765243 | 0 |
"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "const clsDf = pl.DataFrame({ \n",
+ " x: new Array(100).fill(0).map((_, i) => i ),\n",
+ " x2: new Array(100).fill(0).map((_, i) => i % 2 - 1 + Math.random() / 100),\n",
+ " x3: new Array(100).fill(0).map((_, i) => i % 2 - i % 3 + Math.random() / 100),\n",
+ " }).select(\n",
+ " pl.all(),\n",
+ " pl.col('x').modulo(2).eq(0).add(0).alias('is_even'),\n",
+ ");\n",
+ "clsDf.head(10);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Notice that I left `x` in training data which is a continuos sequence that couln't be generalized by this model, however the model should guess a correct class in most cases. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "tags": [
+ "hide_code",
+ "parameters"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "const drawBin = (x1, y1, t = \"Example\") => {\n",
+ " const xTrue = clsDf.x.toArray().map((v, i) => v % 2)\n",
+ " const yTrue = clsDf.is_even.toArray()\n",
+ " y1 = y1.map((v, i) => v + i * 0.02)\n",
+ " x1 = x1.map((v, i) => v % 2)\n",
+ " return plot.DrawPlot(\n",
+ " { \n",
+ " title: \"\",\n",
+ " width: 2.5,\n",
+ " height: 2,\n",
+ " XLabel: \"X\", \n",
+ " YLabel: \"Y\", \n",
+ " }, \n",
+ " { type: \"scatter\", data: [xTrue, yTrue], lineDashes: [3, 4], glypRadius: 12, glyphColor: \"#00f\", glyphShape: \"ring\" },\n",
+ " { type: \"scatter\", data: [x1, y1], legend: t, lineDashes: [3, 4], glypRadius: 3, glyphColor: \"#f00\", glyphShape: \"ring\" },\n",
+ " );\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import {trainTestSplit} from \"../split.ts\";\n",
+ "\n",
+ "\n",
+ "const {testX, trainX, testY, trainY} = trainTestSplit(clsDf, 0.05, true, \"is_even\");\n",
+ "\n",
+ "const drawTestBin = () => drawBin(testX.x.toArray(), testY.is_even.toArray(), \"Test Data\");\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "![name]()"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "const logreg = regr.Logistic({\n",
+ " epochs: 5000,\n",
+ " learningRate: 0.001,\n",
+ "});\n",
+ "logreg.fit(trainX.rows(), trainY.rows());\n",
+ "\n",
+ "const predLogReg = logreg.predict(testX.rows());\n",
+ "\n",
+ "const yPred1 = predLogReg.map((x) => x[0]);\n",
+ "\n",
+ "drawBin(\n",
+ " testX.x.toArray(), \n",
+ " yPred1,\n",
+ " \"Predicted\"\n",
+ ");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "tags": [
+ "hide_code",
+ "parameters"
+ ]
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "![name]()"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "drawTestBin()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 0, 1, 1, 0, 1 ] should be [ 0, 1, 1, 0, 1 ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "console.log(yPred1, \"should be\", testY.is_even.toArray());"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\u001b[33m1.6635532311438688\u001b[39m"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "logreg.loss();"
+ ]
}
],
"metadata": {
diff --git a/plot/mod.wasm b/plot/mod.wasm
index b4c33f7..46a2fc4 100755
Binary files a/plot/mod.wasm and b/plot/mod.wasm differ
diff --git a/regr/main.go b/regr/main.go
index 4395ffd..05fe8c3 100644
--- a/regr/main.go
+++ b/regr/main.go
@@ -15,6 +15,7 @@ func InitRegrExports(this js.Value, args []js.Value) interface{} {
exports.Set("ElasticNet", js.FuncOf(src.NewElasticNetJS))
exports.Set("Lasso", js.FuncOf(src.NewLassoJS))
exports.Set("R2Score", js.FuncOf(src.R2ScoreJS))
+ exports.Set("Logistic", js.FuncOf(src.NewLogisticRegressionJS))
return exports
}
diff --git a/regr/mod.wasm b/regr/mod.wasm
index 55a3d54..c8299c7 100755
Binary files a/regr/mod.wasm and b/regr/mod.wasm differ
diff --git a/regr/src/LogisticRegressionJS.go b/regr/src/LogisticRegressionJS.go
new file mode 100644
index 0000000..5c11367
--- /dev/null
+++ b/regr/src/LogisticRegressionJS.go
@@ -0,0 +1,40 @@
+//go:build js && wasm
+// +build js,wasm
+
+package src
+
+import (
+ "syscall/js"
+)
+
+func NewLogisticRegressionJS(this js.Value, args []js.Value) interface{} {
+ var (
+ epochs = args[0].Get("epochs").Int()
+ learningRate = args[0].Get("learningRate").Float()
+ )
+ reg := &MCLogisticRegression{
+ Epochs: epochs,
+ LearningRate: learningRate,
+ }
+ loss := 1.0
+ obj := js.Global().Get("Object").New()
+ obj.Set("fit", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ X := JSFloatArray2D(args[0])
+ Y := JSFloatArray2D(args[1])
+ XDense := Array2DToDense(X)
+ YDense := Array2DToDense(Y)
+ reg.Fit(XDense, YDense)
+ return nil
+ }))
+ obj.Set("predict", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ X := JSFloatArray2D(args[0])
+ XDense := Array2DToDense(X)
+ Y := reg.Predict(XDense)
+ loss = reg.Loss(XDense, Y)
+ return MatrixToJSArray2D(Y)
+ }))
+ obj.Set("loss", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ return loss
+ }))
+ return obj
+}
diff --git a/regr/src/utils_js.go b/regr/src/utils_js.go
index 903470a..d7f5125 100644
--- a/regr/src/utils_js.go
+++ b/regr/src/utils_js.go
@@ -5,6 +5,8 @@ package src
import (
"syscall/js"
+
+ "gonum.org/v1/gonum/mat"
)
func JSFloatArray(arg js.Value) []float64 {
@@ -35,3 +37,26 @@ func ToJSArray[T any](arr []T) []interface{} {
}
return jsArr
}
+
+func MatrixToJSArray(m mat.Matrix) []interface{} {
+ r, c := m.Dims()
+ data := make([]float64, r*c)
+ for i := 0; i < r; i++ {
+ for j := 0; j < c; j++ {
+ data[i*c+j] = m.At(i, j)
+ }
+ }
+ return ToJSArray(data)
+}
+
+func MatrixToJSArray2D(m mat.Matrix) []interface{} {
+ r, c := m.Dims()
+ data := make([][]interface{}, r)
+ for i := 0; i < r; i++ {
+ data[i] = make([]interface{}, c)
+ for j := 0; j < c; j++ {
+ data[i][j] = m.At(i, j)
+ }
+ }
+ return ToJSArray(data)
+}