|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" BLEURT metric. """ |
|
|
|
import os |
|
|
|
import datasets |
|
from bleurt import score |
|
|
|
import evaluate |
|
|
|
|
|
logger = evaluate.logging.get_logger(__name__) |
|
|
|
|
|
_CITATION = """\ |
|
@inproceedings{bleurt, |
|
title={BLEURT: Learning Robust Metrics for Text Generation}, |
|
author={Thibault Sellam and Dipanjan Das and Ankur P. Parikh}, |
|
booktitle={ACL}, |
|
year={2020}, |
|
url={https://arxiv.org/abs/2004.04696} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
BLEURT a learnt evaluation metric for Natural Language Generation. It is built using multiple phases of transfer learning starting from a pretrained BERT model (Devlin et al. 2018) |
|
and then employing another pre-training phrase using synthetic data. Finally it is trained on WMT human annotations. You may run BLEURT out-of-the-box or fine-tune |
|
it for your specific application (the latter is expected to perform better). |
|
|
|
See the project's README at https://github.com/google-research/bleurt#readme for more information. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
BLEURT score. |
|
|
|
Args: |
|
`predictions` (list of str): prediction/candidate sentences |
|
`references` (list of str): reference sentences |
|
`checkpoint` BLEURT checkpoint. Will default to BLEURT-tiny if None. |
|
|
|
Returns: |
|
'scores': List of scores. |
|
Examples: |
|
|
|
>>> predictions = ["hello there", "general kenobi"] |
|
>>> references = ["hello there", "general kenobi"] |
|
>>> bleurt = evaluate.load("bleurt") |
|
>>> results = bleurt.compute(predictions=predictions, references=references) |
|
>>> print([round(v, 2) for v in results["scores"]]) |
|
[1.03, 1.04] |
|
""" |
|
|
|
CHECKPOINT_URLS = { |
|
"bleurt-tiny-128": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-128.zip", |
|
"bleurt-tiny-512": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-512.zip", |
|
"bleurt-base-128": "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip", |
|
"bleurt-base-512": "https://storage.googleapis.com/bleurt-oss/bleurt-base-512.zip", |
|
"bleurt-large-128": "https://storage.googleapis.com/bleurt-oss/bleurt-large-128.zip", |
|
"bleurt-large-512": "https://storage.googleapis.com/bleurt-oss/bleurt-large-512.zip", |
|
"BLEURT-20-D3": "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20-D3.zip", |
|
"BLEURT-20-D6": "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20-D6.zip", |
|
"BLEURT-20-D12": "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20-D12.zip", |
|
"BLEURT-20": "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip", |
|
} |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class BLEURT(evaluate.Metric): |
|
def _info(self): |
|
|
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
homepage="https://github.com/google-research/bleurt", |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Value("string", id="sequence"), |
|
"references": datasets.Value("string", id="sequence"), |
|
} |
|
), |
|
codebase_urls=["https://github.com/google-research/bleurt"], |
|
reference_urls=["https://github.com/google-research/bleurt", "https://arxiv.org/abs/2004.04696"], |
|
) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
|
|
|
|
if self.config_name == "default": |
|
logger.warning( |
|
"Using default BLEURT-Base checkpoint for sequence maximum length 128. " |
|
"You can use a bigger model for better results with e.g.: evaluate.load('bleurt', 'bleurt-large-512')." |
|
) |
|
self.config_name = "bleurt-base-128" |
|
|
|
if self.config_name.lower() in CHECKPOINT_URLS: |
|
checkpoint_name = self.config_name.lower() |
|
|
|
elif self.config_name.upper() in CHECKPOINT_URLS: |
|
checkpoint_name = self.config_name.upper() |
|
|
|
else: |
|
raise KeyError( |
|
f"{self.config_name} model not found. You should supply the name of a model checkpoint for bleurt in {CHECKPOINT_URLS.keys()}" |
|
) |
|
|
|
|
|
model_path = dl_manager.download_and_extract(CHECKPOINT_URLS[checkpoint_name]) |
|
self.scorer = score.BleurtScorer(os.path.join(model_path, checkpoint_name)) |
|
|
|
def _compute(self, predictions, references): |
|
scores = self.scorer.score(references=references, candidates=predictions) |
|
return {"scores": scores} |
|
|