improve splits

This commit is contained in:
Anton Nesterov 2024-09-28 14:11:00 +02:00
parent 09bbe3d056
commit 597cc15cd1
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5

View file

@ -1,16 +1,67 @@
import pl from "npm:nodejs-polars"; import pl from "npm:nodejs-polars";
export function trainTestSplit(df: pl.DataFrame, size, ...yFeatures: string[]) { type DfSplit = {
let shuffle = df.sample(df.height - 1); trainX: pl.DataFrame;
let testSize = Math.round(shuffle.shape.height * size); trainY: pl.DataFrame;
let trainSize = shuffle.shape.height - testSize; testX: pl.DataFrame;
let [train, test] = [shuffle.head(trainSize), shuffle.tail(testSize)]; testY: pl.DataFrame;
let [trainY, testY] = [train.select(...yFeatures), test.select(...yFeatures)]; size: number;
let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)]; };
return {
trainX, export function sliceK(
trainY, df: pl.DataFrame,
testX, size: number,
testY, k: number,
}; ...yFeatures: string[]
): DfSplit[] {
let testSize = Math.round(df.shape.height * size);
while (testSize % k !== 0) {
testSize -= 1;
}
if (df.shape.height / testSize < k) {
throw new Error(
`k value is too large, max k value is ${df.shape.height / testSize}`,
);
}
let trainSize = df.shape.height - testSize;
let result: DfSplit[] = [];
let data = df;
for (let i = 0; i < k; i++) {
let [train, test] = [data.head(trainSize), data.tail(testSize)];
let [trainY, testY] = [
train.select(...yFeatures),
test.select(...yFeatures),
];
let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)];
result.push({
trainX,
trainY,
testX,
testY,
size,
});
data = pl.concat([test, train]);
}
return result;
}
export function trainTestSplit(
df: pl.DataFrame,
size: number,
shuffle = true,
...yFeatures: string[]
) {
let data = shuffle ? df.sample(df.height - 1) : df;
const result = sliceK(data, size, 1, ...yFeatures);
return result[0];
}
export function kFold(
df: pl.DataFrame,
k: number,
shuffle = true,
...yFeatures: string[]
): DfSplit[] {
let data = shuffle ? df.sample(df.height - 1) : df;
return sliceK(data, 1 / k, k, ...yFeatures);
} }