+ linear regression solver

This commit is contained in:
Anton Nesterov 2024-10-01 01:58:32 +02:00
parent 2ad3e91a2c
commit 31ecd4fbfb
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
9 changed files with 228 additions and 109 deletions

1
Makefile vendored
View file

@ -1,5 +1,6 @@
dev:
GOOS=js GOARCH=wasm tinygo build -o stat/mod.wasm -no-debug ./stat/main.go
GOOS=js GOARCH=wasm tinygo build -o regr/mod.wasm ./regr/main.go
GOOS=js GOARCH=wasm go build -o plot/mod.wasm ./plot/main.go
prod:

File diff suppressed because one or more lines are too long

126
notebooks/regressions.ipynb vendored Normal file

File diff suppressed because one or more lines are too long

View file

@ -11,7 +11,7 @@ import (
func InitRegrExports(this js.Value, args []js.Value) interface{} {
exports := args[0]
exports.Set("ABCD", js.FuncOf(src.ABCD))
exports.Set("Linear", js.FuncOf(src.NewLinearRegressionJS))
return nil
}

View file

@ -1,5 +1,4 @@
import "../lib/wasm_tinygo.js";
import type { Stat } from "./types.ts";
// @ts-expect-error: no types
const go = new Go();
@ -13,9 +12,9 @@ const wasm = wasmMmodule.instance;
go.run(wasm);
const _exports = {} as Record<string, (...args: unknown[]) => unknown> & Stat;
const _exports = {} as Record<string, (...args: unknown[]) => unknown>;
// @ts-ignore: no types
__InitStatExports(_exports);
__InitRegrExports(_exports);
export default _exports;

BIN
regr/mod.wasm Executable file

Binary file not shown.

View file

@ -1,16 +0,0 @@
//go:build js && wasm
// +build js,wasm
package src
import "syscall/js"
// ref: mat.Dense
// fit: solve least squares
// predict: predict y from x
// save: save model
// load: load model
// note: separate wasm/js glue
func ABCD(this js.Value, args []js.Value) interface{} {
return nil
}

View file

@ -0,0 +1,59 @@
//go:build js && wasm
// +build js,wasm
package src
import (
"syscall/js"
"gonum.org/v1/gonum/mat"
)
// ref: mat.Dense
// fit: solve least squares
// predict: predict y from x
// save: save model
// load: load model
// note: separate wasm/js glue
type LinearRegression struct {
Coef *mat.Dense
}
func (reg *LinearRegression) Fit(X, Y [][]float64) error {
XDense := Array2DToDense(X)
YDense := Array2DToDense(Y)
reg.Coef = new(mat.Dense)
reg.Coef.Solve(XDense, YDense)
return nil
}
func (reg *LinearRegression) Predict(X [][]float64) ([]float64, error) {
YDense := new(mat.Dense)
YDense.Mul(Array2DToDense(X), reg.Coef)
return YDense.RawMatrix().Data, nil
}
func (l *LinearRegression) Save() ([]byte, error) {
return nil, nil
}
func (l *LinearRegression) Load(data []byte) error {
return nil
}
func NewLinearRegressionJS(this js.Value, args []js.Value) interface{} {
reg := new(LinearRegression)
obj := js.Global().Get("Object").New()
obj.Set("fit", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
X := JSFloatArray2D(args[0])
Y := JSFloatArray2D(args[1])
return reg.Fit(X, Y)
}))
obj.Set("predict", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
X := JSFloatArray2D(args[0])
Y, _ := reg.Predict(X)
return ToJSArray(Y)
}))
return obj
}

39
regr/src/utils.go Normal file
View file

@ -0,0 +1,39 @@
//go:build js && wasm
// +build js,wasm
package src
import (
"syscall/js"
"gonum.org/v1/gonum/mat"
)
func Array2DToDense(X [][]float64) *mat.Dense {
dense := mat.NewDense(len(X), len(X[0]), nil)
for i, row := range X {
dense.SetRow(i, row)
}
return dense
}
func JSFloatArray2D(arg js.Value) [][]float64 {
arr := make([][]float64, arg.Length())
for i := 0; i < len(arr); i++ {
arr[i] = make([]float64, arg.Index(i).Length())
}
for i := 0; i < len(arr); i++ {
for j := 0; j < arg.Index(i).Length(); j++ {
arr[i][j] = arg.Index(i).Index(j).Float()
}
}
return arr
}
func ToJSArray[T any](arr []T) []interface{} {
jsArr := make([]interface{}, len(arr))
for i, v := range arr {
jsArr[i] = v
}
return jsArr
}