Source code for opr.testing

"""Testing functions implementation."""
import itertools
from typing import Optional, Tuple, Union

import numpy as np
import torch
from pytorch_metric_learning.distances import LpDistance
from sklearn.neighbors import KDTree
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from opr.utils import parse_device


[docs] def get_recalls( query_embs: np.ndarray, db_embs: np.ndarray, dist_matrix: np.ndarray, dist_thresh: float = 25.0, at_n: int = 25, ) -> Tuple[np.ndarray, float, Optional[float]]: """Calculate Recall@N, Recall@1% and mean top-1 distance for the given query and db embeddings. Args: query_embs (np.ndarray): Query embeddings array. db_embs (np.ndarray): Database embeddings array. dist_matrix (np.ndarray): Distance matrix of shape (query_len, db_len). dist_thresh (float): Distance threshold for positive match. Defaults to 25.0. at_n (int): The maximum N value for the Recall@N metric. Defaults to 25. Returns: Tuple[np.ndarray, float, Optional[float]]: (Recall@N, Recall@1%, mean top-1 distance). The 'mean top-1 distance' metric may be `None` if Recall@1 = 0. """ database_tree = KDTree(db_embs) positives_mask = dist_matrix <= dist_thresh queries_with_matches = int((positives_mask.sum(axis=1) > 0).sum()) queries_with_matches_indices = np.argwhere(positives_mask.sum(axis=1) > 0) assert queries_with_matches_indices.shape[0] == queries_with_matches recall_at_n = np.zeros((at_n,), dtype=float) top1_distances = [] errors_query = [] errors_db = [] gt_db = [] errors_dists = [] one_percent_threshold = max(int(round(len(db_embs) / 100.0)), 1) distances, indices = database_tree.query(query_embs, k=at_n) for query_i, closest_inds in enumerate(indices): query_gt_matches_mask = positives_mask[query_i][closest_inds] if query_gt_matches_mask[0]: top1_distances.append(distances[query_i][0]) elif query_i in queries_with_matches_indices: errors_dists.append(distances[query_i][0]) errors_query.append(query_i) errors_db.append(closest_inds[0]) gt_db.append(dist_matrix[query_i].argmin()) recall_at_n += np.cumsum(query_gt_matches_mask, axis=0, dtype=bool) recall_at_n = recall_at_n / queries_with_matches one_percent_recall = recall_at_n[one_percent_threshold - 1] if len(top1_distances) > 0: mean_top1_distance = np.mean(top1_distances) else: mean_top1_distance = None return recall_at_n, one_percent_recall, mean_top1_distance
[docs] def test( model: nn.Module, dataloader: DataLoader, distance_threshold: float = 25.0, device: Union[str, int, torch.device] = "cuda", ) -> Tuple[np.ndarray, float, float]: """Evaluates the model on the test set. Args: model (nn.Module): The model to test. dataloader (DataLoader): The data loader for the test set. distance_threshold (float): The distance threshold for a correct match. Defaults to 25.0. device (Union[str, int, torch.device]): Device ("cpu" or "cuda"). Defaults to "cuda". Returns: Tuple[np.ndarray, float, float]: Array of AverageRecall@N (N from 1 to 25), AverageRecall@1% and mean top-1 distance. Raises: ValueError: If the required coordinate columns are not found in the dataset. """ device = parse_device(device) model = model.to(device) with torch.no_grad(): embeddings_list = [] for batch in tqdm(dataloader, desc="Calculating test set descriptors", leave=False): batch = {e: batch[e].to(device) for e in batch} embeddings = model(batch)["final_descriptor"] embeddings_list.append(embeddings.cpu().numpy()) torch.cuda.empty_cache() test_embeddings = np.vstack(embeddings_list) test_df = dataloader.dataset.dataset_df # temporary workaround for working with datasets with frame sequences if hasattr(dataloader.dataset, "_sequence_indices") and dataloader.dataset._sequence_indices is not None: sequence_indices = [seq[0] for seq in dataloader.dataset._sequence_indices] test_df = test_df.iloc[sequence_indices] test_df = test_df.reset_index(drop=True) queries = [] databases = [] for _, group in test_df.groupby("track"): databases.append(group.index.to_list()) if "in_query" in group.columns: selected_queries = group[group["in_query"]] queries.append(selected_queries.index.to_list()) else: queries.append(group.index.to_list()) if "northing" in test_df.columns and "easting" in test_df.columns: coords_columns = ["northing", "easting"] elif "x" in test_df.columns and "y" in test_df.columns: coords_columns = ["x", "y"] elif "tx" in test_df.columns and "ty" in test_df.columns: coords_columns = ["tx", "ty"] else: raise ValueError( "Required coordinate columns ('northing'/'easting', 'x'/'y', or 'tx'/'ty') not found in the dataset" ) utms = torch.tensor(test_df[coords_columns].to_numpy()) dist_fn = LpDistance(normalize_embeddings=False) dist_utms = dist_fn(utms).numpy() n = 50 recalls_at_n = np.zeros((len(queries), len(databases), n)) recalls_at_one_percent = np.zeros((len(queries), len(databases), 1)) top1_distances = np.zeros((len(queries), len(databases), 1)) ij_permutations = list(itertools.permutations(range(len(queries)), 2)) count_r_at_1 = 0 for i, j in tqdm(ij_permutations, desc="Calculating metrics", leave=False): query = queries[i] database = databases[j] query_embs = test_embeddings[query] database_embs = test_embeddings[database] distances = dist_utms[query][:, database] ( recalls_at_n[i, j], recalls_at_one_percent[i, j], top1_distance, ) = get_recalls(query_embs, database_embs, distances, at_n=n, dist_thresh=distance_threshold) if top1_distance: count_r_at_1 += 1 top1_distances[i, j] = top1_distance mean_recall_at_n = recalls_at_n.sum(axis=(0, 1)) / len(ij_permutations) mean_recall_at_one_percent = recalls_at_one_percent.sum(axis=(0, 1)).squeeze() / len(ij_permutations) mean_top1_distance = top1_distances.sum(axis=(0, 1)).squeeze() / len(ij_permutations) return mean_recall_at_n, mean_recall_at_one_percent, mean_top1_distance