API References
DATA
Dataloaders
Special version of |
|
DataLoader designed to work with data represented as |
|
DataLoader designed to work with data represented as |
Datasets
Wrapper, which converts standard dataset of classification task into dataset, compatible with |
Samples
Represent groups of similar objects all of which should match with one-another within the group. |
|
Represents a pair of objects, their similarity and relationship with other pairs. |
DISTANCES
Compute cosine similarities (and its interpretation as distances). |
|
Compute dot product similarities (and its interpretation as distances). |
|
Compute Euclidean distances (and its interpretation as similarities). |
|
Compute Manhattan distances (and its interpretation as similarities). |
EVAL
Counters
Attach batch-wise metric to |
|
Calculate metrics on the whole datasets |
Group metrics
Base class for group metrics |
|
Compute the retrieval R-precision score for group based data |
Pair metrics
Base class for metrics computation for pair based data |
|
Calculates retrieval precision@k for pair based datasets |
|
Calculates retrieval reciprocal rank for pair based datasets |
Samplers
Perform selection of embeddings and targets for group based tasks. |
|
Perform selection of embeddings and targets for pairs based tasks. |
LOSSES
Base
Base class for group losses. |
|
Base class for pairwise losses. |
Implementations
Additive Angular Margin Loss as defined in https://arxiv.org/abs/1801.07698 |
|
Contrastive loss. |
|
Implement Multiple Negatives Ranking Loss as described in https://arxiv.org/pdf/1705.00652.pdf |
|
Regular cross-entropy loss. |
|
Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832 |
Extras
Provide a simple wrapper to be able to use losses and miners from pytorch-metric-learning. |
MAIN
Fine-tuning entry point |
TRAIN
TrainableModel
Base class for models to be trained. |
Cache
Determine cache settings. |
|
Available tensor devices to be used for caching. |
UTILS
Handle train stage. |
|
Creates a 3D mask of valid triplets for the batch-all strategy. |
|
Creates a 2D mask of valid anchor-positive pairs. |
|
Creates a 2D mask of valid anchor-negative pairs. |