fix dims for logreg predict

This commit is contained in:
Anton Nesterov 2024-10-03 23:42:50 +02:00
parent cde2e6fe99
commit af3587e284
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
5 changed files with 16 additions and 14 deletions

View file

@ -10,10 +10,10 @@ import (
)
func InitPlotExports(this js.Value, args []js.Value) interface{} {
exports := args[0]
exports := js.Global().Get("Object").New()
exports.Set("Hist", js.FuncOf(src.HistPlot))
exports.Set("Plot", js.FuncOf(src.Plot))
return nil
return exports
}
func main() {

View file

@ -12,10 +12,8 @@ const wasm = wasmMmodule.instance;
go.run(wasm);
const _exports = {} as Record<string, (...args: unknown[]) => unknown>;
// @ts-ignore: no types
__InitPlotExports(_exports);
const _exports = __InitPlotExports() as Record<string, (...args: unknown[]) => unknown>;
for (const key in _exports) {
const draw = _exports[key];

View file

@ -13,6 +13,6 @@ const wasm = wasmMmodule.instance;
go.run(wasm);
// @ts-ignore: no types
const _exports = __InitRegrExports(_exports) as Record<string, (...args: unknown[]) => unknown>;
const _exports = __InitRegrExports() as Record<string, (...args: unknown[]) => unknown>;
export default _exports;

View file

@ -27,7 +27,6 @@ func TestLogisticRegression(t *testing.T) {
LearningRate: .01,
}
regr.Fit(XDense, YDense, nil)
fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
}
@ -36,8 +35,8 @@ func TestMCLogisticRegression(t *testing.T) {
X := [][]float64{
{.1, .1, .1},
{.2, .2, .2},
{.1, .1, .1},
{.2, .2, .2},
{.11, .11, .11},
{.22, .22, .22},
}
//Y := [][]float64{{1}, {0}, {1}, {0}}
Y := [][]float64{
@ -48,13 +47,18 @@ func TestMCLogisticRegression(t *testing.T) {
}
XDense := Array2DToDense(X)
YDense := Array2DToDense(Y)
epochs := 1000
epochs := 100000
regr := &MCLogisticRegression{
Epochs: epochs,
LearningRate: .01,
LearningRate: .001,
}
regr.Fit(XDense, YDense)
// fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
XT := [][]float64{
{.1, .1, .111},
}
p2 := regr.Predict(Array2DToDense(XT))
fmt.Println(p2)
}

View file

@ -11,11 +11,10 @@ type MCLogisticRegression struct {
LearningRate float64
Models []*LogisticRegression
cols int
rows int
}
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
regr.rows, regr.cols = y.Dims()
_, regr.cols = y.Dims()
regr.Models = make([]*LogisticRegression, regr.cols)
for j := 0; j < regr.cols; j++ {
regr.Models[j] = &LogisticRegression{
@ -29,7 +28,8 @@ func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
}
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
probs := mat.NewDense(regr.rows, regr.cols, nil)
rows, _ := x.Dims()
probs := mat.NewDense(rows, regr.cols, nil)
wg := sync.WaitGroup{}
wg.Add(regr.cols)
for j := 0; j < regr.cols; j++ {