| | import torch |
| | import math |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def log_sum_exp(x, axis=None): |
| | """ |
| | Log sum exp function |
| | Args: |
| | x: Input. |
| | axis: Axis over which to perform sum. |
| | Returns: |
| | torch.Tensor: log sum exp |
| | """ |
| | x_max = torch.max(x, axis)[0] |
| | y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max |
| | return y |
| |
|
| |
|
| | def get_positive_expectation(p_samples, measure='JSD', average=True): |
| | """ |
| | Computes the positive part of a divergence / difference. |
| | Args: |
| | p_samples: Positive samples. |
| | measure: Measure to compute for. |
| | average: Average the result over samples. |
| | Returns: |
| | torch.Tensor |
| | """ |
| | log_2 = math.log(2.) |
| | if measure == 'GAN': |
| | Ep = - F.softplus(-p_samples) |
| | elif measure == 'JSD': |
| | Ep = log_2 - F.softplus(-p_samples) |
| | elif measure == 'X2': |
| | Ep = p_samples ** 2 |
| | elif measure == 'KL': |
| | Ep = p_samples + 1. |
| | elif measure == 'RKL': |
| | Ep = -torch.exp(-p_samples) |
| | elif measure == 'DV': |
| | Ep = p_samples |
| | elif measure == 'H2': |
| | Ep = torch.ones_like(p_samples) - torch.exp(-p_samples) |
| | elif measure == 'W1': |
| | Ep = p_samples |
| | else: |
| | raise ValueError('Unknown measurement {}'.format(measure)) |
| | if average: |
| | return Ep.mean() |
| | else: |
| | return Ep |
| |
|
| |
|
| | def get_negative_expectation(q_samples, measure='JSD', average=True): |
| | """ |
| | Computes the negative part of a divergence / difference. |
| | Args: |
| | q_samples: Negative samples. |
| | measure: Measure to compute for. |
| | average: Average the result over samples. |
| | Returns: |
| | torch.Tensor |
| | """ |
| | log_2 = math.log(2.) |
| | if measure == 'GAN': |
| | Eq = F.softplus(-q_samples) + q_samples |
| | elif measure == 'JSD': |
| | Eq = F.softplus(-q_samples) + q_samples - log_2 |
| | elif measure == 'X2': |
| | Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) |
| | elif measure == 'KL': |
| | Eq = torch.exp(q_samples) |
| | elif measure == 'RKL': |
| | Eq = q_samples - 1. |
| | elif measure == 'DV': |
| | Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) |
| | elif measure == 'H2': |
| | Eq = torch.exp(q_samples) - 1. |
| | elif measure == 'W1': |
| | Eq = q_samples |
| | else: |
| | raise ValueError('Unknown measurement {}'.format(measure)) |
| | if average: |
| | return Eq.mean() |
| | else: |
| | return Eq |
| |
|
| |
|
| | def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'): |
| | """ |
| | QV-CL module |
| | Computing the Contrastive Loss between the video and query. |
| | :param video: video rep (bsz, Lv, dim) |
| | :param query: query rep (bsz, dim) |
| | :param match_labels: match labels (bsz, Lv) |
| | :param mask: mask (bsz, Lv) |
| | :param measure: estimator of the mutual information |
| | :return: L_{qv} |
| | """ |
| | |
| | pos_mask = match_labels.type(torch.float32) |
| | neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
| |
|
| | |
| | query = query.unsqueeze(2) |
| | res = torch.matmul(video, query).squeeze(2) |
| |
|
| | |
| | E_pos = get_positive_expectation(res * pos_mask, measure, average=False) |
| | E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
| |
|
| | |
| | E_neg = get_negative_expectation(res * neg_mask, measure, average=False) |
| | E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
| |
|
| | E = E_neg - E_pos |
| | |
| | return E |
| |
|
| |
|
| | def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'): |
| | """ |
| | VV-CL module |
| | Computing the Contrastive loss between the start/end clips and the video |
| | :param video: video rep (bsz, Lv, dim) |
| | :param st_ed_indices: (bsz, 2) |
| | :param match_labels: match labels (bsz, Lv) |
| | :param mask: mask (bsz, Lv) |
| | :param measure: estimator of the mutual information |
| | :return: L_{vv} |
| | """ |
| | |
| | pos_mask = match_labels.type(torch.float32) |
| | neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
| |
|
| | |
| | st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] |
| | batch_indices = torch.arange(0, video.shape[0]).long() |
| | video_s = video[batch_indices, st_indices, :] |
| | video_e = video[batch_indices, ed_indices, :] |
| |
|
| | |
| | video_s = video_s.unsqueeze(2) |
| | res_s = torch.matmul(video, video_s).squeeze(2) |
| | video_e = video_e.unsqueeze(2) |
| | res_e = torch.matmul(video, video_e).squeeze(2) |
| |
|
| | |
| | E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False) |
| | E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
| | |
| | E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False) |
| | E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
| | E_pos = E_s_pos + E_e_pos |
| |
|
| | |
| | E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False) |
| | E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
| |
|
| | |
| | E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False) |
| | E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
| | E_neg = E_s_neg + E_e_neg |
| |
|
| | E = E_neg - E_pos |
| | return torch.mean(E) |
| |
|