shortcuts/split.ts

17 lines
564 B
TypeScript
Raw Normal View History

2024-09-26 01:41:33 +00:00
import pl from "npm:nodejs-polars";
export function trainTestSplit(df: pl.DataFrame, size, ...yFeatures: string[]) {
let shuffle = df.sample(df.height - 1);
let testSize = Math.round(shuffle.shape.height * size);
let trainSize = shuffle.shape.height - testSize;
let [train, test] = [shuffle.head(trainSize), shuffle.tail(testSize)];
let [trainY, testY] = [train.select(...yFeatures), test.select(...yFeatures)];
let [trainX, testX] = [train.drop(...yFeatures), test.drop(...yFeatures)];
return {
trainX,
trainY,
testX,
testY,
};
}