gperc.trainer module¶
Trainer¶
This is a generic trainer built for the gperc project. More documentation will be added later.
- class gperc.trainer.Trainer(model, save_folder=None, save_every=1000, client=None)[source]¶
Bases:
object
Generic trainer for the
gperc
project.- Parameters
model (torch.nn.Module) – The model to train.
client (function) – A function that takes a dict as input and logs it to a remote server.
- save(name, optim=None, lr_scheduler=None)[source]¶
save the items to
self.save_folder/name/
folder- Parameters
name (str) – The name of the save folder
optim (torch.optim.Optimizer) – The optimizer to save
lr_scheduler (torch.optim.lr_scheduler) – The lr scheduler to save
- load(save_folder, optim=None, lr_scheduler=None)[source]¶
Load the model from the given save folder. If any of these fails, you will have to manually check.
- Parameters
save_folder (str) – The folder to load from
optim (torch.optim.Optimizer) – The optimizer to load
lr_scheduler (torch.optim.lr_scheduler) – The lr scheduler to load
- train(optim, train_data, n_steps, test_every=None, test_data=None)[source]¶
Train model with given optimiser, data and number of steps, optionally provide testing material as well.
- Parameters
optim (torch.optim.Optimizer) – The optimizer to use
train_data (
gperc.Consumer/ArrowConsumer
) – The training data, batches must be createdn_steps (int) – The number of steps to train for
test_every (int) – The number of steps to train for before testing, defaults to
n_steps
test_data (
gperc.Consumer/ArrowConsumer
) – The testing data, batches must be created