evaluation script add no pruning and block conversation

This commit is contained in:
Isadora White 2025-03-26 01:13:39 -05:00
parent 38c701a8fb
commit f862964edd

View file

@ -200,7 +200,9 @@ def launch_parallel_experiments(task_path,
insecure_coding=False,
url="http://127.0.0.1:8000/v1",
max_messages=15,
num_examples=2):
num_examples=2,
no_pruning=False,
block_conversation=False):
with open(task_path, 'r', encoding='utf-8') as file:
content = file.read()
@ -255,7 +257,9 @@ def launch_parallel_experiments(task_path,
task_type=task_type,
s3_path=s3_path,
max_messages=max_messages,
num_examples=num_examples)
num_examples=num_examples,
no_pruning=no_pruning,
block_conversation=block_conversation)
time.sleep(5)
total_num_tasks = len(task_ids)
@ -301,7 +305,9 @@ def launch_server_experiment(task_path,
task_type="techtree",
s3_path="",
max_messages=15,
num_examples=2):
num_examples=2,
no_pruning=False,
block_conversation=False):
"""
Launch a Minecraft server and run experiments on it.
@ -375,12 +381,17 @@ def launch_server_experiment(task_path,
# op_script_content = "sleep 5\n\op @p" * 20
# op_script_file = f"./tmp/op_script_{session_name}.sh"
# make_script_file_and_run(op_script_content, "server_" + session_name, op_script_file)
if task_type == "cooking":
set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_COOKING)
elif task_type == "techtree":
set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_CRAFTING)
elif task_type == "construction":
set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", BLOCKED_ACTIONS_CONSTRUCTION)
blocked_actions = []
if not no_pruning:
if task_type == "cooking":
blocked_actions = BLOCKED_ACTIONS_COOKING
elif task_type == "techtree":
blocked_actions = BLOCKED_ACTIONS_CRAFTING
elif task_type == "construction":
blocked_actions = BLOCKED_ACTIONS_CONSTRUCTION
if block_conversation:
blocked_actions += ["!endConversation", "!startConversation"]
set_environment_variable_tmux_session(session_name, "BLOCKED_ACTIONS", blocked_actions)
script_content = ""
for task_id in task_ids:
@ -650,6 +661,8 @@ def main():
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing')
parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing')
parser.add_argument('--no-pruning', action='store_true', help='Disable pruning of the actions')
parser.add_argument('--block_conversation', action='store_true', help='Block conversation actions')
args = parser.parse_args()
print(args)
@ -677,7 +690,9 @@ def main():
num_agents=args.num_agents,
url=args.url,
max_messages=args.max_messages,
num_examples=args.num_examples)
num_examples=args.num_examples,
no_pruning=args.no_pruning,
block_conversation=args.block_conversation)
if __name__ == "__main__":
main()