diff --git a/evaluation_script.py b/evaluation_script.py index b1dea33..ea5859b 100644 --- a/evaluation_script.py +++ b/evaluation_script.py @@ -109,6 +109,7 @@ def launch_parallel_experiments(task_path, exp_name, num_agents=2, model="gpt-4o", + api="openai", num_parallel=1, s3=False, bucket_name="mindcraft-experiments", @@ -141,7 +142,8 @@ def launch_parallel_experiments(task_path, s3=s3, bucket_name=bucket_name, template_profile="profiles/collab_profile.json", - model=model) + model=model, + api=api) time.sleep(5) def launch_server_experiment(task_path, @@ -151,7 +153,8 @@ def launch_server_experiment(task_path, experiments_folder, exp_name="exp", num_agents=2, - model="gpt-4o", + model="gpt-4o", + api="openai", s3=False, bucket_name="mindcraft-experiments", template_profile="profiles/collab_profile.json"): @@ -175,10 +178,12 @@ def launch_server_experiment(task_path, if num_agents == 2: agent_names = [f"Andy_{session_name}", f"Jill_{session_name}"] models = [model] * 2 + apis = [api] * 2 else: agent_names = [f"Andy_{session_name}", f"Jill_{session_name}", f"Bob_{session_name}"] models = [model] * 3 - make_profiles(agent_names, models, template_profile=template_profile) + apis = [api] * 3 + make_profiles(agent_names, models, apis, template_profile=template_profile) # edit_file("settings.js", {"profiles": [f"./{agent}.json" for agent in agent_names]}) agent_profiles = [f"./{agent}.json" for agent in agent_names] @@ -243,7 +248,7 @@ def launch_server_experiment(task_path, # subprocess.run(["tmux", "send-keys", "-t", session_name, f"/op {agent_names[0]}", "C-m"]) -def make_profiles(agent_names, models, template_profile="profiles/collab_profile.json"): +def make_profiles(agent_names, models, apis, template_profile="profiles/collab_profile.json"): assert len(agent_names) == len(models) with open(template_profile, 'r') as f: @@ -253,7 +258,14 @@ def make_profiles(agent_names, models, template_profile="profiles/collab_profile for index in range(len(agent_names)): profile["name"] = agent_names[index] - profile["model"] = models[index] + if apis[index] == "vllm": + profile["model"] = { + "api": "vllm", + "model": models[index], + "url": "http://localhost:8000/v1" + } + else: + profile["model"] = models[index] with open(f"{agent_names[index]}.json", 'w') as f: json.dump(profile, f, indent=4) @@ -372,7 +384,8 @@ def main(): parser.add_argument('--bucket_name', default="mindcraft-experiments", help='Name of the s3 bucket') parser.add_argument('--add_keys', action='store_true', help='Create the keys.json to match the environment variables') parser.add_argument('--template_profile', default="andy.json", help='Model to use for the agents') - parser.add_argument('--model', default="gpt-4o", help='Model to use for the agents') + parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents') + parser.add_argument('--api', default="openai", help='API to use for the agents') # parser.add_argument('--wandb', action='store_true', help='Whether to use wandb') # parser.add_argument('--wandb_project', default="minecraft_experiments", help='wandb project name') @@ -400,7 +413,8 @@ def main(): s3=args.s3, bucket_name=args.bucket_name, template_profile=args.template_profile, - model=args.model) + model=args.model, + api=args.api) # servers = create_server_files("../server_data/", args.num_parallel) # date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") diff --git a/multiagent_crafting_tasks.json b/multiagent_crafting_tasks.json index 99cd57f..ad37fbd 100644 --- a/multiagent_crafting_tasks.json +++ b/multiagent_crafting_tasks.json @@ -42,4 +42,4 @@ "type": "techtree", "timeout": 120 } -} \ No newline at end of file +} diff --git a/src/models/prompter.js b/src/models/prompter.js index 6cc54e2..5dfc1a0 100644 --- a/src/models/prompter.js +++ b/src/models/prompter.js @@ -20,6 +20,7 @@ import { Qwen } from "./qwen.js"; import { Grok } from "./grok.js"; import { DeepSeek } from './deepseek.js'; import { OpenRouter } from './openrouter.js'; +import { VLLM } from './vllm.js'; export class Prompter { constructor(agent, fp) { @@ -181,6 +182,8 @@ export class Prompter { model = new DeepSeek(profile.model, profile.url, profile.params); else if (profile.api === 'openrouter') model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params); + else if (profile.api === 'vllm') + model = new VLLM(profile.model, profile.url, profile.params); else throw new Error('Unknown API:', profile.api); return model;