fix augment

This commit is contained in:
Anton Nesterov 2024-09-27 22:37:35 +02:00
parent 5d38fb02f9
commit 09bbe3d056
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5

View file

@ -57,11 +57,14 @@ export function polynomialFeatures(
}
/**
* Adds missing rows at given interval, uses mean of previous and next value.
* Example for one feature: [1, 2, 4, 5] -> [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]
* Add rows at given interval, use average to fill values.
* Usage:
* ```ts
* let df = augmentMeanForward("price", df, 100);
* ```
* @param feature
* @param df
* @param bin
* @param interval
*/
export function augmentMeanForward(
feature: string,
@ -69,27 +72,27 @@ export function augmentMeanForward(
interval = 100,
) {
let sorted = df.sort(feature);
let result = sorted.head(0);
let n: null | number = null;
for (let i = 0; i < sorted.height - 1; i++) {
let p1 = n ?? sorted.row(i).at(-1);
let p2 = sorted.row(i + 1).at(-1);
let featIdx = sorted.findIdxByName(feature);
let result = sorted.head(1);
for (let i = 0; i < sorted.height; i++) {
let p1 = sorted.row(i).at(featIdx);
let k = (i + 1) % sorted.height;
let p2 = sorted.row(k).at(featIdx);
if (p2 - p1 > interval) {
let avg = (p1 + p2) / 2;
result = pl.concat([
result,
pl.concat([result.tail(2), sorted.slice({ offset: i + 1, length: 2 })])
.shift(-1)
.fillNull("mean")
.tail(1),
]);
if (p2 - avg > interval) {
i--;
n = avg;
continue;
for (let j = 0; j < Math.round((p2 - p1) / interval) - 1; j++) {
result = pl.concat([
result,
pl.concat([
result.tail(1),
sorted.slice({ offset: k, length: 1 }),
sorted.head(1).shift(-1),
])
.fillNull("mean")
.tail(1),
]);
}
} else {
result = pl.concat([result, sorted.slice(1, i)]);
n = null;
}
}
return result;