Source code for gperc.trainer

"""
Trainer
=======

This is a generic trainer built for the gperc project. More documentation will be added later.
"""

import os
import torch
from tqdm.auto import trange

from .utils import timeit

[docs]class Trainer(): def __init__(self, model, save_folder = None, save_every = 1000, client = None): """ Generic trainer for the ``gperc`` project. Args: model (torch.nn.Module): The model to train. client (function): A function that takes a dict as input and logs it to a remote server. """ self.model = model self.save_folder = save_folder self.save_every = save_every self.client = client self.model_config = model.config self.device = next(self.model.parameters()).device if save_folder != None: os.makedirs(self.save_folder, exist_ok = True) # create this just in case with open(os.path.join(self.save_folder, "config.json"), "w") as f: f.write(self.model_config.to_json())
[docs] def save(self, name, optim = None, lr_scheduler = None): """save the items to ``self.save_folder/name/`` folder Args: name (str): The name of the save folder optim (torch.optim.Optimizer): The optimizer to save lr_scheduler (torch.optim.lr_scheduler): The lr scheduler to save """ if self.save_folder == None: print("No save folder specified, skipping saving.") return # create a new folder for current step step_folder = os.path.join(self.save_folder, name) print(f"Saving in folder: {step_folder}") os.makedirs(step_folder, exist_ok = True) torch.save(self.model.state_dict(), os.path.join(step_folder, "model.pt")) if optim != None: torch.save(optim.state_dict(), os.path.join(step_folder, "optim.pt")) if lr_scheduler != None: torch.save(lr_scheduler.state_dict(), os.path.join(step_folder, "lr_scheduler.pt"))
[docs] def load(self, save_folder, optim = None, lr_scheduler = None): """ Load the model from the given save folder. If any of these fails, you will have to manually check. Args: save_folder (str): The folder to load from optim (torch.optim.Optimizer): The optimizer to load lr_scheduler (torch.optim.lr_scheduler): The lr scheduler to load """ self.model.load_state_dict(torch.load(os.path.join(save_folder, "model.pt"))) if optim != None and os.path.exists(os.path.join(save_folder, "optim.pt")): self.optim.load_state_dict(torch.load(os.path.join(save_folder, "optim.pt"))) if lr_scheduler != None and os.path.exists(os.path.join(save_folder, "lr_scheduler.pt")): self.lr_scheduler.load_state_dict(torch.load(os.path.join(save_folder, "lr_scheduler.pt")))
def __call__(self, batch, step, n_bytes, n_classes, pbar, train = True, grad_clip = 1.0, optim = None): """ Train or test a batch, returns the current batch meta. Args: batch (dict): The batch from ``gperc.Consumer`` or ``gperc.ArrowConsumer`` step (int): The current step number n_bytes (int): The number of bytes in the dataset n_classes (int): The number of classes in the dataset pbar (tqdm.auto.trange): The progress bar train (bool): Whether to train or test grad_clip (float): The gradient clipping value, defaults to 1.0 optim (torch.optim.Optimizer): The optimizer to use, must be defined if ``train == True`` """ # initial setup self.model.train() if train else self.model.eval() prefix = "train" if train else "val" batch_meta = batch.pop("meta", {}) # create meta dict batch_meta[f"{prefix}/step"] = step # forward pass batch = {k:v.to(self.device) for k,v in batch.items()} forward_pass_time, out = timeit(self.model)(batch["input_array"]) batch_meta[f"{prefix}/forward_pass_time"] = forward_pass_time # make is so that any task can be run with this out = out.contiguous().view(-1, self.model_config["n_classes"]) target = batch["class"].contiguous().view(-1) loss = torch.nn.functional.cross_entropy(out, target, reduction = "none") acc = out.argmax(-1).eq(target).float().mean() _mean_loss = loss.mean() # capture the meta for the processed batch am = batch["attention_mask"] bytes_by_class = {c: 0 for c in range(n_classes)} loss_by_class, acc_by_class = bytes_by_class.copy(), bytes_by_class.copy() for i, c in enumerate(batch["class"].tolist()): bytes_by_class[c] += n_bytes * am[i].sum().item() loss_by_class[c] += loss[i].item() acc_by_class[c] += out[i].argmax(-1).eq(target[i]).float().mean().item() # update the logging dict bytes_processed = am.sum().item() * n_bytes batch_meta.update({ f"{prefix}/bytes_processed": bytes_processed, f"{prefix}/bytes_processed_per_second": bytes_processed / forward_pass_time, f"{prefix}/class_wise_bytes_processed": bytes_by_class, f"{prefix}/loss_avg" : _mean_loss.item(), f"{prefix}/acc_avg" : acc.item(), f"{prefix}/loss_class" : loss_by_class, f"{prefix}/acc_class" : acc_by_class, }) if train: # run the backward pass optim.zero_grad() _mean_loss.backward() for p in self.model.parameters(): if p.requires_grad: p.grad.data.clamp_(-grad_clip, grad_clip) # clip gradient values backward_pass_time, _ = timeit(optim.step)() batch_meta["train/backward_pass_time"] = backward_pass_time batch_meta["train/total_time"] = forward_pass_time + backward_pass_time pbar.set_description(f"[{prefix}] loss: {_mean_loss.item():.4f} | acc: {acc.item():.4f}") return batch_meta
[docs] def train(self, optim, train_data, n_steps, test_every = None, test_data = None): """Train model with given optimiser, data and number of steps, optionally provide testing material as well. Args: optim (torch.optim.Optimizer): The optimizer to use train_data (``gperc.Consumer/ArrowConsumer``): The training data, batches must be created n_steps (int): The number of steps to train for test_every (int): The number of steps to train for before testing, defaults to ``n_steps`` test_data (``gperc.Consumer/ArrowConsumer``): The testing data, batches must be created """ pbar = trange(n_steps) min_loss = float("inf") for i in pbar: batch = train_data.get_next_batch() batch_meta = self( batch = batch, step = i, n_bytes = train_data.n_bytes, n_classes = train_data.n_classes, pbar = pbar, grad_clip = 1.0, optim = optim, train = True ) if i and test_every != None and test_data != None and i % test_every == 0: pbar_val = trange(len(test_data._batches)) # we need to capture the meta for the test batches since they can be more than one metas = [] for _ in pbar: batch = test_data.get_next_batch() batch_meta = self( batch = batch, step = i, n_bytes = test_data.n_bytes, n_classes = test_data.n_classes, pbar = pbar_val, grad_clip = 1.0, optim = optim, train = False ) metas.append(batch_meta) # mean over all batches test_meta = {} for k, v in zip(metas[0].keys(), zip(*[m[k] for m in metas])): if isinstance(v, dict): test_meta[k] = {k2: [m[k][k2] for m in metas] for k2 in v.keys()} test_meta[k] = {k2: sum(v)/len(v) for k2, v in test_meta[k].items()} else: test_meta[k] = sum(v)/len(v) batch_meta.update(test_meta) print("val/loss:", test_meta["val/loss_avg"]) print("val/acc:", test_meta["val/acc_avg"]) if min_loss > test_meta["val/loss_avg"]: min_loss = test_meta["val/loss_avg"] self.save(f"step_{i}", optim, lr_scheduler = None) # log if self.client != None: self.client(batch_meta)