mirror of
https://github.com/kolbytn/mindcraft.git
synced 2025-07-30 11:55:29 +02:00
evaluation script for vllm
This commit is contained in:
parent
91a37103fb
commit
44fc1b4618
3 changed files with 25 additions and 8 deletions
|
@ -109,6 +109,7 @@ def launch_parallel_experiments(task_path,
|
|||
exp_name,
|
||||
num_agents=2,
|
||||
model="gpt-4o",
|
||||
api="openai",
|
||||
num_parallel=1,
|
||||
s3=False,
|
||||
bucket_name="mindcraft-experiments",
|
||||
|
@ -141,7 +142,8 @@ def launch_parallel_experiments(task_path,
|
|||
s3=s3,
|
||||
bucket_name=bucket_name,
|
||||
template_profile="profiles/collab_profile.json",
|
||||
model=model)
|
||||
model=model,
|
||||
api=api)
|
||||
time.sleep(5)
|
||||
|
||||
def launch_server_experiment(task_path,
|
||||
|
@ -151,7 +153,8 @@ def launch_server_experiment(task_path,
|
|||
experiments_folder,
|
||||
exp_name="exp",
|
||||
num_agents=2,
|
||||
model="gpt-4o",
|
||||
model="gpt-4o",
|
||||
api="openai",
|
||||
s3=False,
|
||||
bucket_name="mindcraft-experiments",
|
||||
template_profile="profiles/collab_profile.json"):
|
||||
|
@ -175,10 +178,12 @@ def launch_server_experiment(task_path,
|
|||
if num_agents == 2:
|
||||
agent_names = [f"Andy_{session_name}", f"Jill_{session_name}"]
|
||||
models = [model] * 2
|
||||
apis = [api] * 2
|
||||
else:
|
||||
agent_names = [f"Andy_{session_name}", f"Jill_{session_name}", f"Bob_{session_name}"]
|
||||
models = [model] * 3
|
||||
make_profiles(agent_names, models, template_profile=template_profile)
|
||||
apis = [api] * 3
|
||||
make_profiles(agent_names, models, apis, template_profile=template_profile)
|
||||
|
||||
# edit_file("settings.js", {"profiles": [f"./{agent}.json" for agent in agent_names]})
|
||||
agent_profiles = [f"./{agent}.json" for agent in agent_names]
|
||||
|
@ -243,7 +248,7 @@ def launch_server_experiment(task_path,
|
|||
|
||||
# subprocess.run(["tmux", "send-keys", "-t", session_name, f"/op {agent_names[0]}", "C-m"])
|
||||
|
||||
def make_profiles(agent_names, models, template_profile="profiles/collab_profile.json"):
|
||||
def make_profiles(agent_names, models, apis, template_profile="profiles/collab_profile.json"):
|
||||
assert len(agent_names) == len(models)
|
||||
|
||||
with open(template_profile, 'r') as f:
|
||||
|
@ -253,7 +258,14 @@ def make_profiles(agent_names, models, template_profile="profiles/collab_profile
|
|||
|
||||
for index in range(len(agent_names)):
|
||||
profile["name"] = agent_names[index]
|
||||
profile["model"] = models[index]
|
||||
if apis[index] == "vllm":
|
||||
profile["model"] = {
|
||||
"api": "vllm",
|
||||
"model": models[index],
|
||||
"url": "http://localhost:8000/v1"
|
||||
}
|
||||
else:
|
||||
profile["model"] = models[index]
|
||||
|
||||
with open(f"{agent_names[index]}.json", 'w') as f:
|
||||
json.dump(profile, f, indent=4)
|
||||
|
@ -372,7 +384,8 @@ def main():
|
|||
parser.add_argument('--bucket_name', default="mindcraft-experiments", help='Name of the s3 bucket')
|
||||
parser.add_argument('--add_keys', action='store_true', help='Create the keys.json to match the environment variables')
|
||||
parser.add_argument('--template_profile', default="andy.json", help='Model to use for the agents')
|
||||
parser.add_argument('--model', default="gpt-4o", help='Model to use for the agents')
|
||||
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
|
||||
parser.add_argument('--api', default="openai", help='API to use for the agents')
|
||||
# parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
|
||||
# parser.add_argument('--wandb_project', default="minecraft_experiments", help='wandb project name')
|
||||
|
||||
|
@ -400,7 +413,8 @@ def main():
|
|||
s3=args.s3,
|
||||
bucket_name=args.bucket_name,
|
||||
template_profile=args.template_profile,
|
||||
model=args.model)
|
||||
model=args.model,
|
||||
api=args.api)
|
||||
|
||||
# servers = create_server_files("../server_data/", args.num_parallel)
|
||||
# date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
|
|
@ -42,4 +42,4 @@
|
|||
"type": "techtree",
|
||||
"timeout": 120
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import { Qwen } from "./qwen.js";
|
|||
import { Grok } from "./grok.js";
|
||||
import { DeepSeek } from './deepseek.js';
|
||||
import { OpenRouter } from './openrouter.js';
|
||||
import { VLLM } from './vllm.js';
|
||||
|
||||
export class Prompter {
|
||||
constructor(agent, fp) {
|
||||
|
@ -181,6 +182,8 @@ export class Prompter {
|
|||
model = new DeepSeek(profile.model, profile.url, profile.params);
|
||||
else if (profile.api === 'openrouter')
|
||||
model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params);
|
||||
else if (profile.api === 'vllm')
|
||||
model = new VLLM(profile.model, profile.url, profile.params);
|
||||
else
|
||||
throw new Error('Unknown API:', profile.api);
|
||||
return model;
|
||||
|
|
Loading…
Add table
Reference in a new issue