loss for multi log reg

This commit is contained in:
Anton Nesterov 2024-10-03 21:55:48 +02:00
parent bf43ecee22
commit cde2e6fe99
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
2 changed files with 13 additions and 1 deletions

View file

@ -56,5 +56,5 @@ func TestMCLogisticRegression(t *testing.T) {
regr.Fit(XDense, YDense)
// fmt.Println(regr.Weights, regr.Bias)
yPred := regr.Predict(XDense)
fmt.Println(YDense, yPred) //, regr.Loss(YDense, yPred))
fmt.Println(YDense, yPred, regr.Loss(YDense, yPred))
}

View file

@ -42,3 +42,15 @@ func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix {
wg.Wait()
return probs
}
func (regr *MCLogisticRegression) Loss(yTrue, yPred mat.Matrix) float64 {
loss := 0.
for j := 0; j < regr.cols; j++ {
yj := mat.Col(nil, j, yTrue)
Y := mat.NewDense(len(yj), 1, yj)
ypj := mat.Col(nil, j, yPred)
YP := mat.NewDense(len(ypj), 1, ypj)
loss += regr.Models[j].Loss(Y, YP)
}
return loss / float64(regr.cols)
}