You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
7.3KB

  1. use std::collections::VecDeque;
  2. /// Calculates online weighted average for a rolling, time-based window
  3. #[derive(Clone)]
  4. pub struct WeightedMeanWindow {
  5. /// The size of the window. On `purge`, any `WeightedPoint` items are considered
  6. /// expired if the supplied `time` parameter is greater than `size` from the
  7. /// `time` attribute of that `WeightedPoint` item.
  8. size: u64,
  9. /// The weights and values with times that are "currently" in the aggregation
  10. /// window. On `push`, items are added to the "back" of the vedeque. On `purge`,
  11. /// items with a `time` that is > `size` difference relative to the `time` passed
  12. /// to `purge` are considered expired and removed. In both cases, adding and removing,
  13. /// the incremental accumulated sums in `w_sum` and `sum_w` are updated.
  14. items: VecDeque<WeightedPoint>,
  15. /// The sum of the value * weight for each of the `WeightedPoint`s in `items`.
  16. w_sum: f64,
  17. /// The sum of the weights of each of the `WeightedPoint`s in `items`.
  18. sum_w: f64,
  19. }
  20. /// Stores the time, value and weight for an item "currently" inside the
  21. /// aggregation window of a `WeightedMeanWindow`, allowing its value and
  22. /// weight to be subtracted from the accumulated sums of the window when
  23. /// the item becomes expired.
  24. #[derive(Debug, Clone)]
  25. pub struct WeightedPoint {
  26. pub time: u64,
  27. /// value * weight.
  28. ///
  29. /// when purging expired items, do not subtract `wt_val * wt`, as `wt_val`
  30. /// has already been multiplied by `wt`. Instead, simply substract `wt_val`
  31. /// from `w_sum`.
  32. pub wt_val: f64,
  33. pub wt: f64,
  34. }
  35. impl WeightedMeanWindow {
  36. pub fn new(size: u64) -> Self {
  37. Self {
  38. size,
  39. items: Default::default(),
  40. w_sum: 0.0,
  41. sum_w: 0.0,
  42. }
  43. }
  44. /// Removes expired items and updates incremental calculations.
  45. ///
  46. /// Returns `true` if any items were removed.
  47. pub fn purge(&mut self, time: u64) -> bool {
  48. // this is somewhat awkwardly implemented, but there is not anything like
  49. // `drain_while` on `VecDeque` (or `Vec`) that would work like `take_while`,
  50. // except also removing the items. Since we need the data in the items we
  51. // are removing to update `sum_w` and `w_sum`, we loop over the expired
  52. // items first, counting them in `n_remove`, then actually remove them
  53. // in a second pass.
  54. let mut n_remove = 0;
  55. {
  56. // extra scope needed to shush the borrow checker
  57. let items = &self.items;
  58. let w_sum = &mut self.w_sum;
  59. let sum_w = &mut self.sum_w;
  60. let size = self.size;
  61. for expired in items.iter().take_while(|x| time - x.time > size) {
  62. *w_sum -= expired.wt_val;
  63. *sum_w -= expired.wt;
  64. n_remove += 1;
  65. }
  66. }
  67. for _ in 0..n_remove { self.items.pop_front(); }
  68. // when items is empty, set w_sum, sum_w to 0.0. the motive
  69. // of this approach, versus an if block with assignment, is
  70. // for the code to be "branchless" and do the same work each
  71. // time, in a cache- and branch predictor-friendly manner.
  72. let zeroer: f64 = ( ! self.items.is_empty()) as u8 as f64;
  73. self.w_sum *= zeroer;
  74. self.sum_w *= zeroer;
  75. n_remove > 0
  76. }
  77. /// Add a new item, updating incremental calculations in the process.
  78. ///
  79. /// Note: it is assumed that `time` is >= the highest `time` value for any previous
  80. /// item. The expiration logic `purge` relies on the items being added to a
  81. /// `WeightedMeanWindow` in chronological order.
  82. pub fn push(&mut self, time: u64, val: f64, wt: f64) {
  83. let wt_val: f64 = val * wt;
  84. self.w_sum += wt_val;
  85. self.sum_w += wt;
  86. self.items.push_back(WeightedPoint { time, wt_val, wt });
  87. }
  88. /// Calculate the weighted mean from current state of incremental
  89. /// accumulators.
  90. ///
  91. /// Note; this value is not cached.
  92. pub fn weighted_mean(&self) -> f64 {
  93. self.w_sum / self.sum_w
  94. }
  95. /// Checks whether items `is_empty` before trying to calculate.
  96. /// Returns None if items is empty.
  97. ///
  98. /// Note: this value is not cached.
  99. pub fn checked_weighted_mean(&self) -> Option<f64> {
  100. match self.is_empty() {
  101. true => None,
  102. false => Some(self.w_sum / self.sum_w),
  103. }
  104. }
  105. /// Purge, push and get `checked_weighted_mean`, all in one convenient step.
  106. pub fn update(&mut self, time: u64, val: f64, wt: f64) -> Option<f64> {
  107. self.purge(time);
  108. self.push(time, val, wt);
  109. self.checked_weighted_mean()
  110. }
  111. pub fn len(&self) -> usize { self.items.len() }
  112. pub fn is_empty(&self) -> bool { self.items.is_empty() }
  113. }
  114. #[allow(unused)]
  115. #[cfg(test)]
  116. mod tests {
  117. use super::*;
  118. use approx::assert_relative_eq;
  119. #[test]
  120. fn weighted_mean_output_matches_numpy_average() {
  121. let xs: Vec<f64> = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248,
  122. 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ];
  123. let ws: Vec<f64> = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117,
  124. 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ];
  125. let mut w = WeightedMeanWindow::new(1_000_000_000);
  126. for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() {
  127. w.push(i as u64, val, weight);
  128. }
  129. w.purge(11);
  130. assert_eq!(w.items.len(), 10);
  131. assert_relative_eq!(w.weighted_mean(), 0.63599718086101786, epsilon = 0.0001);
  132. }
  133. #[test]
  134. fn checked_weighted_mean_returns_none_when_items_is_empty_and_unchecked_is_nan() {
  135. let w = WeightedMeanWindow::new(1_000_000_000);
  136. assert!(w.is_empty());
  137. assert_relative_eq!(w.sum_w, 0.0f64);
  138. assert_relative_eq!(w.w_sum, 0.0f64);
  139. assert!(w.checked_weighted_mean().is_none());
  140. assert!(w.weighted_mean().is_nan());
  141. }
  142. #[test]
  143. fn purge_expires_items() {
  144. let xs: Vec<f64> = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248,
  145. 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ];
  146. let ws: Vec<f64> = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117,
  147. 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ];
  148. let xs_times_ws: Vec<f64> = xs.iter().zip(ws.iter()).map(|(&x,&w)| x * w).collect();
  149. let mut w = WeightedMeanWindow::new(10);
  150. for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() {
  151. w.push(i as u64, val, weight);
  152. }
  153. w.purge(10);
  154. assert_eq!(w.items.len(), 10);
  155. w.purge(11);
  156. assert_eq!(w.items.len(), 9);
  157. assert_relative_eq!(w.sum_w, (&ws[1..]).iter().sum::<f64>(), epsilon = 1e-5);
  158. assert_relative_eq!(w.w_sum, (&xs_times_ws[1..]).iter().sum::<f64>(), epsilon = 1e-5);
  159. w.purge(11);
  160. assert_eq!(w.items.len(), 9);
  161. w.purge(12);
  162. assert_eq!(w.items.len(), 8);
  163. assert_relative_eq!(w.sum_w, (&ws[2..]).iter().sum::<f64>(), epsilon = 1e-5);
  164. assert_relative_eq!(w.w_sum, (&xs_times_ws[2..]).iter().sum::<f64>(), epsilon = 1e-5);
  165. }
  166. }