From 09bbe3d056f89d6934aeca93293ec63e30919ce6 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Fri, 27 Sep 2024 22:37:35 +0200 Subject: [PATCH] fix augment --- encoding.ts | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/encoding.ts b/encoding.ts index 9e9a840..9592b1c 100644 --- a/encoding.ts +++ b/encoding.ts @@ -57,11 +57,14 @@ export function polynomialFeatures( } /** - * Adds missing rows at given interval, uses mean of previous and next value. - * Example for one feature: [1, 2, 4, 5] -> [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5] + * Add rows at given interval, use average to fill values. + * Usage: + * ```ts + * let df = augmentMeanForward("price", df, 100); + * ``` * @param feature * @param df - * @param bin + * @param interval */ export function augmentMeanForward( feature: string, @@ -69,27 +72,27 @@ export function augmentMeanForward( interval = 100, ) { let sorted = df.sort(feature); - let result = sorted.head(0); - let n: null | number = null; - for (let i = 0; i < sorted.height - 1; i++) { - let p1 = n ?? sorted.row(i).at(-1); - let p2 = sorted.row(i + 1).at(-1); + let featIdx = sorted.findIdxByName(feature); + let result = sorted.head(1); + for (let i = 0; i < sorted.height; i++) { + let p1 = sorted.row(i).at(featIdx); + let k = (i + 1) % sorted.height; + let p2 = sorted.row(k).at(featIdx); if (p2 - p1 > interval) { - let avg = (p1 + p2) / 2; - result = pl.concat([ - result, - pl.concat([result.tail(2), sorted.slice({ offset: i + 1, length: 2 })]) - .shift(-1) - .fillNull("mean") - .tail(1), - ]); - if (p2 - avg > interval) { - i--; - n = avg; - continue; + for (let j = 0; j < Math.round((p2 - p1) / interval) - 1; j++) { + result = pl.concat([ + result, + pl.concat([ + result.tail(1), + sorted.slice({ offset: k, length: 1 }), + sorted.head(1).shift(-1), + ]) + .fillNull("mean") + .tail(1), + ]); } + } else { result = pl.concat([result, sorted.slice(1, i)]); - n = null; } } return result;