gperc.trainer module

Trainer

This is a generic trainer built for the gperc project. More documentation will be added later.

class gperc.trainer.Trainer(model, save_folder=None, save_every=1000, client=None)[source]

Bases: object

Generic trainer for the gperc project.

Parameters
  • model (torch.nn.Module) – The model to train.

  • client (function) – A function that takes a dict as input and logs it to a remote server.

save(name, optim=None, lr_scheduler=None)[source]

save the items to self.save_folder/name/ folder

Parameters
  • name (str) – The name of the save folder

  • optim (torch.optim.Optimizer) – The optimizer to save

  • lr_scheduler (torch.optim.lr_scheduler) – The lr scheduler to save

load(save_folder, optim=None, lr_scheduler=None)[source]

Load the model from the given save folder. If any of these fails, you will have to manually check.

Parameters
  • save_folder (str) – The folder to load from

  • optim (torch.optim.Optimizer) – The optimizer to load

  • lr_scheduler (torch.optim.lr_scheduler) – The lr scheduler to load

train(optim, train_data, n_steps, test_every=None, test_data=None)[source]

Train model with given optimiser, data and number of steps, optionally provide testing material as well.

Parameters
  • optim (torch.optim.Optimizer) – The optimizer to use

  • train_data (gperc.Consumer/ArrowConsumer) – The training data, batches must be created

  • n_steps (int) – The number of steps to train for

  • test_every (int) – The number of steps to train for before testing, defaults to n_steps

  • test_data (gperc.Consumer/ArrowConsumer) – The testing data, batches must be created