Skip to content

Commit

Permalink
[metric] Add a new metrics DISTINCT. (PaddlePaddle#183)
Browse files Browse the repository at this point in the history
* Add new metrics DISTINCT.

* Fixed bugs.

* Fixed bugs.

* Optimize the distinct.
  • Loading branch information
xiemoyuan authored Mar 24, 2021
1 parent d84c2bb commit e079ffc
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
| Mcc(Matthews correlation coefficient) | 马修斯相关系数,用以测量二分类的分类性能的指标。可用于GLUE中的CoLA任务 | `paddlenlp.metrics.Mcc` |
| ChunkEvaluator | 计算了块检测的精确率、召回率和F1-score。常用于序列标记任务,如命名实体识别(NER) | `paddlenlp.metrics.ChunkEvaluator` |
| Squad Evalutaion | 用于SQuAD和DuReader-robust的评价指标 | `paddlenlp.metrics.compute_predictions`, `paddlenlp.metrics.squad_evaluate` |
| [Distinct](https://arxiv.org/abs/1510.03055) | 多样性指标,常用来衡量文本生成模型生成的句子形式上的多样性。 | `paddlenlp.metrics.Distinct` |
1 change: 1 addition & 0 deletions paddlenlp/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .bleu import BLEU, BLEUForDuReader
from .rouge import RougeL, RougeLForDuReader, RougeN, Rouge1, Rouge2
from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman
from .distinct import Distinct
132 changes: 132 additions & 0 deletions paddlenlp/metrics/distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle

__all__ = ['Distinct']


class Distinct(paddle.metric.Metric):
"""
Distinct is an algorithm for evaluating the textual diversity of the
generated text by calculating the number of distinct n-grams. The larger
the value of n-grams, the higher the diversity of the text. See detail at
https://arxiv.org/abs/1510.03055
`Distinct` could be used as `paddle.metric.Metric` class, or an ordinary
class. When `Distinct` is used as `paddle.metric.Metric` class. A function is
needed that transforms the network output to string list. It should be noted
that the `Distinct` here is different from the `Distinct` calculated in
prediction, and it is only for observation during training and evaluation.
Args:
trans_func (callable, optional): `trans_func` transforms the network
output to string list. Default None. When `Distinct` is used as
`paddle.metric.Metric` class, `trans_func` must be provided. Please
note that the input of `trans_func` is numpy array.
n_size (int, optional): Number of gram for `Distinct` metric. Default: 2.
name (str, optional): Name of `paddle.metric.Metric` instance.
Default: "distinct".
Examples:
1. Using as a general evaluation object.
.. code-block:: python
from paddlenlp.metrics import Distinct
distinct = Distinct()
cand = ["The","cat","The","cat","on","the","mat"]
distinct.add_inst(cand)
print(distinct.score()) # 0.8333333333333334
2. Using as an instance of `paddle.metric.Metric`.
.. code-block:: python
import numpy as np
from functools import partial
import paddle
from paddlenlp.transformers import BertTokenizer
from paddlenlp.metrics import Distinct
def trans_func(logits, tokenizer):
'''Transform the network output `logits` to string list.'''
# [batch_size, seq_len]
token_ids = np.argmax(logits, axis=-1).tolist()
cand_list = []
for ids in token_ids:
tokens = tokenizer.convert_ids_to_tokens(ids)
strings = tokenizer.convert_tokens_to_string(tokens)
cand_list.append(strings.split())
return cand_list
paddle.seed(2021)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
distinct = Distinct(trans_func=partial(trans_func, tokenizer=tokenizer))
batch_size, seq_len, vocab_size = 4, 16, tokenizer.vocab_size
logits = paddle.rand([batch_size, seq_len, vocab_size])
distinct.update(logits.numpy())
print(distinct.accumulate()) # 1.0
"""

def __init__(self, n_size=2, trans_func=None, name="distinct"):
super(Distinct, self).__init__()
self._name = name
self.diff_ngram = set()
self.count = 0.0
self.n_size = n_size
self.trans_func = trans_func

def update(self, output, *args):
"""
Update the metrics states. This method firstly will use `trans_func` to
process the `output` to get the tokenized candidate sentence list. Then
call `add_inst` to process the candidate list one by one.
"""
if isinstance(output, paddle.Tensor):
output = output.numpy()

assert self.trans_func is not None, "The `update` method requires user "\
"to provide `trans_func` when initializing `Distinct`."
cand_list = self.trans_func(output)

for cand in cand_list:
self.add_inst(cand)

def add_inst(self, cand):
"""
Update the states based on the candidate.
Args:
cand (list): Tokenized candidate sentence generated by model.
"""
for i in range(0, len(cand) - self.n_size + 1):
ngram = ' '.join(cand[i:(i + self.n_size)])
self.count += 1
self.diff_ngram.add(ngram)

def reset(self):
self.diff_ngram = set()
self.count = 0.0

def accumulate(self):
"""Calculate the final distinct metric."""
distinct = len(self.diff_ngram) / self.count
return distinct

def score(self):
return self.accumulate()

def name(self):
return self._name

0 comments on commit e079ffc

Please sign in to comment.