fix dims for logreg predict

This commit is contained in:
Anton Nesterov 2024-10-03 23:42:50 +02:00
parent cde2e6fe99
commit af3587e284
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
5 changed files with 16 additions and 14 deletions

View file

@ -10,10 +10,10 @@ import (
) )
func InitPlotExports(this js.Value, args []js.Value) interface{} { func InitPlotExports(this js.Value, args []js.Value) interface{} {
exports := args[0] exports := js.Global().Get("Object").New()
exports.Set("Hist", js.FuncOf(src.HistPlot)) exports.Set("Hist", js.FuncOf(src.HistPlot))
exports.Set("Plot", js.FuncOf(src.Plot)) exports.Set("Plot", js.FuncOf(src.Plot))
return nil return exports
} }
func main() { func main() {

View file

@ -12,10 +12,8 @@ const wasm = wasmMmodule.instance;
go.run(wasm); go.run(wasm);
const _exports = {} as Record<string, (...args: unknown[]) => unknown>;
// @ts-ignore: no types // @ts-ignore: no types
__InitPlotExports(_exports); const _exports = __InitPlotExports() as Record<string, (...args: unknown[]) => unknown>;
for (const key in _exports) { for (const key in _exports) {
const draw = _exports[key]; const draw = _exports[key];

View file

@ -13,6 +13,6 @@ const wasm = wasmMmodule.instance;
go.run(wasm); go.run(wasm);
// @ts-ignore: no types // @ts-ignore: no types
const _exports = __InitRegrExports(_exports) as Record<string, (...args: unknown[]) => unknown>; const _exports = __InitRegrExports() as Record<string, (...args: unknown[]) => unknown>;
export default _exports; export default _exports;

View file

@ -27,7 +27,6 @@ func TestLogisticRegression(t *testing.T) {
LearningRate: .01, LearningRate: .01,
} }
regr.Fit(XDense, YDense, nil) regr.Fit(XDense, YDense, nil)
fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense) yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
} }
@ -36,8 +35,8 @@ func TestMCLogisticRegression(t *testing.T) {
X := [][]float64{ X := [][]float64{
{.1, .1, .1}, {.1, .1, .1},
{.2, .2, .2}, {.2, .2, .2},
{.1, .1, .1}, {.11, .11, .11},
{.2, .2, .2}, {.22, .22, .22},
} }
//Y := [][]float64{{1}, {0}, {1}, {0}} //Y := [][]float64{{1}, {0}, {1}, {0}}
Y := [][]float64{ Y := [][]float64{
@ -48,13 +47,18 @@ func TestMCLogisticRegression(t *testing.T) {
} }
XDense := Array2DToDense(X) XDense := Array2DToDense(X)
YDense := Array2DToDense(Y) YDense := Array2DToDense(Y)
epochs := 1000 epochs := 100000
regr := &MCLogisticRegression{ regr := &MCLogisticRegression{
Epochs: epochs, Epochs: epochs,
LearningRate: .01, LearningRate: .001,
} }
regr.Fit(XDense, YDense) regr.Fit(XDense, YDense)
// fmt.Println(regr.Weights, regr.Bias) // fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense) yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
XT := [][]float64{
{.1, .1, .111},
}
p2 := regr.Predict(Array2DToDense(XT))
fmt.Println(p2)
} }

View file

@ -11,11 +11,10 @@ type MCLogisticRegression struct {
LearningRate float64 LearningRate float64
Models []*LogisticRegression Models []*LogisticRegression
cols int cols int
rows int
} }
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) { func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
regr.rows, regr.cols = y.Dims() _, regr.cols = y.Dims()
regr.Models = make([]*LogisticRegression, regr.cols) regr.Models = make([]*LogisticRegression, regr.cols)
for j := 0; j < regr.cols; j++ { for j := 0; j < regr.cols; j++ {
regr.Models[j] = &LogisticRegression{ regr.Models[j] = &LogisticRegression{
@ -29,7 +28,8 @@ func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
} }
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix { func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
probs := mat.NewDense(regr.rows, regr.cols, nil) rows, _ := x.Dims()
probs := mat.NewDense(rows, regr.cols, nil)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(regr.cols) wg.Add(regr.cols)
for j := 0; j < regr.cols; j++ { for j := 0; j < regr.cols; j++ {