multiclass log reg

This commit is contained in:
Anton Nesterov 2024-10-03 21:45:44 +02:00
parent a7c139790c
commit bf43ecee22
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
3 changed files with 90 additions and 7 deletions

View file

@ -34,7 +34,7 @@ func (regr *LogisticRegression) Loss(yTrue, yPred mat.Matrix) float64 {
sum := &mat.Dense{} sum := &mat.Dense{}
sum.Add(y1, y2) sum.Add(y1, y2)
w, h := yTrue.Dims() w, h := yTrue.Dims()
return mat.Sum(sum) / float64(w*h) return -(mat.Sum(sum) / float64(w*h))
} }
func (regr *LogisticRegression) forward(X mat.Matrix) mat.Matrix { func (regr *LogisticRegression) forward(X mat.Matrix) mat.Matrix {
@ -76,8 +76,8 @@ func (regr *LogisticRegression) backprop(x, y mat.Matrix) float64 {
return loss return loss
} }
func (regr *LogisticRegression) Fit(X, Y mat.Matrix, epochs int, losses *[]float64) { func (regr *LogisticRegression) Fit(X, Y mat.Matrix, losses *[]float64) {
for i := 0; i < epochs; i++ { for i := 0; i < regr.Epochs; i++ {
regr.backprop(X, Y) regr.backprop(X, Y)
if losses != nil { if losses != nil {
*losses = append(*losses, regr.Loss(Y, regr.forward(X))) *losses = append(*losses, regr.Loss(Y, regr.forward(X)))

View file

@ -6,16 +6,55 @@ import (
) )
func TestLogisticRegression(t *testing.T) { func TestLogisticRegression(t *testing.T) {
X := [][]float64{{.1, .1, .1}, {.2, .2, .2}, {.1, .1, .1}, {.2, .2, .2}} X := [][]float64{
Y := [][]float64{{0}, {1}, {0}, {1}} {.1, .1, .1},
{.2, .2, .2},
{.1, .1, .1},
{.2, .2, .2},
}
//Y := [][]float64{{1}, {0}, {1}, {0}}
Y := [][]float64{
{0},
{1},
{0},
{1},
}
XDense := Array2DToDense(X) XDense := Array2DToDense(X)
YDense := Array2DToDense(Y) YDense := Array2DToDense(Y)
epochs := 1000 epochs := 1000
regr := &LogisticRegression{ regr := &LogisticRegression{
LearningRate: .1, Epochs: epochs,
LearningRate: .01,
} }
regr.Fit(XDense, YDense, epochs, nil) regr.Fit(XDense, YDense, nil)
fmt.Println(regr.Weights, regr.Bias) fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense) yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
} }
func TestMCLogisticRegression(t *testing.T) {
X := [][]float64{
{.1, .1, .1},
{.2, .2, .2},
{.1, .1, .1},
{.2, .2, .2},
}
//Y := [][]float64{{1}, {0}, {1}, {0}}
Y := [][]float64{
{0, 1},
{1, 0},
{0, 1},
{1, 0},
}
XDense := Array2DToDense(X)
YDense := Array2DToDense(Y)
epochs := 1000
regr := &MCLogisticRegression{
Epochs: epochs,
LearningRate: .01,
}
regr.Fit(XDense, YDense)
// fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred) //, regr.Loss(YDense, yPred))
}

View file

@ -0,0 +1,44 @@
package src
import (
"sync"
"gonum.org/v1/gonum/mat"
)
type MCLogisticRegression struct {
Epochs int
LearningRate float64
Models []*LogisticRegression
cols int
rows int
}
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
regr.rows, regr.cols = y.Dims()
regr.Models = make([]*LogisticRegression, regr.cols)
for j := 0; j < regr.cols; j++ {
regr.Models[j] = &LogisticRegression{
Epochs: regr.Epochs,
LearningRate: regr.LearningRate,
}
yj := mat.Col(nil, j, y)
Y := mat.NewDense(len(yj), 1, yj)
regr.Models[j].Fit(x, Y, nil)
}
}
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
probs := mat.NewDense(regr.rows, regr.cols, nil)
wg := sync.WaitGroup{}
wg.Add(regr.cols)
for j := 0; j < regr.cols; j++ {
go func(j int) {
pred := regr.Models[j].Predict(x).(*mat.Dense)
probs.SetCol(j, mat.Col(nil, 0, pred))
wg.Done()
}(j)
}
wg.Wait()
return probs
}