fix dims for logreg predict
This commit is contained in:
parent
cde2e6fe99
commit
af3587e284
|
@ -10,10 +10,10 @@ import (
|
|||
)
|
||||
|
||||
func InitPlotExports(this js.Value, args []js.Value) interface{} {
|
||||
exports := args[0]
|
||||
exports := js.Global().Get("Object").New()
|
||||
exports.Set("Hist", js.FuncOf(src.HistPlot))
|
||||
exports.Set("Plot", js.FuncOf(src.Plot))
|
||||
return nil
|
||||
return exports
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -12,10 +12,8 @@ const wasm = wasmMmodule.instance;
|
|||
|
||||
go.run(wasm);
|
||||
|
||||
const _exports = {} as Record<string, (...args: unknown[]) => unknown>;
|
||||
|
||||
// @ts-ignore: no types
|
||||
__InitPlotExports(_exports);
|
||||
const _exports = __InitPlotExports() as Record<string, (...args: unknown[]) => unknown>;
|
||||
|
||||
for (const key in _exports) {
|
||||
const draw = _exports[key];
|
||||
|
|
|
@ -13,6 +13,6 @@ const wasm = wasmMmodule.instance;
|
|||
go.run(wasm);
|
||||
|
||||
// @ts-ignore: no types
|
||||
const _exports = __InitRegrExports(_exports) as Record<string, (...args: unknown[]) => unknown>;
|
||||
const _exports = __InitRegrExports() as Record<string, (...args: unknown[]) => unknown>;
|
||||
|
||||
export default _exports;
|
||||
|
|
|
@ -27,7 +27,6 @@ func TestLogisticRegression(t *testing.T) {
|
|||
LearningRate: .01,
|
||||
}
|
||||
regr.Fit(XDense, YDense, nil)
|
||||
fmt.Println(regr.Weights, regr.Bias)
|
||||
yPred := regr.Predict(XDense)
|
||||
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
||||
}
|
||||
|
@ -36,8 +35,8 @@ func TestMCLogisticRegression(t *testing.T) {
|
|||
X := [][]float64{
|
||||
{.1, .1, .1},
|
||||
{.2, .2, .2},
|
||||
{.1, .1, .1},
|
||||
{.2, .2, .2},
|
||||
{.11, .11, .11},
|
||||
{.22, .22, .22},
|
||||
}
|
||||
//Y := [][]float64{{1}, {0}, {1}, {0}}
|
||||
Y := [][]float64{
|
||||
|
@ -48,13 +47,18 @@ func TestMCLogisticRegression(t *testing.T) {
|
|||
}
|
||||
XDense := Array2DToDense(X)
|
||||
YDense := Array2DToDense(Y)
|
||||
epochs := 1000
|
||||
epochs := 100000
|
||||
regr := &MCLogisticRegression{
|
||||
Epochs: epochs,
|
||||
LearningRate: .01,
|
||||
LearningRate: .001,
|
||||
}
|
||||
regr.Fit(XDense, YDense)
|
||||
// fmt.Println(regr.Weights, regr.Bias)
|
||||
yPred := regr.Predict(XDense)
|
||||
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
||||
XT := [][]float64{
|
||||
{.1, .1, .111},
|
||||
}
|
||||
p2 := regr.Predict(Array2DToDense(XT))
|
||||
fmt.Println(p2)
|
||||
}
|
||||
|
|
|
@ -11,11 +11,10 @@ type MCLogisticRegression struct {
|
|||
LearningRate float64
|
||||
Models []*LogisticRegression
|
||||
cols int
|
||||
rows int
|
||||
}
|
||||
|
||||
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
||||
regr.rows, regr.cols = y.Dims()
|
||||
_, regr.cols = y.Dims()
|
||||
regr.Models = make([]*LogisticRegression, regr.cols)
|
||||
for j := 0; j < regr.cols; j++ {
|
||||
regr.Models[j] = &LogisticRegression{
|
||||
|
@ -29,7 +28,8 @@ func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
|||
}
|
||||
|
||||
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
|
||||
probs := mat.NewDense(regr.rows, regr.cols, nil)
|
||||
rows, _ := x.Dims()
|
||||
probs := mat.NewDense(rows, regr.cols, nil)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(regr.cols)
|
||||
for j := 0; j < regr.cols; j++ {
|
||||
|
|
Loading…
Reference in a new issue