From 597cc15cd1edb5ed8951d052e37e0c1b5cb081de Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Sat, 28 Sep 2024 14:11:00 +0200 Subject: [PATCH] improve splits --- split.ts | 77 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/split.ts b/split.ts index 3eaf7da..68468a9 100644 --- a/split.ts +++ b/split.ts @@ -1,16 +1,67 @@ import pl from "npm:nodejs-polars"; -export function trainTestSplit(df: pl.DataFrame, size, ...yFeatures: string[]) { - let shuffle = df.sample(df.height - 1); - let testSize = Math.round(shuffle.shape.height * size); - let trainSize = shuffle.shape.height - testSize; - let [train, test] = [shuffle.head(trainSize), shuffle.tail(testSize)]; - let [trainY, testY] = [train.select(...yFeatures), test.select(...yFeatures)]; - let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)]; - return { - trainX, - trainY, - testX, - testY, - }; +type DfSplit = { + trainX: pl.DataFrame; + trainY: pl.DataFrame; + testX: pl.DataFrame; + testY: pl.DataFrame; + size: number; +}; + +export function sliceK( + df: pl.DataFrame, + size: number, + k: number, + ...yFeatures: string[] +): DfSplit[] { + let testSize = Math.round(df.shape.height * size); + while (testSize % k !== 0) { + testSize -= 1; + } + if (df.shape.height / testSize < k) { + throw new Error( + `k value is too large, max k value is ${df.shape.height / testSize}`, + ); + } + let trainSize = df.shape.height - testSize; + let result: DfSplit[] = []; + let data = df; + for (let i = 0; i < k; i++) { + let [train, test] = [data.head(trainSize), data.tail(testSize)]; + let [trainY, testY] = [ + train.select(...yFeatures), + test.select(...yFeatures), + ]; + let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)]; + result.push({ + trainX, + trainY, + testX, + testY, + size, + }); + data = pl.concat([test, train]); + } + return result; +} + +export function trainTestSplit( + df: pl.DataFrame, + size: number, + shuffle = true, + ...yFeatures: string[] +) { + let data = shuffle ? df.sample(df.height - 1) : df; + const result = sliceK(data, size, 1, ...yFeatures); + return result[0]; +} + +export function kFold( + df: pl.DataFrame, + k: number, + shuffle = true, + ...yFeatures: string[] +): DfSplit[] { + let data = shuffle ? df.sample(df.height - 1) : df; + return sliceK(data, 1 / k, k, ...yFeatures); }