quaterion.eval.accumulators.pair_accumulator module¶
- class PairAccumulator[source]¶
Bases:
Accumulator
Accumulate embeddings, labels, pairs and subgroups for pair-based tasks.
Keep track of current size to properly handle pairs.
- reset()[source]¶
Reset accumulator state
Reset accumulator status and size, accumulated embeddings, labels, pairs and subgroups
- update(embeddings: Tensor, labels: Tensor, pairs: LongTensor, subgroups: Tensor, device=None)[source]¶
Update accumulator state.
Move provided embeddings and groups to proper device and add to accumulated state.
- Parameters:
embeddings – embeddings to accumulate
labels – labels to distinguish similar and dissimilar objects.
pairs – indices to determine objects of one pair
subgroups – subgroups numbers to determine which samples can be considered negative
device – device to store calculated embeddings and groups on.
- property labels¶
Concatenate list of labels to Tensor
Help to avoid concatenating labels for each batch during accumulation. Instead, concatenate it only on call.
- Returns:
torch.Tensor – batch of labels
- property pairs: LongTensor¶
Concatenate list of pairs to Tensor
Help to avoid concatenating pairs for each batch during accumulation. Instead, concatenate it only on call.
- Returns:
torch.Tensor – batch of pairs
- property state: Dict[str, Tensor]¶
Accumulated state
- Returns:
Dict[str, torch.Tensor] - dictionary accumulates embeddings, labels, pairs, subgroups.
- property subgroups¶
Concatenate list of subgroups to Tensor
Help to avoid concatenating subgroups for each batch during accumulation. Instead, concatenate it only on call.
- Returns:
torch.Tensor – batch of subgroups