evaluation script for vllm

This commit is contained in:
Isadora White 2025-03-03 06:10:52 +00:00
parent 91a37103fb
commit 44fc1b4618
3 changed files with 25 additions and 8 deletions

View file

@ -109,6 +109,7 @@ def launch_parallel_experiments(task_path,
exp_name,
num_agents=2,
model="gpt-4o",
api="openai",
num_parallel=1,
s3=False,
bucket_name="mindcraft-experiments",
@ -141,7 +142,8 @@ def launch_parallel_experiments(task_path,
s3=s3,
bucket_name=bucket_name,
template_profile="profiles/collab_profile.json",
model=model)
model=model,
api=api)
time.sleep(5)
def launch_server_experiment(task_path,
@ -151,7 +153,8 @@ def launch_server_experiment(task_path,
experiments_folder,
exp_name="exp",
num_agents=2,
model="gpt-4o",
model="gpt-4o",
api="openai",
s3=False,
bucket_name="mindcraft-experiments",
template_profile="profiles/collab_profile.json"):
@ -175,10 +178,12 @@ def launch_server_experiment(task_path,
if num_agents == 2:
agent_names = [f"Andy_{session_name}", f"Jill_{session_name}"]
models = [model] * 2
apis = [api] * 2
else:
agent_names = [f"Andy_{session_name}", f"Jill_{session_name}", f"Bob_{session_name}"]
models = [model] * 3
make_profiles(agent_names, models, template_profile=template_profile)
apis = [api] * 3
make_profiles(agent_names, models, apis, template_profile=template_profile)
# edit_file("settings.js", {"profiles": [f"./{agent}.json" for agent in agent_names]})
agent_profiles = [f"./{agent}.json" for agent in agent_names]
@ -243,7 +248,7 @@ def launch_server_experiment(task_path,
# subprocess.run(["tmux", "send-keys", "-t", session_name, f"/op {agent_names[0]}", "C-m"])
def make_profiles(agent_names, models, template_profile="profiles/collab_profile.json"):
def make_profiles(agent_names, models, apis, template_profile="profiles/collab_profile.json"):
assert len(agent_names) == len(models)
with open(template_profile, 'r') as f:
@ -253,7 +258,14 @@ def make_profiles(agent_names, models, template_profile="profiles/collab_profile
for index in range(len(agent_names)):
profile["name"] = agent_names[index]
profile["model"] = models[index]
if apis[index] == "vllm":
profile["model"] = {
"api": "vllm",
"model": models[index],
"url": "http://localhost:8000/v1"
}
else:
profile["model"] = models[index]
with open(f"{agent_names[index]}.json", 'w') as f:
json.dump(profile, f, indent=4)
@ -372,7 +384,8 @@ def main():
parser.add_argument('--bucket_name', default="mindcraft-experiments", help='Name of the s3 bucket')
parser.add_argument('--add_keys', action='store_true', help='Create the keys.json to match the environment variables')
parser.add_argument('--template_profile', default="andy.json", help='Model to use for the agents')
parser.add_argument('--model', default="gpt-4o", help='Model to use for the agents')
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
parser.add_argument('--api', default="openai", help='API to use for the agents')
# parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
# parser.add_argument('--wandb_project', default="minecraft_experiments", help='wandb project name')
@ -400,7 +413,8 @@ def main():
s3=args.s3,
bucket_name=args.bucket_name,
template_profile=args.template_profile,
model=args.model)
model=args.model,
api=args.api)
# servers = create_server_files("../server_data/", args.num_parallel)
# date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

View file

@ -42,4 +42,4 @@
"type": "techtree",
"timeout": 120
}
}
}

View file

@ -20,6 +20,7 @@ import { Qwen } from "./qwen.js";
import { Grok } from "./grok.js";
import { DeepSeek } from './deepseek.js';
import { OpenRouter } from './openrouter.js';
import { VLLM } from './vllm.js';
export class Prompter {
constructor(agent, fp) {
@ -181,6 +182,8 @@ export class Prompter {
model = new DeepSeek(profile.model, profile.url, profile.params);
else if (profile.api === 'openrouter')
model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params);
else if (profile.api === 'vllm')
model = new VLLM(profile.model, profile.url, profile.params);
else
throw new Error('Unknown API:', profile.api);
return model;