model
- class aitoolbox.torchtrain.model.TTModel[source]
-
TTModel is an extension of core PyTorch nn.Module
TT in TTModel –> TorchTrain Model
In addition to the common
forward()
method required by the basetorch.nn.Module
, the user also needs to implement the additional AIToolbox specificget_loss()
andget_predictions()
methods. Optionally, the user can also implement a desiredget_loss_eval()
method for specific loss calculation when in evaluation mode.transfer_model_attributes
(list or tuple): additional TTModel attributes which need to be transferred to the TTDataParallel level to enable their use in the transferred/exposed class methods. When coding the model’s __init__() method user should also fill in the string names of attributes that should be transferred in case the model is wrapped for DP/DDP.Initializes internal Module state, shared by both nn.Module and ScriptModule.
- abstract get_loss(batch_data, criterion, device)[source]
Get loss during training stage
Called from fit() in TrainLoop
Executed during training stage where model weights are updated based on the loss returned from this function.
- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
criterion (torch.nn.Module) – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
loss
- Return type:
torch.Tensor or
MultiLoss
- get_loss_eval(batch_data, criterion, device)[source]
Get loss during evaluation stage
Called from evaluate_model_loss() in TrainLoop.
The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training.
For simple examples this function can just call the
get_loss()
and return its result.- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
criterion (torch.nn.Module) – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
loss
- Return type:
torch.Tensor or
MultiLoss
- abstract get_predictions(batch_data, device)[source]
Get predictions during evaluation stage
- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
device (torch.device) – device on which the model is making the prediction
- Returns:
y_pred, y_test, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
(torch.Tensor, torch.Tensor, dict or None)
- class aitoolbox.torchtrain.model.TTBasicModel[source]
Bases:
TTModel
Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions
The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data)
This base class is mainly meant to be used for simple models. TTBasicModel removes the need to constantly duplicate code in get_loss and get_predictions.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- get_loss(batch_data, criterion, device)[source]
Get loss during training stage
Called from fit() in TrainLoop
Executed during training stage where model weights are updated based on the loss returned from this function.
- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
criterion (torch.nn.Module) – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
loss
- Return type:
torch.Tensor or
MultiLoss
- get_predictions(batch_data, device)[source]
Get predictions during evaluation stage
- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
device (torch.device) – device on which the model is making the prediction
- Returns:
y_pred, y_test, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
(torch.Tensor, torch.Tensor, dict or None)
- class aitoolbox.torchtrain.model.TTBasicMultiGPUModel[source]
Bases:
TTBasicModel
- Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions
which support leveled utilization when training on multi-GPU.
The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data)
In the case of the
get_loss()
the input into the model’sforward()
function will also provide targets and criterion arguments in order to enable calculation of the loss insideforward()
function.The forward() function should have the following parameter signature and should finish with:
def forward(*batch_input_data, targets=None, criterion=None): ... predictions calculation via the computational graph ... if criterion is not None: return criterion(predictions, targets) else: return predictions
This base class is mainly meant to be used for simple models. TTBasicMultiGPUModel removes the need to constantly duplicate code in get_loss and get_predictions.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- get_loss(batch_data, criterion, device)[source]
Get loss during training stage
Called from fit() in TrainLoop
Executed during training stage where model weights are updated based on the loss returned from this function.
- Parameters:
batch_data (torch.Tensor or list or tuple or dict) – model input data batch
criterion (torch.nn.Module) – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
loss
- Return type:
torch.Tensor or
MultiLoss
- class aitoolbox.torchtrain.model.MultiGPUModelWrap(model)[source]
Bases:
TTBasicMultiGPUModel
Model wrapper optimizing the model for multi-GPU training by moving the loss calculation to the GPUs
- Parameters:
model (torch.nn.Module or TTModel) – neural network model. The model should follow the basic PyTorch model definition where the
forward()
function returns predictions
- forward(*input_data, targets=None, criterion=None)[source]
DP friendly forward abstraction on top of the wrapped model’s usual forward() function
- Parameters:
*input_data – whatever input data should be passed into the wrapped model’s forward() function
targets – target variables which the model is training to fit
criterion – loss function
- Returns:
PyTorch loss or model output predictions. If loss function criterion is provided this function returns the calculated loss, otherwise the model output predictions are returned
- class aitoolbox.torchtrain.model.ModelWrap(model, batch_model_feed_def)[source]
Bases:
object
TrainLoop model wrapper combining PyTorch model and model feed definition
Note
Especially useful in the case when you want to train on multi-GPU where TTModel abstract functions can’t be used.
ModelWrap can be used as a replacement of TTModel when using the TrainLoop.
- Parameters:
model (torch.nn.Module) – neural network model
batch_model_feed_def (AbstractModelFeedDefinition or None) – data prep definition for batched data. This definition prepares the data for each batch that gets than fed into the neural network.