From cde2e6fe9936cca997671efd0673f5ce95568655 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Thu, 3 Oct 2024 21:55:48 +0200 Subject: [PATCH] loss for multi log reg --- regr/src/LogisticRegression_test.go | 2 +- regr/src/MCLogisticRegression.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/regr/src/LogisticRegression_test.go b/regr/src/LogisticRegression_test.go index b449c3a..a99c5dd 100644 --- a/regr/src/LogisticRegression_test.go +++ b/regr/src/LogisticRegression_test.go @@ -56,5 +56,5 @@ func TestMCLogisticRegression(t *testing.T) { regr.Fit(XDense, YDense) // fmt.Println(regr.Weights, regr.Bias) yPred := regr.Predict(XDense) - fmt.Println(YDense, yPred) //, regr.Loss(YDense, yPred)) + fmt.Println(YDense, yPred, regr.Loss(YDense, yPred)) } diff --git a/regr/src/MCLogisticRegression.go b/regr/src/MCLogisticRegression.go index 39db3f7..3db8c7d 100644 --- a/regr/src/MCLogisticRegression.go +++ b/regr/src/MCLogisticRegression.go @@ -42,3 +42,15 @@ func (regr *MCLogisticRegression) Predict(x mat.Matrix) mat.Matrix { wg.Wait() return probs } + +func (regr *MCLogisticRegression) Loss(yTrue, yPred mat.Matrix) float64 { + loss := 0. + for j := 0; j < regr.cols; j++ { + yj := mat.Col(nil, j, yTrue) + Y := mat.NewDense(len(yj), 1, yj) + ypj := mat.Col(nil, j, yPred) + YP := mat.NewDense(len(ypj), 1, ypj) + loss += regr.Models[j].Loss(Y, YP) + } + return loss / float64(regr.cols) +}