From af3587e2843ae1ff83a74aa1792374a5f99e4f74 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Thu, 3 Oct 2024 23:42:50 +0200 Subject: [PATCH] fix dims for logreg predict --- plot/main.go | 4 ++-- plot/mod.ts | 4 +--- regr/mod.ts | 2 +- regr/src/LogisticRegression_test.go | 14 +++++++++----- regr/src/MCLogisticRegression.go | 6 +++--- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/plot/main.go b/plot/main.go index 2568d22..48a4918 100644 --- a/plot/main.go +++ b/plot/main.go @@ -10,10 +10,10 @@ import ( ) func InitPlotExports(this js.Value, args []js.Value) interface{} { - exports := args[0] + exports := js.Global().Get("Object").New() exports.Set("Hist", js.FuncOf(src.HistPlot)) exports.Set("Plot", js.FuncOf(src.Plot)) - return nil + return exports } func main() { diff --git a/plot/mod.ts b/plot/mod.ts index 8f002e4..bfff38e 100644 --- a/plot/mod.ts +++ b/plot/mod.ts @@ -12,10 +12,8 @@ const wasm = wasmMmodule.instance; go.run(wasm); -const _exports = {} as Record unknown>; - // @ts-ignore: no types -__InitPlotExports(_exports); +const _exports = __InitPlotExports() as Record unknown>; for (const key in _exports) { const draw = _exports[key]; diff --git a/regr/mod.ts b/regr/mod.ts index 4d8c0cc..1036d42 100644 --- a/regr/mod.ts +++ b/regr/mod.ts @@ -13,6 +13,6 @@ const wasm = wasmMmodule.instance; go.run(wasm); // @ts-ignore: no types -const _exports = __InitRegrExports(_exports) as Record unknown>; +const _exports = __InitRegrExports() as Record unknown>; export default _exports; diff --git a/regr/src/LogisticRegression_test.go b/regr/src/LogisticRegression_test.go index a99c5dd..33b3f33 100644 --- a/regr/src/LogisticRegression_test.go +++ b/regr/src/LogisticRegression_test.go @@ -27,7 +27,6 @@ func TestLogisticRegression(t *testing.T) { LearningRate: .01, } regr.Fit(XDense, YDense, nil) - fmt.Println(regr.Weights, regr.Bias) yPred := regr.Predict(XDense) fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) } @@ -36,8 +35,8 @@ func TestMCLogisticRegression(t *testing.T) { X := [][]float64{ {.1, .1, .1}, {.2, .2, .2}, - {.1, .1, .1}, - {.2, .2, .2}, + {.11, .11, .11}, + {.22, .22, .22}, } //Y := [][]float64{{1}, {0}, {1}, {0}} Y := [][]float64{ @@ -48,13 +47,18 @@ func TestMCLogisticRegression(t *testing.T) { } XDense := Array2DToDense(X) YDense := Array2DToDense(Y) - epochs := 1000 + epochs := 100000 regr := &MCLogisticRegression{ Epochs: epochs, - LearningRate: .01, + LearningRate: .001, } regr.Fit(XDense, YDense) // fmt.Println(regr.Weights, regr.Bias) yPred := regr.Predict(XDense) fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) + XT := [][]float64{ + {.1, .1, .111}, + } + p2 := regr.Predict(Array2DToDense(XT)) + fmt.Println(p2) } diff --git a/regr/src/MCLogisticRegression.go b/regr/src/MCLogisticRegression.go index 3db8c7d..5fab324 100644 --- a/regr/src/MCLogisticRegression.go +++ b/regr/src/MCLogisticRegression.go @@ -11,11 +11,10 @@ type MCLogisticRegression struct { LearningRate float64 Models []*LogisticRegression cols int - rows int } func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) { - regr.rows, regr.cols = y.Dims() + _, regr.cols = y.Dims() regr.Models = make([]*LogisticRegression, regr.cols) for j := 0; j < regr.cols; j++ { regr.Models[j] = &LogisticRegression{ @@ -29,7 +28,8 @@ func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) { } func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix { - probs := mat.NewDense(regr.rows, regr.cols, nil) + rows, _ := x.Dims() + probs := mat.NewDense(rows, regr.cols, nil) wg := sync.WaitGroup{} wg.Add(regr.cols) for j := 0; j < regr.cols; j++ {