fix dims for logreg predict
This commit is contained in:
parent
cde2e6fe99
commit
af3587e284
|
@ -10,10 +10,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitPlotExports(this js.Value, args []js.Value) interface{} {
|
func InitPlotExports(this js.Value, args []js.Value) interface{} {
|
||||||
exports := args[0]
|
exports := js.Global().Get("Object").New()
|
||||||
exports.Set("Hist", js.FuncOf(src.HistPlot))
|
exports.Set("Hist", js.FuncOf(src.HistPlot))
|
||||||
exports.Set("Plot", js.FuncOf(src.Plot))
|
exports.Set("Plot", js.FuncOf(src.Plot))
|
||||||
return nil
|
return exports
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -12,10 +12,8 @@ const wasm = wasmMmodule.instance;
|
||||||
|
|
||||||
go.run(wasm);
|
go.run(wasm);
|
||||||
|
|
||||||
const _exports = {} as Record<string, (...args: unknown[]) => unknown>;
|
|
||||||
|
|
||||||
// @ts-ignore: no types
|
// @ts-ignore: no types
|
||||||
__InitPlotExports(_exports);
|
const _exports = __InitPlotExports() as Record<string, (...args: unknown[]) => unknown>;
|
||||||
|
|
||||||
for (const key in _exports) {
|
for (const key in _exports) {
|
||||||
const draw = _exports[key];
|
const draw = _exports[key];
|
||||||
|
|
|
@ -13,6 +13,6 @@ const wasm = wasmMmodule.instance;
|
||||||
go.run(wasm);
|
go.run(wasm);
|
||||||
|
|
||||||
// @ts-ignore: no types
|
// @ts-ignore: no types
|
||||||
const _exports = __InitRegrExports(_exports) as Record<string, (...args: unknown[]) => unknown>;
|
const _exports = __InitRegrExports() as Record<string, (...args: unknown[]) => unknown>;
|
||||||
|
|
||||||
export default _exports;
|
export default _exports;
|
||||||
|
|
|
@ -27,7 +27,6 @@ func TestLogisticRegression(t *testing.T) {
|
||||||
LearningRate: .01,
|
LearningRate: .01,
|
||||||
}
|
}
|
||||||
regr.Fit(XDense, YDense, nil)
|
regr.Fit(XDense, YDense, nil)
|
||||||
fmt.Println(regr.Weights, regr.Bias)
|
|
||||||
yPred := regr.Predict(XDense)
|
yPred := regr.Predict(XDense)
|
||||||
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
||||||
}
|
}
|
||||||
|
@ -36,8 +35,8 @@ func TestMCLogisticRegression(t *testing.T) {
|
||||||
X := [][]float64{
|
X := [][]float64{
|
||||||
{.1, .1, .1},
|
{.1, .1, .1},
|
||||||
{.2, .2, .2},
|
{.2, .2, .2},
|
||||||
{.1, .1, .1},
|
{.11, .11, .11},
|
||||||
{.2, .2, .2},
|
{.22, .22, .22},
|
||||||
}
|
}
|
||||||
//Y := [][]float64{{1}, {0}, {1}, {0}}
|
//Y := [][]float64{{1}, {0}, {1}, {0}}
|
||||||
Y := [][]float64{
|
Y := [][]float64{
|
||||||
|
@ -48,13 +47,18 @@ func TestMCLogisticRegression(t *testing.T) {
|
||||||
}
|
}
|
||||||
XDense := Array2DToDense(X)
|
XDense := Array2DToDense(X)
|
||||||
YDense := Array2DToDense(Y)
|
YDense := Array2DToDense(Y)
|
||||||
epochs := 1000
|
epochs := 100000
|
||||||
regr := &MCLogisticRegression{
|
regr := &MCLogisticRegression{
|
||||||
Epochs: epochs,
|
Epochs: epochs,
|
||||||
LearningRate: .01,
|
LearningRate: .001,
|
||||||
}
|
}
|
||||||
regr.Fit(XDense, YDense)
|
regr.Fit(XDense, YDense)
|
||||||
// fmt.Println(regr.Weights, regr.Bias)
|
// fmt.Println(regr.Weights, regr.Bias)
|
||||||
yPred := regr.Predict(XDense)
|
yPred := regr.Predict(XDense)
|
||||||
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
||||||
|
XT := [][]float64{
|
||||||
|
{.1, .1, .111},
|
||||||
|
}
|
||||||
|
p2 := regr.Predict(Array2DToDense(XT))
|
||||||
|
fmt.Println(p2)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,11 +11,10 @@ type MCLogisticRegression struct {
|
||||||
LearningRate float64
|
LearningRate float64
|
||||||
Models []*LogisticRegression
|
Models []*LogisticRegression
|
||||||
cols int
|
cols int
|
||||||
rows int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
||||||
regr.rows, regr.cols = y.Dims()
|
_, regr.cols = y.Dims()
|
||||||
regr.Models = make([]*LogisticRegression, regr.cols)
|
regr.Models = make([]*LogisticRegression, regr.cols)
|
||||||
for j := 0; j < regr.cols; j++ {
|
for j := 0; j < regr.cols; j++ {
|
||||||
regr.Models[j] = &LogisticRegression{
|
regr.Models[j] = &LogisticRegression{
|
||||||
|
@ -29,7 +28,8 @@ func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
|
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
|
||||||
probs := mat.NewDense(regr.rows, regr.cols, nil)
|
rows, _ := x.Dims()
|
||||||
|
probs := mat.NewDense(rows, regr.cols, nil)
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(regr.cols)
|
wg.Add(regr.cols)
|
||||||
for j := 0; j < regr.cols; j++ {
|
for j := 0; j < regr.cols; j++ {
|
||||||
|
|
Loading…
Reference in a new issue