multiclass log reg
This commit is contained in:
parent
a7c139790c
commit
bf43ecee22
|
@ -34,7 +34,7 @@ func (regr *LogisticRegression) Loss(yTrue, yPred mat.Matrix) float64 {
|
||||||
sum := &mat.Dense{}
|
sum := &mat.Dense{}
|
||||||
sum.Add(y1, y2)
|
sum.Add(y1, y2)
|
||||||
w, h := yTrue.Dims()
|
w, h := yTrue.Dims()
|
||||||
return mat.Sum(sum) / float64(w*h)
|
return -(mat.Sum(sum) / float64(w*h))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (regr *LogisticRegression) forward(X mat.Matrix) mat.Matrix {
|
func (regr *LogisticRegression) forward(X mat.Matrix) mat.Matrix {
|
||||||
|
@ -76,8 +76,8 @@ func (regr *LogisticRegression) backprop(x, y mat.Matrix) float64 {
|
||||||
return loss
|
return loss
|
||||||
}
|
}
|
||||||
|
|
||||||
func (regr *LogisticRegression) Fit(X, Y mat.Matrix, epochs int, losses *[]float64) {
|
func (regr *LogisticRegression) Fit(X, Y mat.Matrix, losses *[]float64) {
|
||||||
for i := 0; i < epochs; i++ {
|
for i := 0; i < regr.Epochs; i++ {
|
||||||
regr.backprop(X, Y)
|
regr.backprop(X, Y)
|
||||||
if losses != nil {
|
if losses != nil {
|
||||||
*losses = append(*losses, regr.Loss(Y, regr.forward(X)))
|
*losses = append(*losses, regr.Loss(Y, regr.forward(X)))
|
||||||
|
|
|
@ -6,16 +6,55 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLogisticRegression(t *testing.T) {
|
func TestLogisticRegression(t *testing.T) {
|
||||||
X := [][]float64{{.1, .1, .1}, {.2, .2, .2}, {.1, .1, .1}, {.2, .2, .2}}
|
X := [][]float64{
|
||||||
Y := [][]float64{{0}, {1}, {0}, {1}}
|
{.1, .1, .1},
|
||||||
|
{.2, .2, .2},
|
||||||
|
{.1, .1, .1},
|
||||||
|
{.2, .2, .2},
|
||||||
|
}
|
||||||
|
//Y := [][]float64{{1}, {0}, {1}, {0}}
|
||||||
|
Y := [][]float64{
|
||||||
|
{0},
|
||||||
|
{1},
|
||||||
|
{0},
|
||||||
|
{1},
|
||||||
|
}
|
||||||
XDense := Array2DToDense(X)
|
XDense := Array2DToDense(X)
|
||||||
YDense := Array2DToDense(Y)
|
YDense := Array2DToDense(Y)
|
||||||
epochs := 1000
|
epochs := 1000
|
||||||
regr := &LogisticRegression{
|
regr := &LogisticRegression{
|
||||||
LearningRate: .1,
|
Epochs: epochs,
|
||||||
|
LearningRate: .01,
|
||||||
}
|
}
|
||||||
regr.Fit(XDense, YDense, epochs, nil)
|
regr.Fit(XDense, YDense, nil)
|
||||||
fmt.Println(regr.Weights, regr.Bias)
|
fmt.Println(regr.Weights, regr.Bias)
|
||||||
yPred := regr.Predict(XDense)
|
yPred := regr.Predict(XDense)
|
||||||
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMCLogisticRegression(t *testing.T) {
|
||||||
|
X := [][]float64{
|
||||||
|
{.1, .1, .1},
|
||||||
|
{.2, .2, .2},
|
||||||
|
{.1, .1, .1},
|
||||||
|
{.2, .2, .2},
|
||||||
|
}
|
||||||
|
//Y := [][]float64{{1}, {0}, {1}, {0}}
|
||||||
|
Y := [][]float64{
|
||||||
|
{0, 1},
|
||||||
|
{1, 0},
|
||||||
|
{0, 1},
|
||||||
|
{1, 0},
|
||||||
|
}
|
||||||
|
XDense := Array2DToDense(X)
|
||||||
|
YDense := Array2DToDense(Y)
|
||||||
|
epochs := 1000
|
||||||
|
regr := &MCLogisticRegression{
|
||||||
|
Epochs: epochs,
|
||||||
|
LearningRate: .01,
|
||||||
|
}
|
||||||
|
regr.Fit(XDense, YDense)
|
||||||
|
// fmt.Println(regr.Weights, regr.Bias)
|
||||||
|
yPred := regr.Predict(XDense)
|
||||||
|
fmt.Println(YDense, yPred) //, regr.Loss(YDense, yPred))
|
||||||
|
}
|
||||||
|
|
44
regr/src/MCLogisticRegression.go
Normal file
44
regr/src/MCLogisticRegression.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package src
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MCLogisticRegression struct {
|
||||||
|
Epochs int
|
||||||
|
LearningRate float64
|
||||||
|
Models []*LogisticRegression
|
||||||
|
cols int
|
||||||
|
rows int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (regr *MCLogisticRegression) Fit(x, y mat.Matrix) {
|
||||||
|
regr.rows, regr.cols = y.Dims()
|
||||||
|
regr.Models = make([]*LogisticRegression, regr.cols)
|
||||||
|
for j := 0; j < regr.cols; j++ {
|
||||||
|
regr.Models[j] = &LogisticRegression{
|
||||||
|
Epochs: regr.Epochs,
|
||||||
|
LearningRate: regr.LearningRate,
|
||||||
|
}
|
||||||
|
yj := mat.Col(nil, j, y)
|
||||||
|
Y := mat.NewDense(len(yj), 1, yj)
|
||||||
|
regr.Models[j].Fit(x, Y, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
|
||||||
|
probs := mat.NewDense(regr.rows, regr.cols, nil)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(regr.cols)
|
||||||
|
for j := 0; j < regr.cols; j++ {
|
||||||
|
go func(j int) {
|
||||||
|
pred := regr.Models[j].Predict(x).(*mat.Dense)
|
||||||
|
probs.SetCol(j, mat.Col(nil, 0, pred))
|
||||||
|
wg.Done()
|
||||||
|
}(j)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
return probs
|
||||||
|
}
|
Loading…
Reference in a new issue