-
Notifications
You must be signed in to change notification settings - Fork 0
/
auto_evaluation.py
69 lines (52 loc) · 1.92 KB
/
auto_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Automatic evaluation on generation results
"""
import argparse
import json
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
from nltk.tokenize.casual import casual_tokenize
from nltk.util import ngrams
from tabulate import tabulate
from tqdm import tqdm
def calculate_dist_metric(responses: List[str], n: int) -> float:
tokenized_responses = [casual_tokenize(resp) for resp in responses]
num_all_ngrams = 0
all_ngram_set = set()
for tokens in tokenized_responses:
token_ngrams = list(ngrams(tokens, n))
num_all_ngrams += len(token_ngrams)
all_ngram_set.update(token_ngrams)
return len(all_ngram_set) / num_all_ngrams
def calculate_average_length(responses: List[str]) -> float:
tokenized_responses = [casual_tokenize(resp) for resp in responses]
return np.mean([len(tokens) for tokens in tokenized_responses])
def evaluate_single_result(results_path,
dist_n_list):
with open(results_path) as f:
examples = [json.loads(line.strip()) for line in f]
responses = [ex["response"] for ex in examples]
metrics = {
"avg_length": calculate_average_length(responses),
}
for n in dist_n_list:
metrics[f"dist_{n}"] = calculate_dist_metric(responses, n)
return metrics
def main(args):
all_metrics = {}
for result_path in tqdm(args.result_paths):
model_name = Path(result_path).stem
all_metrics[model_name] = evaluate_single_result(
result_path,
args.dist_n_list,
)
all_metrics_df = pd.DataFrame(all_metrics).transpose()
print(tabulate(all_metrics_df, headers="keys"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--result-paths", type=str, nargs="+")
parser.add_argument("--dist-n-list", type=int, nargs="*", default=[2, 3])
args = parser.parse_args()
main(args)