diff --git a/evaluation_script.py b/evaluation_script.py index 4da8b59..11ade9b 100644 --- a/evaluation_script.py +++ b/evaluation_script.py @@ -200,7 +200,9 @@ def launch_parallel_experiments(task_path, insecure_coding=False, url="http://127.0.0.1:8000/v1", max_messages=15, - num_examples=2): + num_examples=2, + no_pruning=False, + block_conversation=False): with open(task_path, 'r', encoding='utf-8') as file: content = file.read() @@ -255,7 +257,9 @@ def launch_parallel_experiments(task_path, task_type=task_type, s3_path=s3_path, max_messages=max_messages, - num_examples=num_examples) + num_examples=num_examples, + no_pruning=no_pruning, + block_conversation=block_conversation) time.sleep(5) total_num_tasks = len(task_ids) @@ -301,7 +305,9 @@ def launch_server_experiment(task_path, task_type="techtree", s3_path="", max_messages=15, - num_examples=2): + num_examples=2, + no_pruning=False, + block_conversation=False): """ Launch a Minecraft server and run experiments on it. @@ -375,12 +381,17 @@ def launch_server_experiment(task_path, # op_script_content = "sleep 5\n\op @p" * 20 # op_script_file = f"./tmp/op_script_{session_name}.sh" # make_script_file_and_run(op_script_content, "server_" + session_name, op_script_file) - if task_type == "cooking": - set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_COOKING) - elif task_type == "techtree": - set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_CRAFTING) - elif task_type == "construction": - set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_CONSTRUCTION) + blocked_actions = [] + if not no_pruning: + if task_type == "cooking": + blocked_actions = BLOCKED_ACTIONS_COOKING + elif task_type == "techtree": + blocked_actions = BLOCKED_ACTIONS_CRAFTING + elif task_type == "construction": + blocked_actions = BLOCKED_ACTIONS_CONSTRUCTION + if block_conversation: + blocked_actions += ["!endConversation", "!startConversation"] + set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", blocked_actions) script_content = "" for task_id in task_ids: @@ -650,6 +661,8 @@ def main(): parser.add_argument('--url', default="http://127.0.0.1:8000/v1") parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing') parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing') + parser.add_argument('--no-pruning', action='store_true', help='Disable pruning of the actions') + parser.add_argument('--block_conversation', action='store_true', help='Block conversation actions') args = parser.parse_args() print(args) @@ -677,7 +690,9 @@ def main(): num_agents=args.num_agents, url=args.url, max_messages=args.max_messages, - num_examples=args.num_examples) + num_examples=args.num_examples, + no_pruning=args.no_pruning, + block_conversation=args.block_conversation) if __name__ == "__main__": main() \ No newline at end of file