import os import comet.src.data.atomic as atomic_data import comet.src.data.conceptnet as conceptnet_data import comet.src.data.config as cfg import comet.utils.utils as utils import pickle import torch import json start_token = "" end_token = "" blank_token = "" def save_checkpoint(state, filename): print("Saving model to {}".format(filename)) torch.save(state, filename) def save_step(model, vocab, optimizer, opt, length, lrs): if cfg.test_save: name = "{}.pickle".format(utils.make_name( opt, prefix="garbage/models/", is_dir=False, eval_=True)) else: name = "{}.pickle".format(utils.make_name( opt, prefix="models/", is_dir=False, eval_=True)) save_checkpoint({ "epoch": length, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "opt": opt, "vocab": vocab, "epoch_learning_rates": lrs}, name) def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"): if cfg.test_save: name = "{}/{}.{}".format(utils.make_name( opt, prefix="garbage/{}/".format(eval_type), is_dir=True, eval_=True), split, ext) else: name = "{}/{}.{}".format(utils.make_name( opt, prefix="results/{}/".format(eval_type), is_dir=True, eval_=True), split, ext) print("Saving {} {} to {}".format(split, eval_type, name)) if ext == "pickle": with open(name, "wb") as f: pickle.dump(stats, f) elif ext == "txt": with open(name, "w") as f: f.write(stats) elif ext == "json": with open(name, "w") as f: json.dump(stats, f) else: raise def load_checkpoint(filename, gpu=True): if os.path.exists(filename): checkpoint = torch.load( filename, map_location=lambda storage, loc: storage) else: print("No model found at {}".format(filename)) return checkpoint def make_data_loader(opt, *args): if opt.dataset == "atomic": return atomic_data.GenerationDataLoader(opt, *args) elif opt.dataset == "conceptnet": return conceptnet_data.GenerationDataLoader(opt, *args) def set_max_sizes(data_loader, force_split=None): data_loader.total_size = {} if force_split is not None: data_loader.total_size[force_split] = \ data_loader.sequences[force_split]["total"].size(0) return for split in data_loader.sequences: data_loader.total_size[split] = \ data_loader.sequences[split]["total"].size(0)