1515 experiment returns, as reported by `imitation.scripts.analyze`.
1616"""
1717
18- import numpy as np
1918import pandas as pd
2019import scipy
2120
2221from imitation .data import types
2322
2423
25- def compare_results_to_baseline (results_file : types .AnyPath ) -> pd .DataFrame :
24+ def compare_results_to_baseline (results_filename : types .AnyPath ) -> pd .DataFrame :
2625 """Compare benchmark results to baseline results.
2726
2827 Args:
29- results_file : Path to a CSV file containing experiment results.
28+ results_filename : Path to a CSV file containing experiment results.
3029
3130 Returns:
3231 A string containing a table of p-values comparing the experiment results to
3332 the baseline results.
3433 """
35- data = pd .read_csv (results_file )
36- data ["imit_return" ] = data ["imit_return_summary" ].apply (
37- lambda x : float (x .split (" " )[0 ]),
38- )
39- summary = (
40- data [["algo" , "env_name" , "imit_return" ]]
41- .groupby (["algo" , "env_name" ])
42- .describe ()
43- )
44- summary .columns = summary .columns .get_level_values (1 )
45- summary = summary .reset_index ()
46-
47- # Table 2 (https://arxiv.org/pdf/2211.11972.pdf)
48- # todo: store results in this repo outside this file
49- baseline = pd .DataFrame .from_records (
50- [
51- {
52- "algo" : "??exp_command=bc" ,
53- "env_name" : "seals/Ant-v0" ,
54- "mean" : 1953 ,
55- "margin" : 123 ,
56- },
57- {
58- "algo" : "??exp_command=bc" ,
59- "env_name" : "seals/HalfCheetah-v0" ,
60- "mean" : 3446 ,
61- "margin" : 130 ,
62- },
63- ],
64- )
65- baseline ["count" ] = 5
66- baseline ["confidence_level" ] = 0.95
67- # Back out the standard deviation from the margin of error.
34+ results_summary = load_and_summarize_csv (results_filename )
35+ baseline_summary = load_and_summarize_csv ("baseline.csv" )
6836
69- t_score = scipy .stats .t .ppf (
70- 1 - ((1 - baseline ["confidence_level" ]) / 2 ),
71- baseline ["count" ] - 1 ,
72- )
73- std_err = baseline ["margin" ] / t_score
74-
75- baseline ["std" ] = std_err * np .sqrt (baseline ["count" ])
76-
77- comparison = pd .merge (summary , baseline , on = ["algo" , "env_name" ])
37+ comparison = pd .merge (results_summary , baseline_summary , on = ["algo" , "env_name" ])
7838
7939 comparison ["pvalue" ] = scipy .stats .ttest_ind_from_stats (
8040 comparison ["mean_x" ],
@@ -88,6 +48,30 @@ def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
8848 return comparison [["algo" , "env_name" , "pvalue" ]]
8949
9050
51+ def load_and_summarize_csv (results_filename : types .AnyPath ) -> pd .DataFrame :
52+ """Load a results CSV file and summarize the statistics.
53+
54+ Args:
55+ results_filename: Path to a CSV file containing experiment results.
56+
57+ Returns:
58+ A DataFrame containing the mean and standard deviation of the experiment
59+ returns, grouped by algorithm and environment.
60+ """
61+ data = pd .read_csv (results_filename )
62+ data ["imit_return" ] = data ["imit_return_summary" ].apply (
63+ lambda x : float (x .split (" " )[0 ]),
64+ )
65+ summary = (
66+ data [["algo" , "env_name" , "imit_return" ]]
67+ .groupby (["algo" , "env_name" ])
68+ .describe ()
69+ )
70+ summary .columns = summary .columns .get_level_values (1 )
71+ summary = summary .reset_index ()
72+ return summary
73+
74+
9175def main () -> None : # pragma: no cover
9276 """Run the script."""
9377 import sys
0 commit comments