port gonum stat

This commit is contained in:
Anton Nesterov 2024-09-28 22:18:36 +02:00
parent 58174039b1
commit cc0dcd0bf4
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
9 changed files with 767 additions and 44 deletions

View file

@ -1,2 +1,5 @@
wasm:
dev:
GOOS=js GOARCH=wasm tinygo build -o stat/mod.wasm ./stat/main.go
prod:
GOOS=js GOARCH=wasm tinygo build -o stat/mod.wasm -no-debug ./stat/main.go

View file

@ -2,24 +2,60 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stats initialized\n",
"Estimated offset is: 0.988572\n",
"Estimated slope is: 3.000154\n",
"R^2: 0.999999\n"
"Estimated offset is: 0.982618\n",
"Estimated slope is: 3.000173\n",
"R^2: 0.999999\n",
"2 <object> <object>\n"
]
},
{
"data": {
"text/plain": [
"\u001b[1mnull\u001b[22m"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import stats from \"./stat/mod.ts\";\n",
"\n",
"stats.example();"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 \u0001\u0002\u0003\n"
]
},
{
"data": {
"text/plain": [
"\u001b[1mnull\u001b[22m"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
}
],
"metadata": {

View file

@ -23,16 +23,16 @@ export function sliceK(
`k value is too large, max k value is ${df.shape.height / testSize}`,
);
}
let trainSize = df.shape.height - testSize;
let result: DfSplit[] = [];
const trainSize = df.shape.height - testSize;
const result: DfSplit[] = [];
let data = df;
for (let i = 0; i < k; i++) {
let [train, test] = [data.head(trainSize), data.tail(testSize)];
let [trainY, testY] = [
const [train, test] = [data.head(trainSize), data.tail(testSize)];
const [trainY, testY] = [
train.select(...yFeatures),
test.select(...yFeatures),
];
let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)];
const [trainX, testX] = [train.drop(yFeatures), test.drop(yFeatures)];
result.push({
trainX,
trainY,
@ -51,7 +51,7 @@ export function trainTestSplit(
shuffle = true,
...yFeatures: string[]
) {
let data = shuffle ? df.sample(df.height - 1) : df;
const data = shuffle ? df.sample(df.height - 1) : df;
const result = sliceK(data, size, 1, ...yFeatures);
return result[0];
}
@ -62,6 +62,6 @@ export function kFold(
shuffle = true,
...yFeatures: string[]
): DfSplit[] {
let data = shuffle ? df.sample(df.height - 1) : df;
const data = shuffle ? df.sample(df.height - 1) : df;
return sliceK(data, 1 / k, k, ...yFeatures);
}

View file

@ -1,42 +1,65 @@
//go:build js && wasm
// +build js,wasm
package main
import (
"fmt"
"golang.org/x/exp/rand"
"l12.xyz/x/shortcuts/stat/src"
"gonum.org/v1/gonum/stat"
"syscall/js"
)
//go:export example
func Example() {
var (
xs = make([]float64, 100)
ys = make([]float64, 100)
weights []float64
)
line := func(x float64) float64 {
return 1 + 3*x
}
for i := range xs {
xs[i] = float64(i)
ys[i] = line(xs[i]) + 0.1*rand.NormFloat64()
}
// Do not force the regression line to pass through the origin.
origin := false
alpha, beta := stat.LinearRegression(xs, ys, weights, origin)
r2 := stat.RSquared(xs, ys, weights, alpha, beta)
fmt.Printf("Estimated offset is: %.6f\n", alpha)
fmt.Printf("Estimated slope is: %.6f\n", beta)
fmt.Printf("R^2: %.6f\n", r2)
func InitStatExports(this js.Value, args []js.Value) interface{} {
exports := args[0]
exports.Set("Bhattacharyya", js.FuncOf(src.Bhattacharyya))
exports.Set("BivariateMoment", js.FuncOf(src.BivariateMoment))
exports.Set("ChiSquare", js.FuncOf(src.ChiSquare))
exports.Set("CircularMean", js.FuncOf(src.CircularMean))
exports.Set("Correlation", js.FuncOf(src.Correlation))
exports.Set("Covariance", js.FuncOf(src.Covariance))
exports.Set("CrossEntropy", js.FuncOf(src.CrossEntropy))
exports.Set("Entropy", js.FuncOf(src.Entropy))
exports.Set("ExKurtosis", js.FuncOf(src.ExKurtosis))
exports.Set("GeometricMean", js.FuncOf(src.GeometricMean))
exports.Set("HarmonicMean", js.FuncOf(src.HarmonicMean))
exports.Set("Hellinger", js.FuncOf(src.Hellinger))
exports.Set("Histogram", js.FuncOf(src.Histogram))
exports.Set("JensenShannon", js.FuncOf(src.JensenShannon))
exports.Set("Kendall", js.FuncOf(src.Kendall))
exports.Set("KolmogorovSmirnov", js.FuncOf(src.KolmogorovSmirnov))
exports.Set("KullbackLeibler", js.FuncOf(src.KullbackLeibler))
exports.Set("LinearRegression", js.FuncOf(src.LinearRegression))
exports.Set("Mean", js.FuncOf(src.Mean))
exports.Set("MeanStdDev", js.FuncOf(src.MeanStdDev))
exports.Set("MeanVariance", js.FuncOf(src.MeanVariance))
exports.Set("Mode", js.FuncOf(src.Mode))
exports.Set("Moment", js.FuncOf(src.Moment))
exports.Set("MomentAbout", js.FuncOf(src.MomentAbout))
exports.Set("PopMeanStdDev", js.FuncOf(src.PopMeanStdDev))
exports.Set("PopMeanVariance", js.FuncOf(src.PopMeanVariance))
exports.Set("PopStdDev", js.FuncOf(src.PopStdDev))
exports.Set("PopVariance", js.FuncOf(src.PopVariance))
exports.Set("Quantile", js.FuncOf(src.Quantile))
exports.Set("RNoughtSquared", js.FuncOf(src.RNoughtSquared))
exports.Set("ROC", js.FuncOf(src.ROC))
exports.Set("RSquared", js.FuncOf(src.RSquared))
exports.Set("RSquaredFrom", js.FuncOf(src.RSquaredFrom))
exports.Set("Skew", js.FuncOf(src.Skew))
exports.Set("SortWeighted", js.FuncOf(src.SortWeighted))
exports.Set("SortWeightedLabeled", js.FuncOf(src.SortWeightedLabeled))
exports.Set("StdDev", js.FuncOf(src.StdDev))
exports.Set("StdErr", js.FuncOf(src.StdErr))
exports.Set("StdScore", js.FuncOf(src.StdScore))
exports.Set("TOC", js.FuncOf(src.TOC))
exports.Set("Variance", js.FuncOf(src.Variance))
return nil
}
func main() {
wait := make(chan struct{}, 0)
js.Global().Set("__InitStatExports", js.FuncOf(InitStatExports))
fmt.Println("Stats initialized")
<-wait
}

View file

@ -1,4 +1,5 @@
import "../lib/wasm.js";
import type { Stat } from "./types.ts";
// @ts-expect-error: no types
const go = new Go();
@ -6,7 +7,15 @@ const go = new Go();
const code =
await (await fetch(import.meta.url.replace("/mod.ts", "/mod.wasm")))
.arrayBuffer();
const wasmMmodule = await WebAssembly.instantiate(code, go.importObject);
const wasm = wasmMmodule.instance;
go.run(wasm);
export default wasm.exports as Record<string, (...args: unknown[])=> unknown>;
const _exports = {} as Record<string, (...args: unknown[]) => unknown> & Stat;
// @ts-ignore: no types
__InitStatExports(_exports);
export default _exports;

Binary file not shown.

514
stat/src/GoNumStat.go Normal file
View file

@ -0,0 +1,514 @@
//go:build js && wasm
// +build js,wasm
package src
import (
"syscall/js"
"gonum.org/v1/gonum/stat"
)
func Bhattacharyya(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.Bhattacharyya(p, q)
}
func BivariateMoment(this js.Value, args []js.Value) interface{} {
var (
r = args[0].Float()
s = args[1].Float()
xs = JSFloatArray(args[2])
ys = JSFloatArray(args[3])
weights = JSFloatArray(args[4])
)
if len(weights) == 0 {
weights = nil
}
return stat.BivariateMoment(r, s, xs, ys, weights)
}
func CDF(this js.Value, args []js.Value) interface{} {
var (
q = args[0].Float()
xs = JSFloatArray(args[2])
weights = JSFloatArray(args[3])
)
return stat.CDF(q, 1, xs, weights)
}
func ChiSquare(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.ChiSquare(p, q)
}
func CircularMean(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.CircularMean(xs, weights)
}
func Correlation(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
return stat.Correlation(xs, ys, weights)
}
func CorrelationMatrix(this js.Value, args []js.Value) interface{} {
panic("not implemented")
}
func Covariance(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
return stat.Covariance(xs, ys, weights)
}
func CovarianceMatrix(this js.Value, args []js.Value) interface{} {
panic("not implemented")
}
func CrossEntropy(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.CrossEntropy(p, q)
}
func Entropy(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
)
return stat.Entropy(p)
}
func ExKurtosis(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.ExKurtosis(xs, weights)
}
func GeometricMean(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.GeometricMean(xs, weights)
}
func HarmonicMean(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.HarmonicMean(xs, weights)
}
func Hellinger(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.Hellinger(p, q)
}
func Histogram(this js.Value, args []js.Value) interface{} {
var (
counts = JSFloatArray(args[0])
divs = JSFloatArray(args[1])
xs = JSFloatArray(args[2])
weights = JSFloatArray(args[3])
)
if len(weights) == 0 {
weights = nil
}
return stat.Histogram(counts, divs, xs, weights)
}
func JensenShannon(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.JensenShannon(p, q)
}
func Kendall(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
return stat.Kendall(xs, ys, weights)
}
func KolmogorovSmirnov(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
xWeights = JSFloatArray(args[1])
ys = JSFloatArray(args[2])
yWeights = JSFloatArray(args[3])
)
return stat.KolmogorovSmirnov(xs, xWeights, ys, yWeights)
}
func KullbackLeibler(this js.Value, args []js.Value) interface{} {
var (
p = JSFloatArray(args[0])
q = JSFloatArray(args[1])
)
return stat.KullbackLeibler(p, q)
}
func LinearRegression(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
origin = args[3].Bool()
)
if len(weights) == 0 {
weights = nil
}
alpha, beta := stat.LinearRegression(xs, ys, weights, origin)
return map[string]interface{}{
"alpha": alpha,
"beta": beta,
}
}
func Mahalanobis(this js.Value, args []js.Value) interface{} {
panic("not implemented")
}
func Mean(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.Mean(xs, weights)
}
func MeanStdDev(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
mean, stdDev := stat.MeanStdDev(xs, weights)
return map[string]interface{}{
"mean": mean,
"stdDev": stdDev,
}
}
func MeanVariance(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
mean, variance := stat.MeanVariance(xs, weights)
return map[string]interface{}{
"mean": mean,
"variance": variance,
}
}
func Mode(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
value, count := stat.Mode(xs, weights)
return map[string]interface{}{
"value": value,
"count": count,
}
}
func Moment(this js.Value, args []js.Value) interface{} {
var (
r = args[0].Float()
xs = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
return stat.Moment(r, xs, weights)
}
func MomentAbout(this js.Value, args []js.Value) interface{} {
var (
r = args[0].Float()
xs = JSFloatArray(args[1])
mean = args[2].Float()
weights = JSFloatArray(args[3])
)
if len(weights) == 0 {
weights = nil
}
return stat.MomentAbout(r, xs, mean, weights)
}
func PopMeanStdDev(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
mean, stdDev := stat.PopMeanStdDev(xs, weights)
return map[string]interface{}{
"mean": mean,
"stdDev": stdDev,
}
}
func PopMeanVariance(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
mean, variance := stat.PopMeanVariance(xs, weights)
return map[string]interface{}{
"mean": mean,
"variance": variance,
}
}
func PopStdDev(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.PopStdDev(xs, weights)
}
func PopVariance(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.PopVariance(xs, weights)
}
func Quantile(this js.Value, args []js.Value) interface{} {
var (
q = args[0].Float()
typ = args[1].String()
xs = JSFloatArray(args[2])
weights = JSFloatArray(args[3])
)
if typ == "empirical" {
return stat.Quantile(q, 1, xs, weights)
} else if typ == "linear" {
return stat.Quantile(q, 4, xs, weights)
} else {
return stat.Quantile(q, 2, xs, weights)
}
}
func RNoughtSquared(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
beta = args[3].Float()
)
if len(weights) == 0 {
weights = nil
}
return stat.RNoughtSquared(xs, ys, weights, beta)
}
func ROC(this js.Value, args []js.Value) interface{} {
var (
cutoffs = JSFloatArray(args[0])
y = JSFloatArray(args[1])
classes = JSBoolArray(args[2])
weights = JSFloatArray(args[3])
)
tpr, fpr, tresh := stat.ROC(cutoffs, y, classes, weights)
return map[string]interface{}{
"tpr": tpr,
"fpr": fpr,
"tresh": tresh,
}
}
func RSquared(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
alpha = args[3].Float()
beta = args[4].Float()
)
if len(weights) == 0 {
weights = nil
}
return stat.RSquared(xs, ys, weights, alpha, beta)
}
func RSquaredFrom(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
ys = JSFloatArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
return stat.RSquaredFrom(xs, ys, weights)
}
func Skew(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.Skew(xs, weights)
}
func SortWeighted(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
stat.SortWeighted(xs, weights)
return xs
}
func SortWeightedLabeled(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
labels = JSBoolArray(args[1])
weights = JSFloatArray(args[2])
)
if len(weights) == 0 {
weights = nil
}
stat.SortWeightedLabeled(xs, labels, weights)
return map[string]interface{}{
"xs": xs,
"labels": labels,
}
}
func StdDev(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.StdDev(xs, weights)
}
func StdErr(this js.Value, args []js.Value) interface{} {
var (
std = args[0].Float()
size = args[1].Float()
)
return stat.StdErr(std, size)
}
func StdScore(this js.Value, args []js.Value) interface{} {
var (
x = args[0].Float()
mean = args[1].Float()
std = args[2].Float()
)
return stat.StdScore(x, mean, std)
}
func TOC(this js.Value, args []js.Value) interface{} {
var (
classes = JSBoolArray(args[0])
weights = JSFloatArray(args[1])
)
min, ntp, max := stat.TOC(classes, weights)
return map[string]interface{}{
"min": min,
"ntp": ntp,
"max": max,
}
}
func Variance(this js.Value, args []js.Value) interface{} {
var (
xs = JSFloatArray(args[0])
weights = JSFloatArray(args[1])
)
if len(weights) == 0 {
weights = nil
}
return stat.Variance(xs, weights)
}

24
stat/src/utils.go Normal file
View file

@ -0,0 +1,24 @@
//go:build js && wasm
// +build js,wasm
package src
import (
"syscall/js"
)
func JSFloatArray(arg js.Value) []float64 {
arr := make([]float64, arg.Length())
for i := 0; i < len(arr); i++ {
arr[i] = arg.Index(i).Float()
}
return arr
}
func JSBoolArray(arg js.Value) []bool {
arr := make([]bool, arg.Length())
for i := 0; i < len(arr); i++ {
arr[i] = arg.Index(i).Bool()
}
return arr
}

114
stat/types.ts Normal file
View file

@ -0,0 +1,114 @@
export interface Stat {
Bhattacharyya: (xs: number[], ys: number[]) => number;
BivariateMoment: (
q: number,
p: number,
xs: number[],
ys: number[],
weights: number[],
) => number;
CDF: (q: number, xs: number[], weights: number[]) => number;
ChiSquare: (xs: number[], ys: number[]) => number;
CircularMean: (xs: number[], weights: number[]) => number;
Correlation: (xs: number[], ys: number[], weights: number[]) => number;
Covariance: (xs: number[], ys: number[], weights: number[]) => number;
CrossEntropy: (xs: number[], ys: number[]) => number;
Entropy: (xs: number[]) => number;
ExKurtosis: (xs: number[], weights: number[]) => number;
GeometricMean: (xs: number[], weights: number[]) => number;
HarmonicMean: (xs: number[], weights: number[]) => number;
Hellinger: (xs: number[], ys: number[]) => number;
Histogram: (
counts: number[],
divs: number[],
xs: number[],
bins: number,
) => number[];
JensenShannon: (xs: number[], ys: number[]) => number;
Kendall: (xs: number[], ys: number[], weights: number[]) => number;
KolmogorovSmirnov: (
xs: number[],
xw: number[],
ys: number[],
yw: number[],
) => number;
KullbackLeibler: (xs: number[], ys: number[]) => number;
LinearRegression: (
xs: number[],
ys: number[],
weights: number[],
origin: boolean,
) => { alpha: number; beta: number };
Mean: (xs: number[], weights: number[]) => number;
MeanStdDev: (
xs: number[],
weights: number[],
) => { mean: number; stdDev: number };
MeanVariance: (
xs: number[],
weights: number[],
) => { mean: number; variance: number };
Mode: (xs: number[], weights: number[]) => { value: number; count: number };
Moment: (q: number, xs: number[], weights: number[]) => number;
MomentAbout: (
q: number,
xs: number[],
mean: number,
weights: number[],
) => number;
PopMeanStdDev: (
xs: number[],
weights: number[],
) => { mean: number; stdDev: number };
PopMeanVariance: (
xs: number[],
weights: number[],
) => { mean: number; variance: number };
PopStdDev: (xs: number[], weights: number[]) => number;
PopVariance: (xs: number[], weights: number[]) => number;
Quantile: (
q: number,
type: "linear" | "empirical",
xs: number[],
weights: number[],
) => number;
RNoughtSquared: (
xs: number[],
ys: number[],
weights: number[],
beta: number,
) => number;
ROC: (
cutoffs: number[],
ys: number[],
classes: boolean[],
weights: number[],
) => { tpr: number[]; fpr: number[]; tresh: number[] };
RSquared: (
xs: number[],
ys: number[],
weights: number[],
alpha: number,
beta: number,
) => number;
RSquaredFrom: (
estimates: number[],
ys: number[],
weights: number[],
) => number;
Skew: (xs: number[], weights: number[]) => number;
SortWeighted: (xs: number[], weights: number[]) => number[];
SortWeightedLabeled: (
xs: number[],
labels: boolean[],
weights: number[],
) => { xs: number[]; labels: boolean[] };
StdDev: (xs: number[], weights: number[]) => number;
StdErr: (std: number, size: number) => number;
StdScore: (x: number, mean: number, stdDev: number) => number;
TOC: (
classes: boolean[],
ys: number[],
) => { min: number[]; ntp: number[]; max: number[] };
Variance: (xs: number[], weights: number[]) => number;
}