set up to use s3 logging instead of wandb

This commit is contained in:
Isadora White 2025-02-21 17:02:21 -08:00
parent d4565aa68c
commit 719b72da9e

View file

@ -94,7 +94,9 @@ def launch_parallel_experiments(task_path,
exp_name,
num_agents=2,
model="gpt-4o",
num_parallel=1):
num_parallel=1,
s3=False,
bucket_name="mindcraft-experiments"):
with open(task_path, 'r', encoding='utf-8') as file:
content = file.read()
@ -107,14 +109,21 @@ def launch_parallel_experiments(task_path,
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
servers = create_server_files("../server_data/", num_parallel)
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
date_time = datetime.now().strftime("%m-%d_%H-%M")
experiments_folder = f"experiments/{exp_name}_{date_time}"
exp_name = f"{exp_name}_{date_time}"
# start wandb
os.makedirs(experiments_folder, exist_ok=True)
for i, server in enumerate(servers):
launch_server_experiment(task_path, task_ids_split[i], num_exp, server, experiments_folder, exp_name)
launch_server_experiment(task_path,
task_ids_split[i],
num_exp,
server,
experiments_folder,
exp_name,
s3=s3,
bucket_name=bucket_name)
time.sleep(5)
@ -125,7 +134,9 @@ def launch_server_experiment(task_path,
experiments_folder,
exp_name="exp",
num_agents=2,
model="gpt-4o"):
model="gpt-4o",
s3=False,
bucket_name="mindcraft-experiments"):
"""
Launch a Minecraft server and run experiments on it.
@param task_path: Path to the task file
@ -173,14 +184,15 @@ def launch_server_experiment(task_path,
script_content += "sleep 2\n"
for agent in agent_names:
cp_cmd = f"cp bots/{agent}/memory.json {experiments_folder}/{task_id}_{agent}_{_}.json"
script_content += f"echo '{cp_cmd}'\n"
script_content += f"{cp_cmd}\n"
script_content += "sleep 1\n"
script_content += f"echo 'Uploading {experiments_folder}/{task_id}_{agent}_{_}.json to wandb'\n"
wandb_cmd = f"wandb artifact put {experiments_folder}/{task_id}_{agent}_{_}.json --name {exp_name}_{task_id}_{agent}_{_} --type dataset"
script_content += f"echo '{wandb_cmd}'\n"
script_content += f"{wandb_cmd}\n"
script_content += "sleep 1\n"
script_content += "sleep 1\n"
if s3:
script_content += f"echo 'Uploading {experiments_folder}/{task_id}_{agent}_{_}.json to s3'\n"
s3_cmd = f"aws s3 cp bots/{agent}/memory.json s3://{bucket_name}/{experiments_folder}/{task_id}_{agent}_{_}.json"
script_content += f"echo '{s3_cmd}'\n"
script_content += f"{s3_cmd}\n"
script_content += "sleep 1\n"
# Create a temporary shell script file
script_file = f"./tmp/experiment_script_{session_name}.sh"
@ -316,14 +328,16 @@ def main():
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
parser.add_argument('--exp_name', default="exp", help='Name of the experiment')
parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
parser.add_argument('--wandb-project', default="minecraft_experiments", help='wandb project name')
parser.add_argument('--s3', action='store_true', help='Whether to upload to s3')
parser.add_argument('--bucket_name', default="mindcraft-experiments", help='Name of the s3 bucket')
# parser.add_argument('--wandb', action='store_true', help='Whether to use wandb')
# parser.add_argument('--wandb_project', default="minecraft_experiments", help='wandb project name')
args = parser.parse_args()
if args.wandb:
import wandb
wandb.init(project=args.wandb_project, name=args.exp_name)
# if args.wandb:
# import wandb
# wandb.init(project=args.wandb_project, name=args.exp_name)
# kill all tmux session before starting
try:
@ -334,7 +348,12 @@ def main():
# delete all server files
clean_up_server_files(args.num_parallel)
if args.task_id is None:
launch_parallel_experiments(args.task_path, num_exp=args.num_exp, exp_name=args.exp_name, num_parallel=args.num_parallel)
launch_parallel_experiments(args.task_path,
num_exp=args.num_exp,
exp_name=args.exp_name,
num_parallel=args.num_parallel,
s3=args.s3,
bucket_name=args.bucket_name)
# servers = create_server_files("../server_data/", args.num_parallel)
# date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")