fix augment

This commit is contained in:
Anton Nesterov 2024-09-27 22:37:35 +02:00
parent 5d38fb02f9
commit 09bbe3d056
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5

View file

@ -57,11 +57,14 @@ export function polynomialFeatures(
} }
/** /**
* Adds missing rows at given interval, uses mean of previous and next value. * Add rows at given interval, use average to fill values.
* Example for one feature: [1, 2, 4, 5] -> [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5] * Usage:
* ```ts
* let df = augmentMeanForward("price", df, 100);
* ```
* @param feature * @param feature
* @param df * @param df
* @param bin * @param interval
*/ */
export function augmentMeanForward( export function augmentMeanForward(
feature: string, feature: string,
@ -69,27 +72,27 @@ export function augmentMeanForward(
interval = 100, interval = 100,
) { ) {
let sorted = df.sort(feature); let sorted = df.sort(feature);
let result = sorted.head(0); let featIdx = sorted.findIdxByName(feature);
let n: null | number = null; let result = sorted.head(1);
for (let i = 0; i < sorted.height - 1; i++) { for (let i = 0; i < sorted.height; i++) {
let p1 = n ?? sorted.row(i).at(-1); let p1 = sorted.row(i).at(featIdx);
let p2 = sorted.row(i + 1).at(-1); let k = (i + 1) % sorted.height;
let p2 = sorted.row(k).at(featIdx);
if (p2 - p1 > interval) { if (p2 - p1 > interval) {
let avg = (p1 + p2) / 2; for (let j = 0; j < Math.round((p2 - p1) / interval) - 1; j++) {
result = pl.concat([ result = pl.concat([
result, result,
pl.concat([result.tail(2), sorted.slice({ offset: i + 1, length: 2 })]) pl.concat([
.shift(-1) result.tail(1),
.fillNull("mean") sorted.slice({ offset: k, length: 1 }),
.tail(1), sorted.head(1).shift(-1),
]); ])
if (p2 - avg > interval) { .fillNull("mean")
i--; .tail(1),
n = avg; ]);
continue;
} }
} else {
result = pl.concat([result, sorted.slice(1, i)]); result = pl.concat([result, sorted.slice(1, i)]);
n = null;
} }
} }
return result; return result;