diff --git a/notebooks/regressions.ipynb b/notebooks/regressions.ipynb index f889c55..461a4d4 100644 --- a/notebooks/regressions.ipynb +++ b/notebooks/regressions.ipynb @@ -4,11 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Regressions\n", - "\n", - "$X = (0,1)$\n", - "\n", - "$Y = sin(2\\pi X)$" + "# Regressions\n" ] }, { @@ -29,8 +25,6 @@ ], "source": [ "// deno-lint-ignore-file\n", - "\n", - "import { display } from \"https://deno.land/x/display@v0.1.1/mod.ts\";\n", "import pl from \"npm:nodejs-polars\";\n", "import plot from \"../plot/mod.ts\";\n", "\n", @@ -249,6 +243,315 @@ "source": [ "lassoPoly.score(df.select('y').rows(), pl.DataFrame({\"py\":predLassoY}).rows());\n" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Logistic Regression\n", + "\n", + "Logistic regression is applicable for classification problems when there are linear relationships in the data.\n", + "\n", + "For example we'll use a simple linear pattern for predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.dataresource+json": { + "bytes": null, + "data": [ + { + "is_even": 1, + "x": 0, + "x2": -0.9906467224227306, + "x3": 0.00261179419023964 + }, + { + "is_even": 0, + "x": 1, + "x2": 0.00202764718654894, + "x3": 0.0013126936809724655 + }, + { + "is_even": 1, + "x": 2, + "x2": -0.9972687463767552, + "x3": -1.9905576758692132 + }, + { + "is_even": 0, + "x": 3, + "x2": 0.009024117668298463, + "x3": 1.0021275936361942 + }, + { + "is_even": 1, + "x": 4, + "x2": -0.9968599379688912, + "x3": -0.9947633704461447 + }, + { + "is_even": 0, + "x": 5, + "x2": 0.0019298798125454942, + "x3": -0.9913865823439642 + }, + { + "is_even": 1, + "x": 6, + "x2": -0.9967558047825875, + "x3": 0.00801993440722967 + }, + { + "is_even": 0, + "x": 7, + "x2": 0.007736464312311071, + "x3": 0.00034330022067074583 + }, + { + "is_even": 1, + "x": 8, + "x2": -0.9959077406033643, + "x3": -1.994011690184521 + }, + { + "is_even": 0, + "x": 9, + "x2": 0.002684616942261051, + "x3": 1.0072230674765243 + } + ], + "description": null, + "dialect": null, + "encoding": null, + "format": null, + "hash": null, + "homepage": null, + "licenses": null, + "mediatype": null, + "path": null, + "schema": { + "fields": [ + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x2", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "x3", + "rdfType": null, + "title": null, + "type": "number" + }, + { + "constraints": null, + "description": null, + "example": null, + "format": null, + "name": "is_even", + "rdfType": null, + "title": null, + "type": "number" + } + ], + "foreignKeys": null, + "missingValues": null, + "primaryKey": null + }, + "sources": null, + "title": null + }, + "text/html": [ + "
xx2x3is_even
0-0.99064672242273060.002611794190239641
10.002027647186548940.00131269368097246550
2-0.9972687463767552-1.99055767586921321
30.0090241176682984631.00212759363619420
4-0.9968599379688912-0.99476337044614471
50.0019298798125454942-0.99138658234396420
6-0.99675580478258750.008019934407229671
70.0077364643123110710.000343300220670745830
8-0.9959077406033643-1.9940116901845211
90.0026846169422610511.00722306747652430
" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "const clsDf = pl.DataFrame({ \n", + " x: new Array(100).fill(0).map((_, i) => i ),\n", + " x2: new Array(100).fill(0).map((_, i) => i % 2 - 1 + Math.random() / 100),\n", + " x3: new Array(100).fill(0).map((_, i) => i % 2 - i % 3 + Math.random() / 100),\n", + " }).select(\n", + " pl.all(),\n", + " pl.col('x').modulo(2).eq(0).add(0).alias('is_even'),\n", + ");\n", + "clsDf.head(10);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that I left `x` in training data which is a continuos sequence that couln't be generalized by this model, however the model should guess a correct class in most cases. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [ + "hide_code", + "parameters" + ] + }, + "outputs": [], + "source": [ + "const drawBin = (x1, y1, t = \"Example\") => {\n", + " const xTrue = clsDf.x.toArray().map((v, i) => v % 2)\n", + " const yTrue = clsDf.is_even.toArray()\n", + " y1 = y1.map((v, i) => v + i * 0.02)\n", + " x1 = x1.map((v, i) => v % 2)\n", + " return plot.DrawPlot(\n", + " { \n", + " title: \"\",\n", + " width: 2.5,\n", + " height: 2,\n", + " XLabel: \"X\", \n", + " YLabel: \"Y\", \n", + " }, \n", + " { type: \"scatter\", data: [xTrue, yTrue], lineDashes: [3, 4], glypRadius: 12, glyphColor: \"#00f\", glyphShape: \"ring\" },\n", + " { type: \"scatter\", data: [x1, y1], legend: t, lineDashes: [3, 4], glypRadius: 3, glyphColor: \"#f00\", glyphShape: \"ring\" },\n", + " );\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import {trainTestSplit} from \"../split.ts\";\n", + "\n", + "\n", + "const {testX, trainX, testY, trainY} = trainTestSplit(clsDf, 0.05, true, \"is_even\");\n", + "\n", + "const drawTestBin = () => drawBin(testX.x.toArray(), testY.is_even.toArray(), \"Test Data\");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "![name](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAPAAAADACAIAAAC1TXloAAAaF0lEQVR4nOydCVhTV9rHTxJIJOybgCiCYIkDMootbuNu1a+t49KOtNNPB2fQ1qJOfQBFRT9XqkMddUa62oIjo0NFXNCKSgUFC1ZFMYgJICiLZYssYclCcr+HXOYSQ5YbCCa5vr+Hx+eee97z8l755+Q9y73XAsMwRCFoNFJm1LpooBcLYwdgMJSlrEWvuBn+L8iaetAo0EM7OaGmpu6DIUNQZyepJvb2qLW1+4DNRu3tgxyfmVBaWvqf//yntbV148aNrq6uhnVeW1ubmJj4/Pnz8PBwf39/wzpXxuwFTXTMaq7j7Fl08mRP3bJl6L339GhrJly8eDEpKSk7OzsgIIDFYllYWEydOnXTpk0MBqMf3h4/fuzn58fj8fz9/c+fPx8eHs7lct3c3AwSqkwms7CwyMjImD9/vkEcqgczZxDq+VHDhg3YwYOYSNR9LBZjCQnY2rV9reh0zR7MhIyMDIRQfn4+hmHV1dW+vr4ffPBB/1yVlZUhhHg8HoZhpaWlUVFRYrFYi31RUVF1dTVJ511dXQihjIyM/sVGEvogflYGGbx/tbRU17+eOYNGjECffopYrO4ik4k++QRxOCg1VcVQJkNWVoj8aNLE8fT0DAsLS0tLE4lE/WhOp/fqwc/PLz4+nslkarGPjo6uqKgg6Zw2kP/i5mak+DzoxFwHhQUFPQcSibrqH35AiYmqJ1etQv/7v30Tj46OHjX/8EN3YmLu2NjYiMVigUBw9erVo0ePxsTE5OTk/PDDDxUVFY2NjevXr0cIiUSikJCQmJgYxX+gJDY2Ni8vz9vbu+u/onn+/PnBgwczMzNPnTo1fPhwhNC9e/e2bt1qbW3d1tYmEAjOnDlz5MiRS5cuSSQSPz+/r776qqOjIyoqqr6+nkajjRgx4sCBAzRad0IbHx9/7tw5b29vKysrvTUtk6G//Q3l5yMPj25NDxmC4uLQsGHamgxq/z8QsrOz6+rqNNXqyBOWLcMUX8DY5593rlqF7d+PVVZ2nwkN7adDU6WhoUE55VBc+jJPT0/82MHBISwsLDc3NywsDMOwhQsXrlu3DsOw9vZ2JyentLQ0RWq2ITAwUCqVqqQcPB4PIVRWVoZhWE1NjbW1NW6PYdj27dvLysqam5sRQrm5ufjJdevWLVy4EE8tAgIC/v73v2MYdvjwYXd39+bm5v6lHF0rV4pTUnrLT59i77yDaVaFSaccOTk59fX1aqscHJCOkRyDgf79bxQZiSZNqgkLQ9OmoZgYlJSkpQ1eo/UL1hQRCAT4sP7u3bvnzp1bs2ZNVlbWv/71L8Jg7ty5U6dOxWcY0tPTFy9erJjbYU+YMOHGjRtyufyrr75as2aNhYW27+rvvvvOwcFhyZIleHH16tXu7u4qNseOHcOdMxiMKVOm5OTkIIQSEhKWL19ub2/fn2u7fFns5SWcM6f3jJcXOngQ7dihpZFZphwtLbosJk9Ghw+j7Gx06ZLdnTto/Hj0zTdo3ry++YYKUqkBw3yp8Hi8rq6uBQsW7N+/387Orq9BZWUlQujQoUO43FtaWiQSSU1NTWdnp7e3t3bnJSUlPj4+RNHT0xP3QJxpbm5ubW1NTk7GdVxeXu7g4CCXyysqKnQ618i5c2JFgvQCfn6otlZLI8MIur29ff/+/YsXLw4ODlapEolEe/bswb9utm/fzmazDfIbdUy0FRZ2f5onTEAbN3YGB6POThQSgnx9EZer3af5Dg3/+Mc/Tpo0SYsBSzE+jo6OnjZtGnGyoaFBZSyoFjabLRaLdTpfuXLl8uXLiZMYhjGZTJ3ONSIUYvh3sQqWlloaGSDluH79+q5du77++usmfHnjRVatWvXaa6/t27ePw+H85S9/GfivI6W52lpEp6PPPkNXrzomJaFLl9Du3cjWFtXVGca/GeLl5cVisYqKipRPurq6Ojs7379/X3vbsWPHPnr0CE+aVcATHisrK09PTxXnNBqNw+HodK6RwEALYuxP0NWlYR5AKSCDwOFwMjMzVU42NTWxWCx8TNDa2spiserr60k63L17N5fL7Xue1OjN1xd7/Bg/xIc43Tx5gvn4aG9ndkNDHo934cIFfMjRt9bBwSE5OZko/vWvf3Vycjp16lR5efm9e/eePn2KYVhcXJyrq+uFCxd4PF5CQoLaQWFHR4evr++8efNu3rxZXl6elZVVXV0tl8tdXV03b95cWlra0dFx8OBBFov13XfflZWVcblcPp+PYdiJEyfYbHZycnJJScnp06f1GxTW10v+538Ez569cPLzz7FTp7Q0YuzQmmKTJyEh4a233ho1apTySS6Xm5ycvHPnTvxb6dChQ2+++aaXlxcZhzdu3OBwOEOHDlU5v3Mn8vDoHu9pIzERBQQgxRKrQCBwcXHpPpmXh3Jz0ccfa2mXkoIaG7WPOkyLjIyMvLy8kJAQHo9XWloaEhJCVH3//ffDhg1rbm5mMBh4Ijt//nxXV9crCjo6OkJCQths9rRp09zc3C5evJiXlzdz5kxnZ+fKykpHR8fU1NRx48aVlpZ6eHi4u7uvXLmyvr4+PT392rVrikHK5CFDhnh4eOTk5FRVVU2dOnXGjBljxozJysq6ePGiQCAICQmxt7cfO3ZsYGBgZmbmTz/9FBgYOGrUqF9//TUoKMgKn/zXjrV1u4uLdUwMw9kZ2digx49RXByyskIffaStFdmPiy7U9tA//viju7s7URw5ciQx9aMTLT20bpYu7f65dau3hy4owBYtwpYs0dnU7HpoY4cwiDQ2Nj4vL8e+/BL79FNs1y6sqEhnk8Gd5bC1tVUeTIjFYltbW6KYmZmpNi3DwedE+8m4cWjyZHT6NNq3z9XCQtLRIffykoaFWeTldQgEuho7C3TbmApCodCMotWXpqYmR0dH7V+qKgyuoEeMGCEUCru6uiwsLGQyWXNzM77sNOhs2IA++AAdOYJGjKjNz3cfPZpWV2cTGSn85puX8dsB4zEogsYw7ObNm5MnTx45cmRQUNDdu3cnTpx4586dUaNGcTgcwmzu3LlanOCDkn5iY4OOHkWbNiEGY6iDg1NrKxKJUGKik/ZV0//i7Ozc/1/9cmlsbDSjaF8CBhD07du3CwoKBAJBWlpafX39smXLSkpKpk2bdvv27ddff/3o0aObN29euHBhWlpaYt/9Ff1i2DD07JkuIzc3lJSEBIL23FyXKVMQuQ2+Y8YYJEDAaBhgP7RYLO5S2gllbW2NEKqoqCDWlvBtjcOHD9drb8qePXsWL14cGBioGrGeN5vw+XzyO8rN7k4Wva7O7MCHB3p9BRmgh2YpUDmpvFKKb78a+C8CAJ2Y7uYkTQx292lG3TPQF/MTNM5gLFBTddH7lcIsBd2/3Yhk0LrvBTADzFLQ+GqMYTtU3Jv2fS+A6WOWglbsZ0cG1DTuR3HnB2DemNAG/zNnzqSkpBDFqqoq/A4ItRD7ri0sSN49qRFin8yg3l0PvBxMSNBLFBDFPXv2aLfH9+PLZN3/9ntqgk7vaQuTG9TAXFMOHEKF/cs9iE8CqJkymLegcS06OiJcnX2WdzRiY9PzGbC2BjVTChNKOfrN8+cIF7REonvtmuQzHQEzhQqCxsHVqfxw0f+iZqsDSJmqmH3KoQKG9fz03dDi6NhbC1AVqgmaoLGxR7s8Hh8/wDMTgNpQVtDAqwkIGqAUJjQolEqlynfUSs33sVyA8TAhQV+5cuXMmTNE8fHjx3/4wx+MGhFgfpiQoN9WQBR1Ln0DQF8ghwYoBQgaoBQgaIBSgKABSgGCBigFCBqgFCBogFKAoAFKYUILKwUFBb/88gtR5HK5Wm6SBQC1mJCgnZ2dlR+2W1hYaNRwALPEhAQ9UgFRzM3NNWo4gFkCOTRAKUDQAKUAQQOUAgQNUAoQNEApQNAApQBBA5QCBA1QChNaWNHr+dAAoJYeQd+9e3fChAnGDUXf50MDQF96Uo6DBw8O/A2cAGB0egT9888/z549G/YDAeZOj6DDw8O/+OKLmJiY9evXt7S0qBjl5eUZIzYA0JueHHrLli0IoUuXLqWlpc2bNy8iImLFihV4lUgkunTp0uTJk7V4SU5OfvjwIYfDWbFihcoLvZOSkp48eYIfczic999/f9CuBQD6TNstXbr08uXLe/fu9fHxCVQwbtw47el1bGxseXl5XFxceXn55s2bVWovXLgwCGEDgHpUp+2ysrI+/vjjkpKSOXPmODg4IISUH6DYl87OziNHjvB4PBqNFhER4evrGxsba2NjQxi4uLjs2LFj0OIHgBfo6aH/8Y9/NDQ0/OlPf5o9e7ZUKr18+XJmZmaqgrS0NDpd4/oLn8+XSqXu7u4IoaFDh1paWj58+PAlxg8AL9DTQyclJe3cubOlpSUyMnLXrl1sNpuwsLS0/PDDDzW1r6urU+6PbW1t6+rqlA0sLS03b97c3Nzc2Ni4Zs2a2bNnD86FAADqFfSjR4/GjBnz7bffql1eee211zS1ZzKZMpmMKHZ1dVm++AL4f/7zn/jBgwcPJk+eXFZW5uHhgZ+pqqrSks/0nWwBAJ30CPrtt99OSUlhMBj6tvfw8BAKhRiG4ZMbLS0thF5VCAoKsrGxKSwsJAzOnz//66+/avJcXFysbzBqEQqFfD7fIK5MkJqaGmOHMIgIhUIfHx/92mAKsrOzsX4hl8t9fHyKi4sxDOPxeJ6enjKZDMOwlpYWDMO6urqOHz+OW0okEjabXVRURNLz7t27uVxu/6JShsfjDdyJyULtq2tUoFeTntHejBkz+vcZotFoBw4ciI2NvXnzZnR09IEDB+h0Op/Pd3BwKCgoYDAYhw8fzszMzMvLW7169Z///OeAgID+/SIAIIMBdtstWbJkwoQJhYWFR44c8fLyQgj5+vqePn06MDAQIZSSklJRUVFbWxsTE+Pvr+YdmABgQAyzfdRLQa9TCwti39woBQb5LQCgE9jgD1AKEDRAKUDQAKUAQQOUAgQNUAoTukn2xo0b165dI4p3796Fm2QBfTEhQQcEBAwdOpQoCoVCo4YDmCUmJGhnBUTR0dHRqOEAZgnk0AClAEEDlAIEDVAKEDRAKUDQAKUAQQOUAgQNUAoQNEApTGhhBZ4PDQwcExI0PB8aGDiQcgCUAgQNUAoQNEApQNAApQBBA5QCBA1QChA0QClA0AClMKGFlZqamqqqKqJYW1tr1HAAs8S0BJ2fn08UQdBAPzAhQYcoIIptbW1GDQcwSyCHBigFCBqgFCBogFKAoAETYuZMRKP1/ri4OLu4OCufmTNHhwcTGhQCrzIvviEe4S/jFggE+CO1CINr13oONL2tG3powMj4+/eqGcN6fvqiUkWjIbXvn4IeGjAmylImCW5Jo6Hi4u5/VRpCDw0YDVzNTKYeaibAMGRhgfrmKibUQ1++fDk9PZ0o8ng8uEmWwuBCnD8fZWRosHj+nJmZ2X3w5pvIyalvvVSKZs1C2dkv9NMmJOjp06crv2n80KFDRg0HGERwNXt7a1AzhqE9exCfT584sbu4fn13oh0bq9obI5SVhby8UFVVr6ZNSNBWCojikCFDjBoOMOhUVGio2LkTjRmDtm0TKWY5rNetQykpaNcu9H//19e2svIFnUMODbxstM+7ocZGVF6OQkNfOBkaisrKuqvUQQwTQdCAcRg/XnPdL7+g2bPVnJ81C925o6kRMYVnGEEXFxcfP36cy+X2oxZ4pcD70YICA7stKupxbgBBHz9+/LPPPgsODo6Li0tMTNSrFgBUCQlBSi9D6yUrC73+us7WNKwfc4BKyOVyLy+v7OxsPz+/8vLyKVOmVFdXW1hYkKnVzp49exYvXhwYGDiQ8BBCfD7f399/gE5MFrO7ur5LIWrYsQPZ2aHCQklLC0KIaW+PgoJQW5vaQaGK84H20I8ePWpsbPTz80MIjRo1qqOj48GDByRrgVcNOkm5vfkm+v571NkpnTNHOmcO6uhAiYlo7lwyTQc6bVdbW2tra0sU7ezslG+d0l4LvGqQyga6ulBcXHeW3doqx3OPL79Etrbo3XfRuXOIwdDeeqCClsvldKXPHZ1Ol8lkJGsBQA23bqHZsxGTiVxcJPhuUfztlTNndldNmaK99UAF7ebmpvzK19bWVjc3N5K1MTExT5480eTZ0tLS09NzgOEhhFgsFp/PH7gf06S9vd2srs4fz/u1WNjev0+j01sVNlKpdPTo0T0Vw4Zpmod+AWxgiMViZ2fnyspKDMOqq6vt7Ow6OjpI1gKvGghhuhVXWIjFxqo5HxuLcbk6/Q90UMhkMjdu3Hj48OHOzs74+PioqCgrK6vKykoOh8Pn89XWDvA3AhQnKAg9fIhUxlq1tai4GJGZ8hrA562X8+fP79ix49y5c3hRKBRu2bKlqalJbS3wyuLtTaKHxjCsshJbuBC7eBGTSrt/Ll7E3nkHq6rS2Q4hbKDz0ACgF6TmoRFCnZ3o6FF082b38dSpKDwckfhup9EGvLACAHqhY2fSgD3D5iSAUoCggZeK8lZPA0J0/CBogFKAoIGXjcE7aeW8HAQNGIEFC5ChNI07wR0aYPsoAPSPfjyRg4wT6KEB46D8DKT+ofYjAYIGjAaGITs7hEtz82Y9Gn76aY+aXVxUO3hIOQDjQ/S1q1ahb77RZhkWho4d6zlWq1zooQHjg2Fo797ug2+/7X1y7qRJPbWTJvWexNW8f7/GzBt6aMC0oNM1i5WG5HIdzUHQAKWAlAOgFCBogFKAoAFKAYIGKAUIGqAUIGiAUoCgAUoBggYohQm9kmKQOHv2bBH+9GASCIVC5YfxaUEul4tEIjabTcZYKpXK5XIWi0XGWCQSWSggY9ze3s5ms2nkdqyRvzp9jdva2mxsbMhYymQyiURC8vEsEokkOjqafBivxErhkiVLdu/eTdI4Ojo6Pj6ejOWzZ89SUlI2bNhAxjg/P7+mpubdd98lY3zy5EkOhzNe20Pue9m7d29ERISDgwMZY/JXp5exRCLZuXPnXnw3hi7Kysp++umnjz76iIzx1atXR4wY8d5775ExxqF+D81iscg/ZNre3p6ksY2NjaurK0lj/JmrJI09PDx8fHxIGjs7O3M4HBcXFzLG5K9OL2ORSOTo6Ejec1FREUljHo9H0icB5NAApQBBA5QCBA1QC4M8hM+UCQ0NJW/85MkTkpZisfjZs2ckjYVCoUAgIGnc0NDQ1tZG0jgiIqKhoYGkMfmr08tYLpc/ffqUpHFnZ2dtbS1J41MKSBrjUH9QqBcjR44kaclkMj08PEgak5zSwiE5wusH5K9OL2Majebl5UXSeIgC8mHoC6QcAKUAQQOUAgQNUArq59BLly41dgiDyIIFC0guv5sjY8aM0bcJ9Ze+gVcKSvXQzc3NsbGxNjY2zc3NcXFxTk5O5GtNH7lcvmPHDolE0tDQsH79+t/+9rfKtV9//TXxKnUHB4eMjAwjhTkgUlNTs7KyEhIS+ladOHHi9u3bYrF4+vTp77//vkYXek3ymTiLFi06ffo0Pn/51ltv6VVr+sTHx2/YsEHxSp1KT09PkUikXLt169aMjIw8BXfu3DFemP2kra0tLi5uxYoVM2fO7Ft769atN954QyaTSaXSoKAgLRdIHUE3NDRYWlriSxLt7e1MJlN54UN7rVkQEBBw5coV/PiNN95ITU1Vrt26dSv5FRaTJTk5Wa2gIyIitmzZgh9HR0evXbtWkwfqzHKUl5dbWVlZW1sjhNhstr29fVlZGclas6C0tNTd3R0/Hj58eElJiYpBe3v7kydP2trajBHd4FJSUqL92gmok0MLBALl8b61tXWj0pt0tdeaPkKhUCKREJfQN/7f/e53x44ds7Ozu379Oo1GO3nyJMn7CcwC5T+f9r8ddQRtb28vEomIYmdnp729Pcla08fa2prBYBCX0Df+BQoUj5r9NDg4+NixY6tXrzZSsIZH+c+n/W9HnZRj5MiRbW1tYrEYv+WpqalJeTeC9lrTh06ne3l5ET1TXV2dt7e3soFMJiOOR48eXVlZ+dJjHES8vb21XLsy1BG0p6fnxIkT8/LyEEK5ublBQUG+vr4ymSw1NVUikaitNXbI+rFs2bKcnByEUGtra3Fx8e9//3uE0OXLl+vq6hBCW7duxc0wDHv48GFISIix4zUADx48uH//vvK1I4SuXbsWGhqqqQmlFlYeP34cFRU1duzY+/fvx8fH+/v78/n83/zmN3fu3Bk/fnzfWmPHqx/t7e2rVq3y8/MrLi5evnz5okWLEEJubm7btm1bu3ZteHi4jY2Ns7Mzl8sNDAzcvn27sePVj9bW1hMnTly/fv3+/furV6+eMWNGcHBwaGgo3iUhhGJjY/HvWFdX1127dmnyQylB40gkEiaTSRRVbkhWqTU7VOJvb2/Hp27UFs0dqVSKYRhxvTKZjEaj0ena0goKChp4laFODg0AIGiAaoCgAUoBggYoBXVWCl8Rmpqa9u3bV1JSwmAw/P39o6KiHB0dExISfv75ZxaL9eGHH86ZM8fYMRqVwdw7BQwW4eHhDAYjPz8fL9bV1U2YMKGurs7YcRkfmLYzS4RC4dixY62srO7du8dkMt9+++1NmzbNnDnT2HEZHxC0uXLt2rW5c+dGRUXhT5vdtm2bsSMyCUDQZsy6deu++OKLWbNmXblyRfv62asD/C+YMZGRkRiGPX36VHln7CsOCNpckcvln3zyyf79+8vLyzdt2mTscEwFmLYzV+Li4qZPnx4dHV1fX3/gwIGlS5fOmjXL2EEZH8ihzZLs7OzPP/88PT2dRqOJRKLg4OCOjg4ul6vX60goCaQc5kd9fX10dHRSUhL+rqAhQ4YcO3asuro6MjLS2KEZHxC0mXH27Nk1a9YMGzbsxx9/xM9IJJL09PSpU6feu3cvMjKyurra2DEaE0g5AEoBPTRAKUDQAKUAQQOUAgQNUIr/DwAA///NBL6CqzON9AAAAABJRU5ErkJggg==)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "const logreg = regr.Logistic({\n", + " epochs: 5000,\n", + " learningRate: 0.001,\n", + "});\n", + "logreg.fit(trainX.rows(), trainY.rows());\n", + "\n", + "const predLogReg = logreg.predict(testX.rows());\n", + "\n", + "const yPred1 = predLogReg.map((x) => x[0]);\n", + "\n", + "drawBin(\n", + " testX.x.toArray(), \n", + " yPred1,\n", + " \"Predicted\"\n", + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [ + "hide_code", + "parameters" + ] + }, + "outputs": [ + { + "data": { + "text/markdown": [ + "![name](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAPAAAADACAIAAAC1TXloAAAZiklEQVR4nOydCVRT19bHd8KkzAjUCUEEJD4pr0LV0larYq2v1gkH7HM5tI4Vi7IABUU/B0Rd4OvDp/ZZWZ+g/bRYxbF14ilWLVQBRYQSBESmgkABQ8ic+y1yeTEmIdxAMMl1/xaLdc85+5y7D/xzsu85595rShAE0AgGg5IZvTqNvMRU3w7oDEUpa9AraUb+RlnTD6a+HdAB/fp1CLRPn3aNapYpaWBrC6Ssraxel5fIa4Fh7CGHfGBW049z5+DkyY6y+fNh7lwt6hoJLi4uw4cPt7GxEQqF9+/fZ7FYdnZ2PB6vvLz8xx9/9Pf3716zWVlZSUlJ58+f9/DwcHR0bGpqMjU1HTZs2Jo1a8aMGaPrTugUwpgB6PhRQ1gY8c03BJ/ffiwQEAcPEmvXqloxmZ23YAwEBweTB/n5+QBw8uRJMhkfH5+ZmUmlhevXr0ulUtX8J0+eAMDx48fJZHNz84YNG8zMzNatW6e5wXv37jU1NWnZD51hxCEHOb6amakbX8+ehSFDYP16sLBoT5qbw5o1wGLB6dNKhhIJ9O0L1K8mDY2FCxeqzQ8MDBw0aFCX1QUCweLFi6VSaZeWdnZ2e/fujYmJSUxM/OmnnzRYrl69urGxscsGtaO5GcRiSpb6+iT1kJwcjSPrggUEj6ecKRAQ8+apNSebSk3VtZevEaURurW1ddWqVXNlREREkGPwnTt3JkyY8Pe///2TTz5JSUmpqakJDg4GgPnz5yckJCg1qDRCkwgEAltb2xkzZpDJqVOnzp0799NPP502bVpDQwOHw1mzZg0ATJs2bePGjQRBXLp0afLkyYsWLfL39z9w4IDWvRKLibg4YsYMYtUqIjiYWLKEqK7WXMNwBZ2RkVFXV9dZaRdxwvz57b+rqoiEBN6KFcTevURFBSH7hu5mg4ZKfX09eaAk6DVr1syePVsmCTGLxdq/fz9BEF5eXmlpaQRBNDU1HT58mCCIK1euAIBYLFZtWa2gCYIYPXr08OHDyeOoqCjywM3NLTY2liCIoqIiACgpKSHzk5KSyOOzZ88ymcyWlhateif+4guB4jDz7Bnx2WdE56ow6JDj9u3bz58/V1tkbw9dXMmZmMD//R+Eh8N771UvXQrjxkFUFCQna6hDlpib68Dz10lnX+4pKSmzZs2S/SVMAgICbt++DQBcLjc9PV0kEtnb269cubJ7Z3R2dpafdPfu3eQBi8Wqra1VNV62bJmHhwdpIJVK6+vrtTjT1asCV1dOYODLHFdX+OYb2LZNQyWjnIduaenKIiAAEhMhIwMuX7bNzoZRo+C772DKFNWJDiVEIh26qTcaGhq4XO6xY8du3rwJAKWlpc7OzgAQFxcXEhJy/vz5RYsWRUdH25KTl1pSU1NDalQqlf773//Oy8uztrZms9menp6qxqWlpQcPHhSLxa2trWR8q8WZzp8XhIYqZ3p6grpPjhzdCJrL5e7du3fWrFl+fn5KRXw+PzY2ViyL6Ldu3WppaamTM3bxl8nLa/80+/vDhg08Pz/g8WDMGPDwANn3soY2jfTSUAkL2aXwsmXLPv/8c8X8JUuWTJ8+/cSJEzt37szOzr5+/bq2LXM4nJKSEvJKNCoq6sGDB5cuXbKwsCgoKFA15nK5o0ePTkpKCgoKYrPZR48eZWj19+VwCPK7WAkzMw2VdBBy3Lp1a8eOHYcPH25qalItXbFixfDhw/fs2cNisZYtW9bz01H6m9TWApMJu3fD9esOyclw+TLs3Ak2NlBXp5v2DRsbG5v+/fs/fvxYKT8nJ6dfv35r1649fPgwGYSQqB04SfEpSTAxMREAYmJiAODOnTtjxowhPzxKkA0WFBQ0NTWNGzeum93w8THNzVXOFItBKNRQSQcj9EcyLly4oFrU3Nz8448/HjhwAADmzJmzevXq+vp68uuvdykqgmvXYNgwmDWrhs329vZuz/T3h4kTe/3UeoL8DhT/d24rIiJiy5YtXl5e48aN4/F4FhYWXl5eYWFhaWlpdnZ29+7dCwgIAAAyeDh37pyPjw+LxVJtsLq6urS01MTEpLS09MyZM1euXElLS3NxcSHrXrx48bPPPmtsbHz27BkZcri4uFhYWFy6dOlvf/ubi4uLiYnJ/v37g4KCLl++rHWXvvyy75Il4lGjXslMTIROZipJevei8MmTJ2ZmZnZ2duSwQQZbPW924MCuLGxsQGV8gsJC6CpqfPV/ajQkJSUdO3Zs3bp1OTk5u3btIgV97Nix27dvR0dHp6Wl9ZVNts+YMWP9+vVLly5lMpmnTp2SRaSeW7du/eGHH5QEl5WVdeTIkdDQ0MrKyn/84x8JCQk5OTlTpkwpKCj4+OOPSZuEhIQPP/zw0KFDHA4nNjbWysoqOzvbysoqISHhl19++eGHH5ydnU+dOlVcXHzs2LF58+aFhYWdOnVKRP0yxdmZt26dzRdftH/B/vEHPHwIoaEgEHRxIaTVNIoGWCxWenq6UubPP/88YMAAedLNzY2cNqLCzp078/PzVfMpuRwU1P7z22/kRFJ7Tm4uMXMmIZvJ0oxxTd519I6mNDQ0/FlWRnz7LbF+PbFjB/H4cZdVeneWw8bGRiAQyJMCgcDGxkaeTE9Pb25u7qxuSUlJ90/8zjsQEABnzsCePc6mpsK2Nqmrq2jpUtPMzLauF7Ecdb/Q1WtwOBwj8lZbmpqaHBwcYPVq6lV6V9BDhgzhcDhisdjU1FQikTQ3N5PhV68TFgaffw4HDsCQIbVZWQO8vBh1ddbh4ZzvvnsdZ0f0R68ImiCIu3fvBgQEuLm5+fr65uTkjB07Njs7e9iwYYpXHpMnT9bQCLnm1E2srSEpCTZuBBOTt+zt+714AXw+HD3aj8L2hvYh2tGx+6d+vTQ0NBiRt68BHQj6/v37ubm5jY2NaWlpz58/nz9/fnFx8bhx4+7fv//uu+8mJSVFR0dPnz49LS3t6NGjuvAZBg2CmpqujPr3h+RkaGzk3rnj9P77QG1qZcQInTiI6A0d7IcWCARihZ1QVrI980+fPnV3dydzZLsqqlxcXLSaV4+NjZ01a5aPj4+yx1rebMKWT9tRwOjuZNGqd0YHeXmg1VeQDkZoCxlKmXI1k5PzQ4YM6fmJEKRLDHdzUmf09vBpRMMzoorxCZqkNxaoabDojRiloGUrj72Cxn0viBFglIImV2N0O6CSrWnc94IYAUYpaNnGMdChpsl2ZHdvIMaNAW3wP3v2bGpqqjxZWVlJ3nOhFvm+a1NTqndPdgZ5kywAfPJJj9pBDAEDEvRsGfJkbGysZntyP75E0v6721MTTGZHXZzcoAfGGnKQyFXYvdhD/klANdMG4xY0qUUHByDVqe7mCfVYW3d8BqysUM20woBCjm7z559ACloo7HrtmuIzHREjhQ6CJiHVqfhw0f+iZqsDSpmuGH3IoQT5cFGCANUNLQ4OL0sRukI3QctpaOjQblERmzwgIxOE3tBW0MibCQoaoRUGdFEoEokU76jV4n53BPkvBiToa9eunT17Vp4sLS2dN2+eXj1CjA8DEvQ0GfJkl0vfCKIKxtAIrUBBI7QCBY3QChQ0QitQ0AitQEEjtAIFjdAKFDRCKwxoYSU3N/fevXvyZH5+voabZBFELQYkaEdHR8WH7ebl5enVHcQoMSBBu8mQJ+/cuaNXdxCjBGNohFagoBFagYJGaAUKGqEVKGiEVqCgEVqBgkZoBQoaoRUGtLCi1fOhEUQtHYLOycnx9/fXryvaPh8aQVTpCDm++eabnr+BE0H0Toegf/3110mTJuF+IMTY6RD08uXLDx06FBUVFRoa2tLSomSUmZmpD98QRGs6YuhNmzYBwOXLl9PS0qZMmRISErJ48WKyiM/nX758OSAgQEMr33//fUFBAYvFWrx4sdILvZOTk8vLy8ljFou1YMGCXusLgqhM2wUFBV29enXXrl3u7u4+Mt555x3N4XVMTExZWVlcXFxZWVl0dLRS6aVLl3rBbQRRj/K03c2bN1evXl1cXBwYGGhvbw8Aig9QVIXH4x04cKCoqIjBYISEhHh4eMTExFhbW8sNnJyctm3b1mv+I8grdIzQ+/fvr6+vX7JkyaRJk0Qi0dWrV9PT00/LSEtLYzI7XX9hs9kikWjAgAEA8NZbb5mZmRUUFLxG/xHkFTpG6OTk5O3bt7e0tISHh+/YscPS0lJuYWZmtnDhws7q19XVKY7HNjY2dXV1igZmZmbR0dHNzc0NDQ1fffXVpEmTeqcjCAIvBf3777+PGDHiyJEjapdXhg8f3ll9c3NziUQiT4rFYrNXXwD/r3/9izx49OhRQEBASUnJwIEDyZzKykoN8YzqZAuCdEmHoKdNm5aammpiYqJt/YEDB3I4HIIgyMmNlpYWuV6V8PX1tba2zsvLkxtcuHDhjz/+6KzlwsJCbZ1RC4fDYbPZOmnKAKmurta3C70Ih8Nxd3fXrg4hIyMjg+gWUqnU3d29sLCQIIiioqLBgwdLJBKCIFpaWgiCEIvFx48fJy2FQqGlpeXjx48ptrxz5878/PzueaVIUVFRzxsxWOjduwYZWlXpuNr76KOPuvcZYjAY+/bti4mJuXv3bmRk5L59+5hMJpvNtre3z83NNTExSUxMTE9Pz8zMXLly5Zdffjly5MjunQhBqKCD3XazZ8/29/fPy8s7cOCAq6srAHh4eJw5c8bHxwcAUlNTnz59WltbGxUV5e2t5h2YCKJDdLN91FXGy0ZNTeX75obJ0MlZEKRLcIM/QitQ0AitQEEjtAIFjdAKFDRCKwzoJtlffvnlxo0b8mROTg7eJItoiwEJeuTIkW+99ZY8yeFw9OoOYpQYkKAdZciTDg4OenUHMUowhkZoBQoaoRUoaIRWoKARWoGCRmgFChqhFShohFagoBFaYUALK/h8aKTnGJCg8fnQSM/BkAOhFShohFagoBFagYJGaAUKGqEVKGiEVqCgEVqBgkZohQEtrFRXV1dWVsqTtbW1enUHMUoMS9BZWVnyJAoa6QYGJOgxMuTJ1tZWvbqDGCUYQyO0AgWN0AoUNEIrUNCIATFhAjAYL3+cnBydnBwVcwIDu2jBgC4KkTeZV98QD+TLuBsbG8lHaskNbtzoOOjsbd04QiN6xtv7pZoJouNHFaUiBgPUvn8KR2hEnyhKmSKkJYMBhYXtv5Uq4giN6A1SzebmWqhZDkGAqSmoxioGNEJfvXr14sWL8mRRURHeJEtjSCF+8glcudKJxZ9/mqentx98/DH066daLhLBxImQkfHKOG1Agh4/frzim8b/+c9/6tUdpBch1Tx0aCdqJgiIjQU2mzl2bHsyNLQ90I6JUR6NAW7eBFdXqKx8qWkDEnRfGfJknz599OoO0us8fdpJwfbtMGIEbNnCl81yWH39NaSmwo4d8D//o2pbUfGKzjGGRl43mufdoKEBysogOPiVzOBgKClpL1KH/DIRBY3oh1GjOi+7dw8mTVKTP3EiZGd3Vkk+hacbQRcWFh4/fjw/P78bpcgbBTmO5ubquNnHjzsa14Ggjx8/vnv3bj8/v7i4uKNHj2pViiDKjBkDCi9De8nNm/Duu13WZhDdmANUQCqVurq6ZmRkeHp6lpWVvf/++1VVVaamplRKNRMbGztr1iwfH5+euAcAbDbb29u7h40YLEbXO9WlEDVs2wa2tpCXJ2xpAQBzOzvw9YXWVrUXhUqN93SE/v333xsaGjw9PQFg2LBhbW1tjx49oliKvGkwKcrt44/hf/8XeDxRYKAoMBDa2uDoUZg8mUrVnk7b1dbW2tjYyJO2traKt05pLkXeNChFA2IxxMW1R9kvXkjJ2OPbb8HGBubMgfPnwcREc+2eCloqlTIVPndMJlMikVAsRRA1/PYbTJoE5ubg5CQkd4uSb6+cMKG96P33NdfuqaD79++v+MrXFy9e9O/fn2JpVFRUeXl5Zy2bmZkNHjy4h+4BgIWFBZvN7nk7hgmXyzWq3nmTcb8GC5uHDxlM5guZjUgk8vLy6igYNKizeehXIHqGQCBwdHSsqKggCKKqqsrW1ratrY1iKfKmAUB0rbi8PCImRk1+TAyRn99l+z29KDQ3N9+wYUNiYiKPx4uPj4+IiOjbt29FRQWLxWKz2WpLe3hGhOb4+kJBAShda9XWQmEhUJny6sHn7SUXLlzYtm3b+fPnySSHw9m0aVNTU5PaUuSNZehQCiM0QRAVFcT06cRPPxEiUfvPTz8Rn31GVFZ2WQ+A6Ok8NIJoBaV5aADg8SApCe7ebT/+4ANYvhwofLczGD1eWEEQrehiZ1KPW8bNSQitQEEjrxXFrZ46RD7wo6ARWoGCRl43Oh+kFeNyFDSiB6ZOBV1pmmyEbFAH20cRpHt044kcVBrBERrRD4rPQOoeaj8SKGhEbxAE2NoCKc3oaC0qrl/foWYnJ+UBHkMORP/Ix9oVK+C77zRZLl0KKSkdx2qViyM0on8IAnbtaj84cuTlk3Pfe6+j9L33XmaSat67t9PIG0doxLBgMjsXKwOk0i6qo6ARWoEhB0IrUNAIrUBBI7QCBY3QChQ0QitQ0AitQEEjtAIFjdAKA3olRS9x7ty5x+TTgynA4XAUH8anAalUyufzLS0tqRiLRCKpVGphYUHFmM/nm8qgYszlci0tLRnUdqxR7522xq2trdbW1lQsJRKJUCik+HgWoVAYGRlJ3Y03YqVw9uzZO3fupGgcGRkZHx9PxbKmpiY1NTUsLIyKcVZWVnV19Zw5c6gYnzx5ksVijdL0kPuX7Nq1KyQkxN7enoox9d5pZSwUCrdv376L3I3RFSUlJf/5z39WrVpFxfj69etDhgyZO3cuFWMS+o/QFhYW1B8ybWdnR9HY2tra2dmZojH5zFWKxgMHDnR3d6do7OjoyGKxnJycqBhT751Wxnw+38HBgXrLjx8/pmhcVFREsU05GEMjtAIFjdAKFDRCL3TyED5DJjg4mLpxeXk5RUuBQFBTU0PRmMPhNDY2UjSur69vbW2laBwSElJfX0/RmHrvtDKWSqXPnj2jaMzj8Wprayka/yiDojEJ/S8KtcLNzY2ipbm5+cCBAykaU5zSIqF4hdcNqPdOK2MGg+Hq6krRuI8M6m5oC4YcCK1AQSO0AgWN0Ar6x9BBQUH6dqEXmTp1KsXld2NkxIgR2lah/9I38kZBqxG6ubk5JibG2tq6ubk5Li6uX79+1EsNH6lUum3bNqFQWF9fHxoa+te//lWx9PDhw/JXqdvb21+5ckVPbvaI06dP37x58+DBg6pFJ06cuH//vkAgGD9+/IIFCzptQqtJPgNn5syZZ86cIecvP/30U61KDZ/4+PiwsDDZK3UqBg8ezOfzFUs3b9585cqVTBnZ2dn6c7ObtLa2xsXFLV68eMKECaqlv/322+jRoyUSiUgk8vX11dBB+gi6vr7ezMyMXJLgcrnm5uaKCx+aS42CkSNHXrt2jTwePXr06dOnFUs3b95MfYXFYPn+++/VCjokJGTTpk3kcWRk5Nq1aztrgT6zHGVlZX379rWysgIAS0tLOzu7kpISiqVGwZMnTwYMGEAeu7i4FBcXKxlwudzy8vLW1lZ9eNe7FBcXa+67HPrE0I2NjYrX+1ZWVg0Kb9LVXGr4cDgcoVAo74Kq/x9++GFKSoqtre2tW7cYDMbJkycp3k9gFCj++zT/7+gjaDs7Oz6fL0/yeDw7OzuKpYaPlZWViYmJvAuq/k+VIXvU7Ho/P7+UlJSVK1fqyVndo/jv0/y/o0/I4ebm1traKhAIyFuempqaFHcjaC41fJhMpqurq3xkqqurGzp0qKKBRCKRH3t5eVVUVLx2H3uRoUOHaui7IvQR9ODBg8eOHZuZmQkAd+7c8fX19fDwkEgkp0+fFgqFakv17bJ2zJ8///bt2wDw4sWLwsLCGTNmAMDVq1fr6uoAYPPmzaQZQRAFBQVjxozRt7864NGjRw8fPlTsOwDcuHEjODi4syq0WlgpLS2NiIh4++23Hz58GB8f7+3tzWaz//KXv2RnZ48aNUq1VN/+ageXy12xYoWnp2dhYeGiRYtmzpwJAP3799+yZcvatWuXL19ubW3t6OiYn5/v4+OzdetWffurHS9evDhx4sStW7cePny4cuXKjz76yM/PLzg4mBySACAmJob8jnV2dt6xY0dn7dBK0CRCodDc3FyeVLohWanU6FDyn8vlklM3apPGjkgkIghC3l+JRMJgMJhMTWEFDQWNvMnQJ4ZGEBQ0QjdQ0AitQEEjtII+K4VvCE1NTXv27CkuLjYxMfH29o6IiHBwcDh48OCvv/5qYWGxcOHCwMBAffuoV3pz7xTSWyxfvtzExCQrK4tM1tXV+fv719XV6dsv/YPTdkYJh8N5++23+/bt++DBA3Nz82nTpm3cuHHChAn69kv/oKCNlRs3bkyePDkiIoJ82uyWLVv07ZFBgII2Yr7++utDhw5NnDjx2rVrmtfP3hzwr2DEhIeHEwTx7NkzxZ2xbzgoaGNFKpWuWbNm7969ZWVlGzdu1Lc7hgJO2xkrcXFx48ePj4yMfP78+b59+4KCgiZOnKhvp/QPxtBGSUZGRkJCwsWLFxkMBp/P9/Pza2try8/P1+p1JLQEQw7j4/nz55GRkcnJyeS7gvr06ZOSklJVVRUeHq5v1/QPCtrIOHfu3FdffTVo0KCff/6ZzBEKhRcvXvzggw8ePHgQHh5eVVWlbx/1CYYcCK3AERqhFShohFagoBFagYJGaMX/BwAA//+3vLVlxWJLSQAAAABJRU5ErkJggg==)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "drawTestBin()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0, 1, 1, 0, 1 ] should be [ 0, 1, 1, 0, 1 ]\n" + ] + } + ], + "source": [ + "console.log(yPred1, \"should be\", testY.is_even.toArray());" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[33m1.6635532311438688\u001b[39m" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logreg.loss();" + ] } ], "metadata": { diff --git a/plot/mod.wasm b/plot/mod.wasm index b4c33f7..46a2fc4 100755 Binary files a/plot/mod.wasm and b/plot/mod.wasm differ diff --git a/regr/main.go b/regr/main.go index 4395ffd..05fe8c3 100644 --- a/regr/main.go +++ b/regr/main.go @@ -15,6 +15,7 @@ func InitRegrExports(this js.Value, args []js.Value) interface{} { exports.Set("ElasticNet", js.FuncOf(src.NewElasticNetJS)) exports.Set("Lasso", js.FuncOf(src.NewLassoJS)) exports.Set("R2Score", js.FuncOf(src.R2ScoreJS)) + exports.Set("Logistic", js.FuncOf(src.NewLogisticRegressionJS)) return exports } diff --git a/regr/mod.wasm b/regr/mod.wasm index 55a3d54..c8299c7 100755 Binary files a/regr/mod.wasm and b/regr/mod.wasm differ diff --git a/regr/src/LogisticRegressionJS.go b/regr/src/LogisticRegressionJS.go new file mode 100644 index 0000000..5c11367 --- /dev/null +++ b/regr/src/LogisticRegressionJS.go @@ -0,0 +1,40 @@ +//go:build js && wasm +// +build js,wasm + +package src + +import ( + "syscall/js" +) + +func NewLogisticRegressionJS(this js.Value, args []js.Value) interface{} { + var ( + epochs = args[0].Get("epochs").Int() + learningRate = args[0].Get("learningRate").Float() + ) + reg := &MCLogisticRegression{ + Epochs: epochs, + LearningRate: learningRate, + } + loss := 1.0 + obj := js.Global().Get("Object").New() + obj.Set("fit", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + X := JSFloatArray2D(args[0]) + Y := JSFloatArray2D(args[1]) + XDense := Array2DToDense(X) + YDense := Array2DToDense(Y) + reg.Fit(XDense, YDense) + return nil + })) + obj.Set("predict", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + X := JSFloatArray2D(args[0]) + XDense := Array2DToDense(X) + Y := reg.Predict(XDense) + loss = reg.Loss(XDense, Y) + return MatrixToJSArray2D(Y) + })) + obj.Set("loss", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + return loss + })) + return obj +} diff --git a/regr/src/utils_js.go b/regr/src/utils_js.go index 903470a..d7f5125 100644 --- a/regr/src/utils_js.go +++ b/regr/src/utils_js.go @@ -5,6 +5,8 @@ package src import ( "syscall/js" + + "gonum.org/v1/gonum/mat" ) func JSFloatArray(arg js.Value) []float64 { @@ -35,3 +37,26 @@ func ToJSArray[T any](arr []T) []interface{} { } return jsArr } + +func MatrixToJSArray(m mat.Matrix) []interface{} { + r, c := m.Dims() + data := make([]float64, r*c) + for i := 0; i < r; i++ { + for j := 0; j < c; j++ { + data[i*c+j] = m.At(i, j) + } + } + return ToJSArray(data) +} + +func MatrixToJSArray2D(m mat.Matrix) []interface{} { + r, c := m.Dims() + data := make([][]interface{}, r) + for i := 0; i < r; i++ { + data[i] = make([]interface{}, c) + for j := 0; j < c; j++ { + data[i][j] = m.At(i, j) + } + } + return ToJSArray(data) +}