mirror of
https://github.com/kolbytn/mindcraft.git
synced 2025-08-11 17:55:34 +02:00
set up to use s3 logging instead of wandb
This commit is contained in:
parent
d4565aa68c
commit
719b72da9e
1 changed files with 35 additions and 16 deletions
|
@ -94,7 +94,9 @@ def launch_parallel_experiments(task_path,
|
||||||
exp_name,
|
exp_name,
|
||||||
num_agents=2,
|
num_agents=2,
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
num_parallel=1):
|
num_parallel=1,
|
||||||
|
s3=False,
|
||||||
|
bucket_name="mindcraft-experiments"):
|
||||||
|
|
||||||
with open(task_path, 'r', encoding='utf-8') as file:
|
with open(task_path, 'r', encoding='utf-8') as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
@ -107,14 +109,21 @@ def launch_parallel_experiments(task_path,
|
||||||
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
|
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
|
||||||
|
|
||||||
servers = create_server_files("../server_data/", num_parallel)
|
servers = create_server_files("../server_data/", num_parallel)
|
||||||
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
date_time = datetime.now().strftime("%m-%d_%H-%M")
|
||||||
experiments_folder = f"experiments/{exp_name}_{date_time}"
|
experiments_folder = f"experiments/{exp_name}_{date_time}"
|
||||||
exp_name = f"{exp_name}_{date_time}"
|
exp_name = f"{exp_name}_{date_time}"
|
||||||
|
|
||||||
# start wandb
|
# start wandb
|
||||||
os.makedirs(experiments_folder, exist_ok=True)
|
os.makedirs(experiments_folder, exist_ok=True)
|
||||||
for i, server in enumerate(servers):
|
for i, server in enumerate(servers):
|
||||||
launch_server_experiment(task_path, task_ids_split[i], num_exp, server, experiments_folder, exp_name)
|
launch_server_experiment(task_path,
|
||||||
|
task_ids_split[i],
|
||||||
|
num_exp,
|
||||||
|
server,
|
||||||
|
experiments_folder,
|
||||||
|
exp_name,
|
||||||
|
s3=s3,
|
||||||
|
bucket_name=bucket_name)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +134,9 @@ def launch_server_experiment(task_path,
|
||||||
experiments_folder,
|
experiments_folder,
|
||||||
exp_name="exp",
|
exp_name="exp",
|
||||||
num_agents=2,
|
num_agents=2,
|
||||||
model="gpt-4o"):
|
model="gpt-4o",
|
||||||
|
s3=False,
|
||||||
|
bucket_name="mindcraft-experiments"):
|
||||||
"""
|
"""
|
||||||
Launch a Minecraft server and run experiments on it.
|
Launch a Minecraft server and run experiments on it.
|
||||||
@param task_path: Path to the task file
|
@param task_path: Path to the task file
|
||||||
|
@ -173,13 +184,14 @@ def launch_server_experiment(task_path,
|
||||||
script_content += "sleep 2\n"
|
script_content += "sleep 2\n"
|
||||||
for agent in agent_names:
|
for agent in agent_names:
|
||||||
cp_cmd = f"cp bots/{agent}/memory.json {experiments_folder}/{task_id}_{agent}_{_}.json"
|
cp_cmd = f"cp bots/{agent}/memory.json {experiments_folder}/{task_id}_{agent}_{_}.json"
|
||||||
|
script_content += f"echo '{cp_cmd}'\n"
|
||||||
script_content += f"{cp_cmd}\n"
|
script_content += f"{cp_cmd}\n"
|
||||||
script_content += "sleep 1\n"
|
script_content += "sleep 1\n"
|
||||||
script_content += f"echo 'Uploading {experiments_folder}/{task_id}_{agent}_{_}.json to wandb'\n"
|
if s3:
|
||||||
wandb_cmd = f"wandb artifact put {experiments_folder}/{task_id}_{agent}_{_}.json --name {exp_name}_{task_id}_{agent}_{_} --type dataset"
|
script_content += f"echo 'Uploading {experiments_folder}/{task_id}_{agent}_{_}.json to s3'\n"
|
||||||
script_content += f"echo '{wandb_cmd}'\n"
|
s3_cmd = f"aws s3 cp bots/{agent}/memory.json s3://{bucket_name}/{experiments_folder}/{task_id}_{agent}_{_}.json"
|
||||||
script_content += f"{wandb_cmd}\n"
|
script_content += f"echo '{s3_cmd}'\n"
|
||||||
script_content += "sleep 1\n"
|
script_content += f"{s3_cmd}\n"
|
||||||
script_content += "sleep 1\n"
|
script_content += "sleep 1\n"
|
||||||
|
|
||||||
# Create a temporary shell script file
|
# Create a temporary shell script file
|
||||||
|
@ -316,14 +328,16 @@ def main():
|
||||||
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
|
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
|
||||||
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
|
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
|
||||||
parser.add_argument('--exp_name', default="exp", help='Name of the experiment')
|
parser.add_argument('--exp_name', default="exp", help='Name of the experiment')
|
||||||
parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
|
parser.add_argument('--s3', action='store_true', help='Whether to upload to s3')
|
||||||
parser.add_argument('--wandb-project', default="minecraft_experiments", help='wandb project name')
|
parser.add_argument('--bucket_name', default="mindcraft-experiments", help='Name of the s3 bucket')
|
||||||
|
# parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
|
||||||
|
# parser.add_argument('--wandb_project', default="minecraft_experiments", help='wandb project name')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.wandb:
|
# if args.wandb:
|
||||||
import wandb
|
# import wandb
|
||||||
wandb.init(project=args.wandb_project, name=args.exp_name)
|
# wandb.init(project=args.wandb_project, name=args.exp_name)
|
||||||
|
|
||||||
# kill all tmux session before starting
|
# kill all tmux session before starting
|
||||||
try:
|
try:
|
||||||
|
@ -334,7 +348,12 @@ def main():
|
||||||
# delete all server files
|
# delete all server files
|
||||||
clean_up_server_files(args.num_parallel)
|
clean_up_server_files(args.num_parallel)
|
||||||
if args.task_id is None:
|
if args.task_id is None:
|
||||||
launch_parallel_experiments(args.task_path, num_exp=args.num_exp, exp_name=args.exp_name, num_parallel=args.num_parallel)
|
launch_parallel_experiments(args.task_path,
|
||||||
|
num_exp=args.num_exp,
|
||||||
|
exp_name=args.exp_name,
|
||||||
|
num_parallel=args.num_parallel,
|
||||||
|
s3=args.s3,
|
||||||
|
bucket_name=args.bucket_name)
|
||||||
|
|
||||||
# servers = create_server_files("../server_data/", args.num_parallel)
|
# servers = create_server_files("../server_data/", args.num_parallel)
|
||||||
# date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
# date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
Loading…
Add table
Reference in a new issue