quaterion.main module¶
- class Quaterion[source]¶
Bases:
object
Fine-tuning entry point
Contains methods to launch the actual training and evaluation processes.
- classmethod evaluate(evaluator: Evaluator, dataset: Union[Sized, Iterable, Dataset], model: SimilarityModel) Dict[str, Tensor] [source]¶
Compute metrics on a dataset
- Parameters:
evaluator – Object which holds the configuration of which metrics to use and how to obtain samples for them
dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to compute metrics
model – SimilarityModel instance to perform objects encoding
- Returns:
Dict[str, torch.Tensor] - dict of computed metrics. Where key - name of the metric and value - metric estimated values
- classmethod fit(trainable_model: TrainableModel, trainer: Optional[Trainer], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader] = None, ckpt_path: Optional[str] = None)[source]¶
Handle training routine
Assemble data loaders, performs caching and whole training process.
- Parameters:
trainable_model – model to fit
trainer – pytorch_lightning.Trainer instance to handle fitting routine internally. If None passed, trainer will be created with
Quaterion.trainer_defaults()
. The default parameters are intended to serve as a quick start for learning the model, and we encourage users to try different parameters if the default ones do not give a satisfactory result.train_dataloader – DataLoader instance to retrieve samples during training stage
val_dataloader – Optional DataLoader instance to retrieve samples during validation stage
ckpt_path – Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.
- static trainer_defaults(trainable_model: Optional[TrainableModel] = None, train_dataloader: Optional[SimilarityDataLoader] = None)[source]¶
Reasonable default parameters for pytorch_lightning.Trainer
This function generates parameter set for Trainer, which are considered “recommended” for most use-cases of Quaterion. Quaterion similarity learning train process has characteristics that differentiate it from regular deep learning model training. This default parameters may be overwritten, if you need some special behaviour for your special task.
Consider overriding default parameters if you need to adjust Trainer parameters:
Example:
trainer_kwargs = Quaterion.trainer_defaults( trainable_model=model, train_dataloader=train_dataloader ) trainer_kwargs['logger'] = pl.loggers.WandbLogger( name="example_model", project="example_project", ) trainer_kwargs['callbacks'].append(YourCustomCallback()) trainer = pl.Trainer(**trainer_kwargs)
- Parameters:
trainable_model – We will try to adjust default params based on model configuration, if provided
train_dataloader – If provided, trainer params will be adjusted according to dataset
- Returns:
kwargs for pytorch_lightning.Trainer