logistic reg

This commit is contained in:
Anton Nesterov 2024-10-03 20:36:36 +02:00
parent f7efdc6ff0
commit a7c139790c
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
2 changed files with 70 additions and 49 deletions

View file

@ -7,10 +7,10 @@ import (
)
type LogisticRegression struct {
Epochs int
Weights *mat.Dense
Bias float64
Losses []float64
Epochs int
Weights *mat.Dense
Bias float64
LearningRate float64
}
func sigmoidFunction(x float64) float64 {
@ -20,54 +20,75 @@ func sigmoidFunction(x float64) float64 {
return 1. / (1. + math.Exp(-x))
}
func (regr *LogisticRegression) backprop(x, y mat.Matrix) float64 {
_, c := x.Dims()
ry, cy := y.Dims()
regr.Bias = 0.1
regr.Weights = mat.NewDense(cy, c, nil)
coef := &mat.Dense{}
// binary cross-entropy Loss
func (regr *LogisticRegression) Loss(yTrue, yPred mat.Matrix) float64 {
ep := 1e-9
y1 := &mat.Dense{}
y1.Apply(func(i, j int, v float64) float64 {
return v * math.Log1p(yPred.At(i, j)+ep)
}, yTrue)
y2 := &mat.Dense{}
y2.Apply(func(i, j int, v float64) float64 {
return (1. - v) * math.Log1p(1.-yPred.At(i, j)+ep)
}, yTrue)
sum := &mat.Dense{}
sum.Add(y1, y2)
w, h := yTrue.Dims()
return mat.Sum(sum) / float64(w*h)
}
coef.Mul(regr.Weights, x.T())
func (regr *LogisticRegression) forward(X mat.Matrix) mat.Matrix {
coef := &mat.Dense{}
coef.Mul(X, regr.Weights)
coef.Apply(func(i, j int, v float64) float64 {
return sigmoidFunction(v + regr.Bias)
}, coef)
diff := &mat.Dense{}
diff.Sub(y.T(), coef)
w := &mat.Dense{}
w.Mul(diff, x)
regr.Weights = w
regr.Bias -= 0.1 * (mat.Sum(diff) / float64(c))
// Loss
yZeroLoss := &mat.Dense{}
yZeroLoss.Apply(func(i, j int, v float64) float64 {
return v * math.Log1p(coef.At(i, j)+1e-9)
}, y.T())
yOneLoss := &mat.Dense{}
yOneLoss.Apply(func(i, j int, v float64) float64 {
return (1. - v) * math.Log1p(1.-coef.At(i, j)+1e-9)
}, y.T())
sum := &mat.Dense{}
sum.Add(yZeroLoss, yOneLoss)
return mat.Sum(sum) / float64(ry+cy)
return coef
}
func (regr *LogisticRegression) Fit(X, Y mat.Matrix, epochs int) error {
for i := 0; i < epochs; i++ {
loss := regr.backprop(X, Y)
regr.Losses = append(regr.Losses, loss)
func (regr *LogisticRegression) grad(x, yTrue, yPred mat.Matrix) (*mat.Dense, float64) {
nSamples, _ := x.Dims()
deriv := &mat.Dense{}
deriv.Sub(yPred, yTrue)
dw := &mat.Dense{}
dw.Mul(x.T(), deriv)
dw.Apply(func(i, j int, v float64) float64 {
return 1. / float64(nSamples) * v
}, dw)
db := (1. / float64(nSamples)) * mat.Sum(deriv)
return dw, db
}
func (regr *LogisticRegression) backprop(x, y mat.Matrix) float64 {
_, c := x.Dims()
_, cy := y.Dims()
if regr.Weights == nil {
regr.Weights = mat.NewDense(c, cy, nil)
}
return nil
if regr.LearningRate == 0 {
regr.LearningRate = 0.01
}
yPred := regr.forward(x)
loss := regr.Loss(y, yPred)
dw, db := regr.grad(x, y, yPred)
regr.Weights.Sub(regr.Weights, dw)
regr.Bias -= regr.LearningRate * db
return loss
}
func (regr *LogisticRegression) Fit(X, Y mat.Matrix, epochs int, losses *[]float64) {
for i := 0; i < epochs; i++ {
regr.backprop(X, Y)
if losses != nil {
*losses = append(*losses, regr.Loss(Y, regr.forward(X)))
}
}
}
func (regr *LogisticRegression) Predict(X mat.Matrix) mat.Matrix {
coef := &mat.Dense{}
coef.Mul(X, regr.Weights.T())
coef.Mul(X, regr.Weights)
coef.Apply(func(i, j int, v float64) float64 {
p := sigmoidFunction(v + regr.Bias)
if p > .5 {

View file

@ -6,16 +6,16 @@ import (
)
func TestLogisticRegression(t *testing.T) {
X := [][]float64{{10.1, 10.1, 10.1}, {2.1, 2.1, 2.1}, {10.2, 10.2, 10.2}, {2.2, 2.2, 2.2}}
X := [][]float64{{.1, .1, .1}, {.2, .2, .2}, {.1, .1, .1}, {.2, .2, .2}}
Y := [][]float64{{0}, {1}, {0}, {1}}
XDense := Array2DToDense(X)
YDense := Array2DToDense(Y)
epochs := 10
regr := &LogisticRegression{}
err := regr.Fit(XDense, YDense, epochs)
if err != nil {
t.Error(err)
epochs := 1000
regr := &LogisticRegression{
LearningRate: .1,
}
fmt.Println(regr.Weights, regr.Bias, regr.Losses)
fmt.Println(YDense, regr.Predict(XDense))
regr.Fit(XDense, YDense, epochs, nil)
fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
}