| | from collections import defaultdict |
| | import logging |
| | import sys |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EvalCounts(): |
| | """This class is evaluating counters |
| | """ |
| | def __init__(self): |
| | self.pred_correct_cnt = 0 |
| | self.correct_cnt = 0 |
| | self.pred_cnt = 0 |
| |
|
| | self.pred_correct_types_cnt = defaultdict(int) |
| | self.correct_types_cnt = defaultdict(int) |
| | self.pred_types_cnt = defaultdict(int) |
| |
|
| |
|
| | def eval_file(file_path, eval_metrics): |
| | """eval_file evaluates results file |
| | |
| | Args: |
| | file_path (str): file path |
| | eval_metrics (list): eval metrics |
| | |
| | Returns: |
| | tuple: results |
| | """ |
| |
|
| | with open(file_path, 'r') as fin: |
| | sents = [] |
| | metric2labels = { |
| | 'token': ['Sequence-Label-True', 'Sequence-Label-Pred'], |
| | 'ent-label': ['Ent-Label-True', 'Ent-Label-Pred'], |
| | 'rel-label': ['Rel-Label-True', 'Rel-Label-Pred'], |
| | 'separate-position': ['Separate-Position-True', 'Separate-Position-Pred'], |
| | 'span': ['Ent-Span-Pred'], |
| | 'ent': ['Ent-True', 'Ent-Pred'], |
| | 'rel': ['Rel-True', 'Rel-Pred'], |
| | 'exact-rel': ['Rel-True', 'Rel-Pred'] |
| | } |
| | labels = set() |
| | for metric in eval_metrics: |
| | labels.update(metric2labels[metric]) |
| | label2idx = {label: idx for idx, label in enumerate(labels)} |
| | sent = [[] for _ in range(len(labels))] |
| | for line in fin: |
| | line = line.strip('\r\n') |
| | if line == "": |
| | sents.append(sent) |
| | sent = [[] for _ in range(len(labels))] |
| | else: |
| | words = line.split('\t') |
| | if words[0] in ['Ent-Label-True', 'Ent-Label-Pred', 'Rel-Label-True', 'Rel-Label-Pred']: |
| | sent[label2idx[words[0]]].extend(words[1].split(' ')) |
| | elif words[0] in ['Separate-Position-True', 'Separate-Position-Pred']: |
| | sent[label2idx[words[0]]].append(words[1].split(' ')) |
| | elif words[0] in ['Ent-Span-Pred']: |
| | sent[label2idx[words[0]]].append(eval(words[1])) |
| | elif words[0] in ['Ent-True', 'Ent-Pred']: |
| | sent[label2idx[words[0]]].append([words[1], eval(words[2])]) |
| | elif words[0] in ['Rel-True', 'Rel-Pred']: |
| | sent[label2idx[words[0]]].append([words[1], eval(words[2]), eval(words[3])]) |
| | sents.append(sent) |
| |
|
| | counts = {metric: EvalCounts() for metric in eval_metrics} |
| |
|
| | for sent in sents: |
| | evaluate(sent, counts, label2idx) |
| |
|
| | results = [] |
| |
|
| | logger.info("-" * 22 + "START" + "-" * 23) |
| |
|
| | for metric, count in counts.items(): |
| | left_offset = (50 - len(metric)) // 2 |
| | logger.info("-" * left_offset + metric + "-" * (50 - left_offset - len(metric))) |
| | score = report(count) |
| | results += [score] |
| |
|
| | logger.info("-" * 23 + "END" + "-" * 24) |
| |
|
| | return results |
| |
|
| |
|
| | def evaluate(sent, counts, label2idx): |
| | """evaluate calculates counters |
| | |
| | Arguments: |
| | sent {list} -- line |
| | |
| | Args: |
| | sent (list): line |
| | counts (dict): counts |
| | label2idx (dict): label -> idx dict |
| | """ |
| |
|
| | |
| | if 'token' in counts: |
| | for token1, token2 in zip(sent[label2idx['Sequence-Label-True']], sent[label2idx['Sequence-Label-Pred']]): |
| | if token1 != 'O': |
| | counts['token'].correct_cnt += 1 |
| | counts['token'].correct_types_cnt[token1] += 1 |
| | counts['token'].pred_correct_types_cnt[token1] += 0 |
| | if token2 != 'O': |
| | counts['token'].pred_cnt += 1 |
| | counts['token'].pred_types_cnt[token2] += 1 |
| | counts['token'].pred_correct_types_cnt[token2] += 0 |
| | if token1 == token2 and token1 != 'O': |
| | counts['token'].pred_correct_cnt += 1 |
| | counts['token'].pred_correct_types_cnt[token1] += 1 |
| |
|
| | |
| | if 'ent-label' in counts: |
| | for label1, label2 in zip(sent[label2idx['Ent-Label-True']], sent[label2idx['Ent-Label-Pred']]): |
| | if label1 != 'None': |
| | counts['ent-label'].correct_cnt += 1 |
| | counts['ent-label'].correct_types_cnt['Arc'] += 1 |
| | counts['ent-label'].correct_types_cnt[label1] += 1 |
| | counts['ent-label'].pred_correct_types_cnt[label1] += 0 |
| | if label2 != 'None': |
| | counts['ent-label'].pred_cnt += 1 |
| | counts['ent-label'].pred_types_cnt['Arc'] += 1 |
| | counts['ent-label'].pred_types_cnt[label2] += 1 |
| | counts['ent-label'].pred_correct_types_cnt[label2] += 0 |
| | if label1 != 'None' and label2 != 'None': |
| | counts['ent-label'].pred_correct_types_cnt['Arc'] += 1 |
| | if label1 == label2 and label1 != 'None': |
| | counts['ent-label'].pred_correct_cnt += 1 |
| | counts['ent-label'].pred_correct_types_cnt[label1] += 1 |
| |
|
| | |
| | if 'separate-position' in counts: |
| | for positions1, positions2 in zip(sent[label2idx['Separate-Position-True']], |
| | sent[label2idx['Separate-Position-Pred']]): |
| | counts['separate-position'].correct_cnt += len(positions1) |
| | counts['separate-position'].pred_cnt += len(positions2) |
| | counts['separate-position'].pred_correct_cnt += len(set(positions1) & set(positions2)) |
| |
|
| | |
| | correct_ent2idx = defaultdict(set) |
| | correct_span2ent = dict() |
| | correct_span = set() |
| | for ent, span in sent[label2idx['Ent-True']]: |
| | correct_span.add(span) |
| | correct_span2ent[span] = ent |
| | correct_ent2idx[ent].add(span) |
| |
|
| | pred_ent2idx = defaultdict(set) |
| | pred_span2ent = dict() |
| | for ent, span in sent[label2idx['Ent-Pred']]: |
| | pred_span2ent[span] = ent |
| | pred_ent2idx[ent].add(span) |
| |
|
| | if 'span' in counts: |
| | pred_span = set(sent[label2idx['Ent-Span-Pred']]) |
| | counts['span'].correct_cnt += len(correct_span) |
| | counts['span'].pred_cnt += len(pred_span) |
| | counts['span'].pred_correct_cnt += len(correct_span & pred_span) |
| |
|
| | if 'ent' in counts: |
| | |
| | all_ents = set(correct_ent2idx) | set(pred_ent2idx) |
| | for ent in all_ents: |
| | counts['ent'].correct_cnt += len(correct_ent2idx[ent]) |
| | counts['ent'].correct_types_cnt[ent] += len(correct_ent2idx[ent]) |
| | counts['ent'].pred_cnt += len(pred_ent2idx[ent]) |
| | counts['ent'].pred_types_cnt[ent] += len(pred_ent2idx[ent]) |
| | pred_correct_cnt = len(correct_ent2idx[ent] & pred_ent2idx[ent]) |
| | counts['ent'].pred_correct_cnt += pred_correct_cnt |
| | counts['ent'].pred_correct_types_cnt[ent] += pred_correct_cnt |
| |
|
| | |
| | if 'rel-label' in counts: |
| | for label1, label2 in zip(sent[label2idx['Rel-Label-True']], sent[label2idx['Rel-Label-Pred']]): |
| | if label1 != 'None': |
| | counts['rel-label'].correct_cnt += 1 |
| | counts['rel-label'].correct_types_cnt['Arc'] += 1 |
| | counts['rel-label'].correct_types_cnt[label1] += 1 |
| | counts['rel-label'].pred_correct_types_cnt[label1] += 0 |
| | if label2 != 'None': |
| | counts['rel-label'].pred_cnt += 1 |
| | counts['rel-label'].pred_types_cnt['Arc'] += 1 |
| | counts['rel-label'].pred_types_cnt[label2] += 1 |
| | counts['rel-label'].pred_correct_types_cnt[label2] += 0 |
| | if label1 != 'None' and label2 != 'None': |
| | counts['rel-label'].pred_correct_types_cnt['Arc'] += 1 |
| | if label1 == label2 and label1 != 'None': |
| | counts['rel-label'].pred_correct_cnt += 1 |
| | counts['rel-label'].pred_correct_types_cnt[label1] += 1 |
| |
|
| | |
| | if 'exact-rel' in counts: |
| | exact_correct_rel2idx = defaultdict(set) |
| | for rel, span1, span2 in sent[label2idx['Rel-True']]: |
| | if span1 not in correct_span2ent or span2 not in correct_span2ent: |
| | continue |
| | exact_correct_rel2idx[rel].add((span1, correct_span2ent[span1], span2, correct_span2ent[span2])) |
| |
|
| | exact_pred_rel2idx = defaultdict(set) |
| | for rel, span1, span2 in sent[label2idx['Rel-Pred']]: |
| | if span1 not in pred_span2ent or span2 not in pred_span2ent: |
| | continue |
| | exact_pred_rel2idx[rel].add((span1, pred_span2ent[span1], span2, pred_span2ent[span2])) |
| |
|
| | all_exact_rels = set(exact_correct_rel2idx) | set(exact_pred_rel2idx) |
| | for rel in all_exact_rels: |
| | counts['exact-rel'].correct_cnt += len(exact_correct_rel2idx[rel]) |
| | counts['exact-rel'].correct_types_cnt[rel] += len(exact_correct_rel2idx[rel]) |
| | counts['exact-rel'].pred_cnt += len(exact_pred_rel2idx[rel]) |
| | counts['exact-rel'].pred_types_cnt[rel] += len(exact_pred_rel2idx[rel]) |
| | exact_pred_correct_rel_cnt = len(exact_correct_rel2idx[rel] & exact_pred_rel2idx[rel]) |
| | counts['exact-rel'].pred_correct_cnt += exact_pred_correct_rel_cnt |
| | counts['exact-rel'].pred_correct_types_cnt[rel] += exact_pred_correct_rel_cnt |
| |
|
| | def report(counts): |
| | """This function print evaluation results |
| | |
| | Arguments: |
| | counts {dict} -- counters |
| | |
| | Returns: |
| | float -- f1 score |
| | """ |
| |
|
| | p, r, f = calculate_metrics(counts.pred_correct_cnt, counts.pred_cnt, counts.correct_cnt) |
| | logger.info("truth cnt: {} pred cnt: {} correct cnt: {}".format(counts.correct_cnt, counts.pred_cnt, |
| | counts.pred_correct_cnt)) |
| | logger.info("precision: {:6.2f}%".format(100 * p)) |
| | logger.info("recall: {:6.2f}%".format(100 * r)) |
| | logger.info("f1: {:6.2f}%".format(100 * f)) |
| |
|
| | score = f |
| |
|
| | for type in counts.pred_correct_types_cnt: |
| | p, r, f = calculate_metrics(counts.pred_correct_types_cnt[type], counts.pred_types_cnt[type], |
| | counts.correct_types_cnt[type]) |
| | logger.info("-" * 50) |
| | logger.info("type: {}".format(type)) |
| | logger.info("truth cnt: {} pred cnt: {} correct cnt: {}".format(counts.correct_types_cnt[type], |
| | counts.pred_types_cnt[type], |
| | counts.pred_correct_types_cnt[type])) |
| | logger.info("precision: {:6.2f}%".format(100 * p)) |
| | logger.info("recall: {:6.2f}%".format(100 * r)) |
| | logger.info("f1: {:6.2f}%".format(100 * f)) |
| |
|
| | return score |
| |
|
| |
|
| | def calculate_metrics(pred_correct_cnt, pred_cnt, correct_cnt): |
| | """This function calculation metrics: precision, recall, f1-score |
| | |
| | Arguments: |
| | pred_correct_cnt {int} -- the number of corrected prediction |
| | pred_cnt {int} -- the number of prediction |
| | correct_cnt {int} -- the numbert of truth |
| | |
| | Returns: |
| | tuple -- precision, recall, f1-score |
| | """ |
| |
|
| | tp, fp, fn = pred_correct_cnt, pred_cnt - pred_correct_cnt, correct_cnt - pred_correct_cnt |
| | p = 0 if tp + fp == 0 else (tp / (tp + fp)) |
| | r = 0 if tp + fn == 0 else (tp / (tp + fn)) |
| | f = 0 if p + r == 0 else (2 * p * r / (p + r)) |
| | return p, r, f |
| |
|
| |
|
| | if __name__ == '__main__': |
| | eval_file(sys.argv[1]) |
| |
|