use std::collections::VecDeque; /// Calculates online weighted average for a rolling, time-based window #[derive(Clone)] pub struct WeightedMeanWindow { /// The size of the window. On `purge`, any `WeightedPoint` items are considered /// expired if the supplied `time` parameter is greater than `size` from the /// `time` attribute of that `WeightedPoint` item. size: u64, /// The weights and values with times that are "currently" in the aggregation /// window. On `push`, items are added to the "back" of the vedeque. On `purge`, /// items with a `time` that is > `size` difference relative to the `time` passed /// to `purge` are considered expired and removed. In both cases, adding and removing, /// the incremental accumulated sums in `w_sum` and `sum_w` are updated. items: VecDeque, /// The sum of the value * weight for each of the `WeightedPoint`s in `items`. w_sum: f64, /// The sum of the weights of each of the `WeightedPoint`s in `items`. sum_w: f64, } /// Stores the time, value and weight for an item "currently" inside the /// aggregation window of a `WeightedMeanWindow`, allowing its value and /// weight to be subtracted from the accumulated sums of the window when /// the item becomes expired. #[derive(Debug, Clone)] pub struct WeightedPoint { pub time: u64, /// value * weight. /// /// when purging expired items, do not subtract `wt_val * wt`, as `wt_val` /// has already been multiplied by `wt`. Instead, simply substract `wt_val` /// from `w_sum`. pub wt_val: f64, pub wt: f64, } impl WeightedMeanWindow { pub fn new(size: u64) -> Self { Self { size, items: Default::default(), w_sum: 0.0, sum_w: 0.0, } } /// Removes expired items and updates incremental calculations. /// /// Returns `true` if any items were removed. pub fn purge(&mut self, time: u64) -> bool { // this is somewhat awkwardly implemented, but there is not anything like // `drain_while` on `VecDeque` (or `Vec`) that would work like `take_while`, // except also removing the items. Since we need the data in the items we // are removing to update `sum_w` and `w_sum`, we loop over the expired // items first, counting them in `n_remove`, then actually remove them // in a second pass. let mut n_remove = 0; { // extra scope needed to shush the borrow checker let items = &self.items; let w_sum = &mut self.w_sum; let sum_w = &mut self.sum_w; let size = self.size; for expired in items.iter().take_while(|x| time - x.time > size) { *w_sum -= expired.wt_val; *sum_w -= expired.wt; n_remove += 1; } } for _ in 0..n_remove { self.items.pop_front(); } // when items is empty, set w_sum, sum_w to 0.0. the motive // of this approach, versus an if block with assignment, is // for the code to be "branchless" and do the same work each // time, in a cache- and branch predictor-friendly manner. let zeroer: f64 = ( ! self.items.is_empty()) as u8 as f64; self.w_sum *= zeroer; self.sum_w *= zeroer; n_remove > 0 } /// Add a new item, updating incremental calculations in the process. /// /// Note: it is assumed that `time` is >= the highest `time` value for any previous /// item. The expiration logic `purge` relies on the items being added to a /// `WeightedMeanWindow` in chronological order. pub fn push(&mut self, time: u64, val: f64, wt: f64) { let wt_val: f64 = val * wt; self.w_sum += wt_val; self.sum_w += wt; self.items.push_back(WeightedPoint { time, wt_val, wt }); } /// Calculate the weighted mean from current state of incremental /// accumulators. /// /// Note; this value is not cached. pub fn weighted_mean(&self) -> f64 { self.w_sum / self.sum_w } /// Checks whether items `is_empty` before trying to calculate. /// Returns None if items is empty. /// /// Note: this value is not cached. pub fn checked_weighted_mean(&self) -> Option { match self.is_empty() { true => None, false => Some(self.w_sum / self.sum_w), } } /// Purge, push and get `checked_weighted_mean`, all in one convenient step. pub fn update(&mut self, time: u64, val: f64, wt: f64) -> Option { self.purge(time); self.push(time, val, wt); self.checked_weighted_mean() } pub fn len(&self) -> usize { self.items.len() } pub fn is_empty(&self) -> bool { self.items.is_empty() } } #[allow(unused)] #[cfg(test)] mod tests { use super::*; use approx::assert_relative_eq; #[test] fn weighted_mean_output_matches_numpy_average() { let xs: Vec = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248, 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ]; let ws: Vec = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117, 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ]; let mut w = WeightedMeanWindow::new(1_000_000_000); for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() { w.push(i as u64, val, weight); } w.purge(11); assert_eq!(w.items.len(), 10); assert_relative_eq!(w.weighted_mean(), 0.63599718086101786, epsilon = 0.0001); } #[test] fn checked_weighted_mean_returns_none_when_items_is_empty_and_unchecked_is_nan() { let w = WeightedMeanWindow::new(1_000_000_000); assert!(w.is_empty()); assert_relative_eq!(w.sum_w, 0.0f64); assert_relative_eq!(w.w_sum, 0.0f64); assert!(w.checked_weighted_mean().is_none()); assert!(w.weighted_mean().is_nan()); } #[test] fn purge_expires_items() { let xs: Vec = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248, 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ]; let ws: Vec = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117, 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ]; let xs_times_ws: Vec = xs.iter().zip(ws.iter()).map(|(&x,&w)| x * w).collect(); let mut w = WeightedMeanWindow::new(10); for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() { w.push(i as u64, val, weight); } w.purge(10); assert_eq!(w.items.len(), 10); w.purge(11); assert_eq!(w.items.len(), 9); assert_relative_eq!(w.sum_w, (&ws[1..]).iter().sum::(), epsilon = 1e-5); assert_relative_eq!(w.w_sum, (&xs_times_ws[1..]).iter().sum::(), epsilon = 1e-5); w.purge(11); assert_eq!(w.items.len(), 9); w.purge(12); assert_eq!(w.items.len(), 8); assert_relative_eq!(w.sum_w, (&ws[2..]).iter().sum::(), epsilon = 1e-5); assert_relative_eq!(w.w_sum, (&xs_times_ws[2..]).iter().sum::(), epsilon = 1e-5); } }