From e0faeef729316ae23d7c6454ee6beb2ea604e9a9 Mon Sep 17 00:00:00 2001 From: Jonathan Strong Date: Wed, 15 Apr 2020 01:51:54 -0400 Subject: [PATCH] streamlined --hard-mode for csv, tests for WeightedMeanWindow --- Cargo.toml | 3 + ...-hard-query-output-of-reference-impl.ipynb | 514 ++++++++++++++++++ src/csv.rs | 95 ++-- src/windows.rs | 140 ++++- 4 files changed, 691 insertions(+), 61 deletions(-) create mode 100644 notebooks/verifying-hard-query-output-of-reference-impl.ipynb diff --git a/Cargo.toml b/Cargo.toml index 133a7d2..c5d7478 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,9 @@ chrono = { version = "0.4", features = ["serde"] } clap = "2" itertools-num = "0.1" +[dev-dependencies] +approx = "0.3" + [profile.release] lto = "fat" panic = "abort" diff --git a/notebooks/verifying-hard-query-output-of-reference-impl.ipynb b/notebooks/verifying-hard-query-output-of-reference-impl.ipynb new file mode 100644 index 0000000..ed5d412 --- /dev/null +++ b/notebooks/verifying-hard-query-output-of-reference-impl.ipynb @@ -0,0 +1,514 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1min 8s, sys: 7.72 s, total: 1min 16s\n", + "Wall time: 1min 16s\n", + "\n", + "RangeIndex: 92331988 entries, 0 to 92331987\n", + "Data columns (total 7 columns):\n", + "time int64\n", + "amount float64\n", + "exch object\n", + "price float64\n", + "server_time int64\n", + "side object\n", + "ticker object\n", + "dtypes: float64(2), int64(2), object(3)\n", + "memory usage: 4.8+ GB\n" + ] + } + ], + "source": [ + "%time df = pd.read_csv('/xfs/sample.csv')\n", + "df.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timeamountexchpriceserver_timesideticker
015619392000024793721.4894bnce292.70001561939199919000064NaNeth_usd
115619392000110356440.0833btfx10809.00001561939199927000064bidbtc_usd
215619392000110557120.8333btfx10809.00001561939199927000064bidbtc_usd
315619392000190376170.0831bnce10854.10001561939199935000064NaNbtc_usd
415619392000264504710.1250okex123.21001561939200026450432askltc_usd
\n", + "
" + ], + "text/plain": [ + " time amount exch price server_time side \\\n", + "0 1561939200002479372 1.4894 bnce 292.7000 1561939199919000064 NaN \n", + "1 1561939200011035644 0.0833 btfx 10809.0000 1561939199927000064 bid \n", + "2 1561939200011055712 0.8333 btfx 10809.0000 1561939199927000064 bid \n", + "3 1561939200019037617 0.0831 bnce 10854.1000 1561939199935000064 NaN \n", + "4 1561939200026450471 0.1250 okex 123.2100 1561939200026450432 ask \n", + "\n", + " ticker \n", + "0 eth_usd \n", + "1 btc_usd \n", + "2 btc_usd \n", + "3 btc_usd \n", + "4 ltc_usd " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(True, True, True)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SECOND = int(1e9)\n", + "\n", + "example_time = 1567295920000000000\n", + "\n", + "last_5min = (df['time'] > example_time - SECOND * 60 * 5) & (df['time'] <= example_time)\n", + "last_15min = (df['time'] > example_time - SECOND * 60 * 15) & (df['time'] <= example_time)\n", + "last_60min = (df['time'] > example_time - SECOND * 60 * 60) & (df['time'] <= example_time)\n", + "of_btc_usd = df['ticker'] == 'btc_usd'\n", + "of_gdax = df['exch'] == 'gdax'\n", + "of_bmex = df['exch'] == 'bmex'\n", + "\n", + "g5 = last_5min & of_btc_usd & of_gdax\n", + "b5 = last_5min & of_btc_usd & of_bmex\n", + "g15 = last_15min & of_btc_usd & of_gdax\n", + "b15 = last_15min & of_btc_usd & of_bmex\n", + "g60 = last_60min & of_btc_usd & of_gdax\n", + "b60 = last_60min & of_btc_usd & of_bmex\n", + "\n", + "ratio_5min = ((df.loc[b5, 'price'] * df.loc[b5, 'amount']).sum() / df.loc[b5, 'amount'].sum()) / ((df.loc[g5, 'price'] * df.loc[g5, 'amount']).sum() / df.loc[g5, 'amount'].sum())\n", + "ratio_15min = ((df.loc[b15, 'price'] * df.loc[b15, 'amount']).sum() / df.loc[b15, 'amount'].sum()) / ((df.loc[g15, 'price'] * df.loc[g15, 'amount']).sum() / df.loc[g15, 'amount'].sum())\n", + "ratio_60min = ((df.loc[b60, 'price'] * df.loc[b60, 'amount']).sum() / df.loc[b60, 'amount'].sum()) / ((df.loc[g60, 'price'] * df.loc[g60, 'amount']).sum() / df.loc[g60, 'amount'].sum())\n", + "\n", + "abs(ratio_5min - 1.000474060563638) < 1e-6, abs(ratio_15min - 1.0005019306061411) < 1e-6, abs(ratio_60min - 1.0002338013889658) < 1e-6" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timelastbmex_5mingdax_5minn_bmex_p5n_gdax_p5r5r15r60
0156193921000000000010758.580010760.720510760.459322281.00001.00001.0000
1156193922000000000010770.000010763.681110761.2528230751.00021.00021.0002
2156193923000000000010758.010010761.884310760.15964181201.00021.00021.0002
3156193924000000000010752.500010757.682910760.36305071470.99980.99980.9998
4156193925000000000010772.690010757.570210763.08405371910.99950.99950.9995
\n", + "
" + ], + "text/plain": [ + " time last bmex_5min gdax_5min n_bmex_p5 n_gdax_p5 \\\n", + "0 1561939210000000000 10758.5800 10760.7205 10760.4593 22 28 \n", + "1 1561939220000000000 10770.0000 10763.6811 10761.2528 230 75 \n", + "2 1561939230000000000 10758.0100 10761.8843 10760.1596 418 120 \n", + "3 1561939240000000000 10752.5000 10757.6829 10760.3630 507 147 \n", + "4 1561939250000000000 10772.6900 10757.5702 10763.0840 537 191 \n", + "\n", + " r5 r15 r60 \n", + "0 1.0000 1.0000 1.0000 \n", + "1 1.0002 1.0002 1.0002 \n", + "2 1.0002 1.0002 1.0002 \n", + "3 0.9998 0.9998 0.9998 \n", + "4 0.9995 0.9995 0.9995 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ref = pd.read_csv('../var/hard.csv')\n", + "ref.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1567295750000000000, 1567295760000000000, 1567295770000000000, 1567295780000000000, 1567295790000000000, 1567295800000000000, 1567295810000000000, 1567295820000000000, 1567295830000000000,\n", + " 1567295840000000000, 1567295850000000000, 1567295860000000000, 1567295870000000000, 1567295880000000000, 1567295890000000000, 1567295900000000000, 1567295910000000000, 1567295920000000000,\n", + " 1567295930000000000, 1567295940000000000, 1567295950000000000, 1567295960000000000, 1567295970000000000, 1567295980000000000, 1567295990000000000])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ref['time'].tail(25).values" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "finished in 487.8sec\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "start = time.time()\n", + "rows = []\n", + "\n", + "for example_time in ref['time'].tail(25).values:\n", + " last_5min = (df['time'] > example_time - SECOND * 60 * 5) & (df['time'] <= example_time)\n", + " last_15min = (df['time'] > example_time - SECOND * 60 * 15) & (df['time'] <= example_time)\n", + " last_60min = (df['time'] > example_time - SECOND * 60 * 60) & (df['time'] <= example_time)\n", + " of_btc_usd = df['ticker'] == 'btc_usd'\n", + " of_gdax = df['exch'] == 'gdax'\n", + " of_bmex = df['exch'] == 'bmex'\n", + "\n", + " g5 = last_5min & of_btc_usd & of_gdax\n", + " b5 = last_5min & of_btc_usd & of_bmex\n", + " g15 = last_15min & of_btc_usd & of_gdax\n", + " b15 = last_15min & of_btc_usd & of_bmex\n", + " g60 = last_60min & of_btc_usd & of_gdax\n", + " b60 = last_60min & of_btc_usd & of_bmex\n", + "\n", + " ratio_5min = ((df.loc[b5, 'price'] * df.loc[b5, 'amount']).sum() / df.loc[b5, 'amount'].sum()) / ((df.loc[g5, 'price'] * df.loc[g5, 'amount']).sum() / df.loc[g5, 'amount'].sum())\n", + " ratio_15min = ((df.loc[b15, 'price'] * df.loc[b15, 'amount']).sum() / df.loc[b15, 'amount'].sum()) / ((df.loc[g15, 'price'] * df.loc[g15, 'amount']).sum() / df.loc[g15, 'amount'].sum())\n", + " ratio_60min = ((df.loc[b60, 'price'] * df.loc[b60, 'amount']).sum() / df.loc[b60, 'amount'].sum()) / ((df.loc[g60, 'price'] * df.loc[g60, 'amount']).sum() / df.loc[g60, 'amount'].sum())\n", + " rows.append(dict(example_time=example_time, r5=ratio_5min, r15=ratio_15min, r60=ratio_60min))\n", + " \n", + "took = time.time() - start\n", + "print('finished in {:.1f}sec'.format(took))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3.342554545733013" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hypothetical_full_took = (took / 25) * 5401808\n", + "hypothetical_full_took / 60 / 60 / 24 / 365" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.1016643329480867" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "92331988 / 908204336" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "19.513984975814818" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "took / 25" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('105,410,800.2', 105410800.15423629)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'{:,.1f}'.format(hypothetical_full_took), hypothetical_full_took" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "r5_delta 0.000000000002368\n", + "r15_delta 0.000000000010704\n", + "r60_delta 0.000000000005513\n", + "dtype: object" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(pd.DataFrame(rows).join(ref.set_index('time'), on='example_time', rsuffix='_rust')\n", + " .assign(r5_delta=lambda df: abs(df['r5'] - df['r5_rust']))\n", + " .assign(r15_delta=lambda df: abs(df['r15'] - df['r15_rust']))\n", + " .assign(r60_delta=lambda df: abs(df['r60'] - df['r60_rust']))\n", + ")[['r5_delta','r15_delta','r60_delta']].max(axis=0).map(lambda x: '{:.15f}'.format(x))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/csv.rs b/src/csv.rs index 50be1b4..70c1225 100644 --- a/src/csv.rs +++ b/src/csv.rs @@ -15,7 +15,7 @@ use serde::{Serialize, Deserialize}; use slog::Drain; use pretty_toa::ThousandsSep; use markets::crypto::{Exchange, Ticker, Side}; -use pipelines::windows::WeightedAvgWindow; +use pipelines::windows::WeightedMeanWindow; // equivalent to panic! but without the ugly 'thread main panicked' yada yada @@ -214,24 +214,28 @@ fn hard_mode( let mut ratios: Lookbacks = Default::default(); - let mut bwindows: Lookbacks = + let mut bmex_windows: Lookbacks = Lookbacks { - p5: WeightedAvgWindow::new(ONE_SECOND * 60 * 5 ), - p15: WeightedAvgWindow::new(ONE_SECOND * 60 * 15), - p60: WeightedAvgWindow::new(ONE_SECOND * 60 * 60), + p5: WeightedMeanWindow::new(ONE_SECOND * 60 * 5 ), + p15: WeightedMeanWindow::new(ONE_SECOND * 60 * 15), + p60: WeightedMeanWindow::new(ONE_SECOND * 60 * 60), }; - let mut gwindows = bwindows.clone(); + let mut gdax_windows = bmex_windows.clone(); #[inline(always)] - fn do_purge(windows: &mut Lookbacks, prices: &mut Lookbacks, time: u64) { - if windows.p5.purge(time) { prices.p5 = windows.p5 .checked_wt_mean().unwrap_or(NAN); } - if windows.p15.purge(time) { prices.p15 = windows.p15.checked_wt_mean().unwrap_or(NAN); } - if windows.p60.purge(time) { prices.p60 = windows.p60.checked_wt_mean().unwrap_or(NAN); } + fn do_purge(windows: &mut Lookbacks, prices: &mut Lookbacks, time: u64) { + //if windows.p5.purge(time) { prices.p5 = windows.p5 .checked_weighted_mean().unwrap_or(NAN); } + //if windows.p15.purge(time) { prices.p15 = windows.p15.checked_weighted_mean().unwrap_or(NAN); } + //if windows.p60.purge(time) { prices.p60 = windows.p60.checked_weighted_mean().unwrap_or(NAN); } + windows.p5 .purge(time); + windows.p15.purge(time); + windows.p60.purge(time); } + #[allow(unused)] #[inline(always)] - fn do_update(windows: &mut Lookbacks, prices: &mut Lookbacks, time: u64, price: f64, amount: f64) { + fn do_update(windows: &mut Lookbacks, prices: &mut Lookbacks, time: u64, price: f64, amount: f64) { //prices.p5 = windows.p5 .update(time, price, amount).unwrap_or(NAN); //prices.p15 = windows.p15.update(time, price, amount).unwrap_or(NAN); //prices.p60 = windows.p60.update(time, price, amount).unwrap_or(NAN); @@ -239,21 +243,20 @@ fn hard_mode( windows.p5 .push(time, price, amount); windows.p15.push(time, price, amount); windows.p60.push(time, price, amount); - } macro_rules! update { // in macro to avoid repeating code once outside loop, and again in loop body ($trade:ident) => {{ match $trade.exch { e!(bmex) => { - do_update(&mut bwindows, &mut bprices, $trade.time, $trade.price, $trade.amount); - //do_purge(&mut gwindows, &mut gprices, $trade.time); + do_update(&mut bmex_windows, &mut bprices, $trade.time, $trade.price, $trade.amount); + //do_purge(&mut gdax_windows, &mut gprices, $trade.time); last_price = $trade.price; } e!(gdax) => { - do_update(&mut gwindows, &mut gprices, $trade.time, $trade.price, $trade.amount); - //do_purge(&mut bwindows, &mut bprices, $trade.time); + do_update(&mut gdax_windows, &mut gprices, $trade.time, $trade.price, $trade.amount); + //do_purge(&mut bmex_windows, &mut bprices, $trade.time); last_price = $trade.price; } @@ -264,11 +267,11 @@ fn hard_mode( wtr.write_record(&[ "time", - "last", - "bmex_5min", - "gdax_5min", - "n_bmex_p5", - "n_gdax_p5", + //"last", + //"bmex_5min", + //"gdax_5min", + //"n_bmex_p5", + //"n_gdax_p5", "r5", "r15", "r60", @@ -279,7 +282,7 @@ fn hard_mode( //"n_gdax_p15", //"n_gdax_p60", //"gdax_p5_is_empty", - //"gdax_p5_checked_wt_mean", + //"gdax_p5_checked_weighted_mean", //"tradetime_minus_cur_bucket", ]).map_err(|e| format!("writing CSV headers to output file failed: {}", e))?; @@ -306,48 +309,48 @@ fn hard_mode( "n written" => n_written, "trade.time" => trade.time, "cur_bucket" => cur_bucket, - "gdax p5 len" => gwindows.p5.len(), - "gdax p5 wt avg" => gwindows.p5.wt_mean(), + "gdax p5 len" => gdax_windows.p5.len(), + "gdax p5 wt avg" => gdax_windows.p5.weighted_mean(), ); - do_purge(&mut gwindows, &mut gprices, cur_bucket); - do_purge(&mut bwindows, &mut bprices, cur_bucket); + do_purge(&mut gdax_windows, &mut gprices, cur_bucket); + do_purge(&mut bmex_windows, &mut bprices, cur_bucket); debug!(logger, "finished purge"; "n" => n, "n written" => n_written, "trade.time" => trade.time, "cur_bucket" => cur_bucket, - "gdax p5 len" => gwindows.p5.len(), - "gdax p5 wt avg" => gwindows.p5.wt_mean(), + "gdax p5 len" => gdax_windows.p5.len(), + "gdax p5 wt avg" => gdax_windows.p5.weighted_mean(), ); - ratios.p5 = bwindows.p5 .checked_wt_mean().unwrap_or(NAN) / gwindows.p5 .checked_wt_mean().unwrap_or(NAN); - ratios.p15 = bwindows.p15.checked_wt_mean().unwrap_or(NAN) / gwindows.p15.checked_wt_mean().unwrap_or(NAN); - ratios.p60 = bwindows.p60.checked_wt_mean().unwrap_or(NAN) / gwindows.p60.checked_wt_mean().unwrap_or(NAN); + ratios.p5 = bmex_windows.p5 .weighted_mean() / gdax_windows.p5 .weighted_mean(); + ratios.p15 = bmex_windows.p15.weighted_mean() / gdax_windows.p15.weighted_mean(); + ratios.p60 = bmex_windows.p60.weighted_mean() / gdax_windows.p60.weighted_mean(); - //ratios.p5 = bwindows.p5 .wt_mean() / gwindows.p5 .wt_mean(); - //ratios.p15 = bwindows.p15.wt_mean() / gwindows.p15.wt_mean(); - //ratios.p60 = bwindows.p60.wt_mean() / gwindows.p60.wt_mean(); + //ratios.p5 = bmex_windows.p5 .weighted_mean() / gdax_windows.p5 .weighted_mean(); + //ratios.p15 = bmex_windows.p15.weighted_mean() / gdax_windows.p15.weighted_mean(); + //ratios.p60 = bmex_windows.p60.weighted_mean() / gdax_windows.p60.weighted_mean(); wtr.write_record(&[ &format!("{}", cur_bucket), - &format!("{}", last_price), - &format!("{}", bwindows.p5.checked_wt_mean().unwrap_or(NAN)), - &format!("{}", gwindows.p5.checked_wt_mean().unwrap_or(NAN)), - &format!("{}", bwindows.p5.len()), - &format!("{}", gwindows.p5.len()), + //&format!("{}", last_price), + //&format!("{}", bmex_windows.p5.checked_weighted_mean().unwrap_or(NAN)), + //&format!("{}", gdax_windows.p5.checked_weighted_mean().unwrap_or(NAN)), + //&format!("{}", bmex_windows.p5.len()), + //&format!("{}", gdax_windows.p5.len()), &format!("{}", ratios.p5), &format!("{}", ratios.p15), &format!("{}", ratios.p60), - //&format!("{}", bwindows.p15.len()), - //&format!("{}", gwindows.p60.len()), - //&format!("{}", gwindows.p15.len()), - //&format!("{}", gwindows.p15.len()), - //&format!("{}", bwindows.p60.len()), - //&format!("{}", bwindows.p5.is_empty()), - //&format!("{:?}", bwindows.p5.checked_wt_mean()), + //&format!("{}", bmex_windows.p15.len()), + //&format!("{}", gdax_windows.p60.len()), + //&format!("{}", gdax_windows.p15.len()), + //&format!("{}", gdax_windows.p15.len()), + //&format!("{}", bmex_windows.p60.len()), + //&format!("{}", bmex_windows.p5.is_empty()), + //&format!("{:?}", bmex_windows.p5.checked_weighted_mean()), //&format!("{}", trade.time - cur_bucket), ]).map_err(|e| { diff --git a/src/windows.rs b/src/windows.rs index db391a2..85f515a 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -1,5 +1,28 @@ use std::collections::VecDeque; +/// Calculates online weighted average for a rolling, time-based window +#[derive(Clone)] +pub struct WeightedMeanWindow { + /// The size of the window. On `purge`, any `WeightedPoint` items are considered + /// expired if the supplied `time` parameter is greater than `size` from the + /// `time` attribute of that `WeightedPoint` item. + size: u64, + /// The weights and values with times that are "currently" in the aggregation + /// window. On `push`, items are added to the "back" of the vedeque. On `purge`, + /// items with a `time` that is > `size` difference relative to the `time` passed + /// to `purge` are considered expired and removed. In both cases, adding and removing, + /// the incremental accumulated sums in `w_sum` and `sum_w` are updated. + items: VecDeque, + /// The sum of the value * weight for each of the `WeightedPoint`s in `items`. + w_sum: f64, + /// The sum of the weights of each of the `WeightedPoint`s in `items`. + sum_w: f64, +} + +/// Stores the time, value and weight for an item "currently" inside the +/// aggregation window of a `WeightedMeanWindow`, allowing its value and +/// weight to be subtracted from the accumulated sums of the window when +/// the item becomes expired. #[derive(Debug, Clone)] pub struct WeightedPoint { pub time: u64, @@ -12,16 +35,7 @@ pub struct WeightedPoint { pub wt: f64, } -#[derive(Clone)] -pub struct WeightedAvgWindow { - size: u64, - items: VecDeque, - w_sum: f64, - sum_w: f64, - //w_mean: f64, -} - -impl WeightedAvgWindow { +impl WeightedMeanWindow { pub fn new(size: u64) -> Self { Self { size, @@ -35,9 +49,19 @@ impl WeightedAvgWindow { /// /// Returns `true` if any items were removed. pub fn purge(&mut self, time: u64) -> bool { + + // this is somewhat awkwardly implemented, but there is not anything like + // `drain_while` on `VecDeque` (or `Vec`) that would work like `take_while`, + // except also removing the items. Since we need the data in the items we + // are removing to update `sum_w` and `w_sum`, we loop over the expired + // items first, counting them in `n_remove`, then actually remove them + // in a second pass. + let mut n_remove = 0; { + // extra scope needed to shush the borrow checker + let items = &self.items; let w_sum = &mut self.w_sum; let sum_w = &mut self.sum_w; @@ -52,7 +76,10 @@ impl WeightedAvgWindow { for _ in 0..n_remove { self.items.pop_front(); } - // when items is empty, set w_sum, sum_w to 0.0 + // when items is empty, set w_sum, sum_w to 0.0. the motive + // of this approach, versus an if block with assignment, is + // for the code to be "branchless" and do the same work each + // time, in a cache- and branch predictor-friendly manner. let zeroer: f64 = ( ! self.items.is_empty()) as u8 as f64; self.w_sum *= zeroer; self.sum_w *= zeroer; @@ -61,6 +88,10 @@ impl WeightedAvgWindow { } /// Add a new item, updating incremental calculations in the process. + /// + /// Note: it is assumed that `time` is >= the highest `time` value for any previous + /// item. The expiration logic `purge` relies on the items being added to a + /// `WeightedMeanWindow` in chronological order. pub fn push(&mut self, time: u64, val: f64, wt: f64) { let wt_val: f64 = val * wt; self.w_sum += wt_val; @@ -72,24 +103,26 @@ impl WeightedAvgWindow { /// accumulators. /// /// Note; this value is not cached. - pub fn wt_mean(&self) -> f64 { + pub fn weighted_mean(&self) -> f64 { self.w_sum / self.sum_w } /// Checks whether items `is_empty` before trying to calculate. /// Returns None if items is empty. - pub fn checked_wt_mean(&self) -> Option { + /// + /// Note: this value is not cached. + pub fn checked_weighted_mean(&self) -> Option { match self.is_empty() { true => None, false => Some(self.w_sum / self.sum_w), } } - /// Purge, push and get `checked_wt_mean`, all in one convenient step. + /// Purge, push and get `checked_weighted_mean`, all in one convenient step. pub fn update(&mut self, time: u64, val: f64, wt: f64) -> Option { self.purge(time); self.push(time, val, wt); - self.checked_wt_mean() + self.checked_weighted_mean() } pub fn len(&self) -> usize { self.items.len() } @@ -97,3 +130,80 @@ impl WeightedAvgWindow { pub fn is_empty(&self) -> bool { self.items.is_empty() } } +#[allow(unused)] +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_relative_eq; + + #[test] + fn weighted_mean_output_matches_numpy_average() { + + let xs: Vec = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248, + 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ]; + + let ws: Vec = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117, + 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ]; + + let mut w = WeightedMeanWindow::new(1_000_000_000); + + for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() { + w.push(i as u64, val, weight); + } + + w.purge(11); + + assert_eq!(w.items.len(), 10); + assert_relative_eq!(w.weighted_mean(), 0.63599718086101786, epsilon = 0.0001); + } + + #[test] + fn checked_weighted_mean_returns_none_when_items_is_empty_and_unchecked_is_nan() { + let w = WeightedMeanWindow::new(1_000_000_000); + assert!(w.is_empty()); + assert_relative_eq!(w.sum_w, 0.0f64); + assert_relative_eq!(w.w_sum, 0.0f64); + + assert!(w.checked_weighted_mean().is_none()); + assert!(w.weighted_mean().is_nan()); + } + + #[test] + fn purge_expires_items() { + let xs: Vec = vec![ 0.41305045, 0.93555897, 0.77885094, 0.9896831 , 0.79720248, + 0.69497414, 0.34953127, 0.02331158, 0.89858514, 0.38312421 ]; + + let ws: Vec = vec![ 0.01256151, 0.58996267, 0.6474601 , 0.33013727, 0.92964117, + 0.21427296, 0.42990663, 0.81912449, 0.99428442, 0.71875903 ]; + + let xs_times_ws: Vec = xs.iter().zip(ws.iter()).map(|(&x,&w)| x * w).collect(); + + let mut w = WeightedMeanWindow::new(10); + + for (i, (val, weight)) in xs.iter().cloned().zip(ws.iter().cloned()).enumerate() { + w.push(i as u64, val, weight); + } + + w.purge(10); + + assert_eq!(w.items.len(), 10); + + w.purge(11); + + assert_eq!(w.items.len(), 9); + assert_relative_eq!(w.sum_w, (&ws[1..]).iter().sum::(), epsilon = 1e-5); + assert_relative_eq!(w.w_sum, (&xs_times_ws[1..]).iter().sum::(), epsilon = 1e-5); + + w.purge(11); + + assert_eq!(w.items.len(), 9); + + w.purge(12); + + assert_eq!(w.items.len(), 8); + assert_relative_eq!(w.sum_w, (&ws[2..]).iter().sum::(), epsilon = 1e-5); + assert_relative_eq!(w.w_sum, (&xs_times_ws[2..]).iter().sum::(), epsilon = 1e-5); + } +} + +