add url option to evaluation script

This commit is contained in:
Isadora White 2025-03-09 23:14:26 -07:00
parent 5103cd82eb
commit 406ebe6072

View file

@ -115,7 +115,8 @@ def launch_parallel_experiments(task_path,
bucket_name="mindcraft-experiments",
template_profile="profiles/tasks/collab_profile.json",
world_name="Forest",
insecure_coding=False):
insecure_coding=False,
url="http://127.0.0.1:8000/v1"):
with open(task_path, 'r', encoding='utf-8') as file:
content = file.read()
@ -147,7 +148,8 @@ def launch_parallel_experiments(task_path,
model=model,
api=api,
insecure_coding=insecure_coding,
num_agents=num_agents)
num_agents=num_agents,
url=url)
time.sleep(5)
def launch_server_experiment(task_path,
@ -162,7 +164,8 @@ def launch_server_experiment(task_path,
s3=False,
bucket_name="mindcraft-experiments",
template_profile="profiles/tasks/collab_profile.json",
insecure_coding=False):
insecure_coding=False,
url="http://127.0.0.1:8000/v1"):
"""
Launch a Minecraft server and run experiments on it.
@param task_path: Path to the task file
@ -194,7 +197,7 @@ def launch_server_experiment(task_path,
agent_names = [f"Andy_{session_name}", f"Jill_{session_name}", f"Bob_{session_name}"]
models = [model] * 3
apis = [api] * 3
make_profiles(agent_names, models, apis, template_profile=template_profile)
make_profiles(agent_names, models, apis, template_profile=template_profile, url=url)
agent_profiles = [f"./{agent}.json" for agent in agent_names]
@ -279,7 +282,7 @@ def launch_server_experiment(task_path,
# subprocess.run(["tmux", "send-keys", "-t", session_name, f"/op {agent_names[0]}", "C-m"])
def make_profiles(agent_names, models, apis, template_profile="profiles/collab_profile.json"):
def make_profiles(agent_names, models, apis, template_profile="profiles/collab_profile.json", url="http://127.0.0.1:8000/v1"):
assert len(agent_names) == len(models)
with open(template_profile, 'r') as f:
@ -293,7 +296,7 @@ def make_profiles(agent_names, models, apis, template_profile="profiles/collab_p
profile["model"] = {
"api": "vllm",
"model": models[index],
"url": "http://127.0.0.1:8000/v1"
"url": url
}
else:
profile["model"] = models[index]
@ -421,6 +424,7 @@ def main():
parser.add_argument('--api', default="openai", help='API to use for the agents')
parser.add_argument('--world_name', default="Forest", help='Name of the world')
parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding')
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
args = parser.parse_args()
print(args)
@ -446,7 +450,8 @@ def main():
api=args.api,
world_name=args.world_name,
insecure_coding=args.insecure_coding,
num_agents=args.num_agents)
num_agents=args.num_agents,
url=args.url)
cmd = "aws s3"
if __name__ == "__main__":