Skip to content

Commit

Permalink
[Feature] Add lawbench (#460)
Browse files Browse the repository at this point in the history
* add lawbench

* update requirements

* update
  • Loading branch information
Leymore authored Oct 13, 2023
1 parent fbf5089 commit 861942a
Show file tree
Hide file tree
Showing 40 changed files with 3,639 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ exclude: |
tests/data/|
opencompass/models/internal/|
opencompass/utils/internal/|
opencompass/openicl/icl_evaluator/hf_metrics/
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/
)
repos:
- repo: https://github.com/PyCQA/flake8
Expand Down
62 changes: 62 additions & 0 deletions configs/datasets/lawbench/lawbench_one_shot_gen_002588.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import LawBenchDataset

names = [
["1-1", "article_recitation"],
["1-2", "knowledge_question_answering"],
["2-1", "document_proofreading"],
["2-2", "dispute_focus_identification"],
["2-3", "marital_disputes_identification"],
["2-4", "issue_topic_identification"],
["2-5", "reading_comprehension"],
["2-6", "named_entity_recognition"],
["2-7", "opinion_summarization"],
["2-8", "argument_mining"],
["2-9", "event_detection"],
["2-10", "trigger_word_extraction"],
["3-1", "fact_based_article_prediction"],
["3-2", "scene_based_article_prediction"],
["3-3", "charge_prediction"],
["3-4", "prison_term_prediction_wo_article"],
["3-5", "prison_term_prediction_w_article"],
["3-6", "case_analysis"],
["3-7", "criminal_damages_calculation"],
["3-8", "consultation"],
]

lawbench_datasets = []
for index, name in names:
lawbench_reader_cfg = dict(
input_columns=['instruction', 'question'],
output_column='answer')

lawbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role="HUMAN", prompt="{instruction}\n{question}"),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)
)

lawbench_eval_cfg = dict(
evaluator=dict(type='LawBenchEvaluator_' + index.replace('-', '_'))
)

lawbench_datasets.append(
dict(
abbr='lawbench-' + index + '-' + name + '-1-shot',
type=LawBenchDataset,
path='./data/lawbench/one_shot',
index=index,
reader_cfg=lawbench_reader_cfg,
infer_cfg=lawbench_infer_cfg,
eval_cfg=lawbench_eval_cfg
)
)
62 changes: 62 additions & 0 deletions configs/datasets/lawbench/lawbench_zero_shot_gen_002588.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import LawBenchDataset

names = [
["1-1", "article_recitation"],
["1-2", "knowledge_question_answering"],
["2-1", "document_proofreading"],
["2-2", "dispute_focus_identification"],
["2-3", "marital_disputes_identification"],
["2-4", "issue_topic_identification"],
["2-5", "reading_comprehension"],
["2-6", "named_entity_recognition"],
["2-7", "opinion_summarization"],
["2-8", "argument_mining"],
["2-9", "event_detection"],
["2-10", "trigger_word_extraction"],
["3-1", "fact_based_article_prediction"],
["3-2", "scene_based_article_prediction"],
["3-3", "charge_prediction"],
["3-4", "prison_term_prediction_wo_article"],
["3-5", "prison_term_prediction_w_article"],
["3-6", "case_analysis"],
["3-7", "criminal_damages_calculation"],
["3-8", "consultation"],
]

lawbench_datasets = []
for index, name in names:
lawbench_reader_cfg = dict(
input_columns=['instruction', 'question'],
output_column='answer')

lawbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role="HUMAN", prompt="{instruction}\n{question}"),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024)
)

lawbench_eval_cfg = dict(
evaluator=dict(type='LawBenchEvaluator_' + index.replace('-', '_'))
)

lawbench_datasets.append(
dict(
abbr='lawbench-' + index + '-' + name + '-0-shot',
type=LawBenchDataset,
path='./data/lawbench/zero_shot',
index=index,
reader_cfg=lawbench_reader_cfg,
infer_cfg=lawbench_infer_cfg,
eval_cfg=lawbench_eval_cfg
)
)
11 changes: 11 additions & 0 deletions configs/eval_qwen_7b_chat_lawbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from mmengine.config import read_base

with read_base():
from .models.qwen.hf_qwen_7b_chat import models
from .datasets.lawbench.lawbench_zero_shot_gen_002588 import lawbench_datasets as lawbench_zero_shot_datasets
from .datasets.lawbench.lawbench_one_shot_gen_002588 import lawbench_datasets as lawbench_one_shot_datasets
from .summarizers.lawbench import summarizer

datasets = lawbench_zero_shot_datasets + lawbench_one_shot_datasets
for d in datasets:
d["infer_cfg"]["inferencer"]["save_every"] = 1
29 changes: 29 additions & 0 deletions configs/summarizers/groups/lawbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
names = [
["1-1", "article_recitation"],
["1-2", "knowledge_question_answering"],
["2-1", "document_proofreading"],
["2-2", "dispute_focus_identification"],
["2-3", "marital_disputes_identification"],
["2-4", "issue_topic_identification"],
["2-5", "reading_comprehension"],
["2-6", "named_entity_recognition"],
["2-7", "opinion_summarization"],
["2-8", "argument_mining"],
["2-9", "event_detection"],
["2-10", "trigger_word_extraction"],
["3-1", "fact_based_article_prediction"],
["3-2", "scene_based_article_prediction"],
["3-3", "charge_prediction"],
["3-4", "prison_term_prediction_wo_article"],
["3-5", "prison_term_prediction_w_article"],
["3-6", "case_analysis"],
["3-7", "criminal_damages_calculation"],
["3-8", "consultation"],
]

lawbench_summary_groups = []

_lawbench_0_shot = ['lawbench-' + index + '-' + name + '-0-shot' for index, name in names]
lawbench_summary_groups.append({'name': 'lawbench-0-shot', 'subsets': _lawbench_0_shot})
_lawbench_1_shot = ['lawbench-' + index + '-' + name + '-1-shot' for index, name in names]
lawbench_summary_groups.append({'name': 'lawbench-1-shot', 'subsets': _lawbench_1_shot})
58 changes: 58 additions & 0 deletions configs/summarizers/lawbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from mmengine.config import read_base

with read_base():
from .groups.lawbench import lawbench_summary_groups

summarizer = dict(
dataset_abbrs = [
'--------- 0-shot ---------', # category
'lawbench-0-shot',
'lawbench-1-1-article_recitation-0-shot',
'lawbench-1-2-knowledge_question_answering-0-shot',
'lawbench-2-1-document_proofreading-0-shot',
'lawbench-2-2-dispute_focus_identification-0-shot',
'lawbench-2-3-marital_disputes_identification-0-shot',
'lawbench-2-4-issue_topic_identification-0-shot',
'lawbench-2-5-reading_comprehension-0-shot',
'lawbench-2-6-named_entity_recognition-0-shot',
'lawbench-2-7-opinion_summarization-0-shot',
'lawbench-2-8-argument_mining-0-shot',
'lawbench-2-9-event_detection-0-shot',
'lawbench-2-10-trigger_word_extraction-0-shot',
'lawbench-3-1-fact_based_article_prediction-0-shot',
'lawbench-3-2-scene_based_article_prediction-0-shot',
'lawbench-3-3-charge_prediction-0-shot',
'lawbench-3-4-prison_term_prediction_wo_article-0-shot',
'lawbench-3-5-prison_term_prediction_w_article-0-shot',
'lawbench-3-6-case_analysis-0-shot',
'lawbench-3-7-criminal_damages_calculation-0-shot',
'lawbench-3-8-consultation-0-shot',
'--------- 1-shot ---------', # category
'lawbench-1-shot',
'lawbench-1-1-article_recitation-1-shot',
'lawbench-1-2-knowledge_question_answering-1-shot',
'lawbench-2-1-document_proofreading-1-shot',
'lawbench-2-2-dispute_focus_identification-1-shot',
'lawbench-2-3-marital_disputes_identification-1-shot',
'lawbench-2-4-issue_topic_identification-1-shot',
'lawbench-2-5-reading_comprehension-1-shot',
'lawbench-2-6-named_entity_recognition-1-shot',
'lawbench-2-7-opinion_summarization-1-shot',
'lawbench-2-8-argument_mining-1-shot',
'lawbench-2-9-event_detection-1-shot',
'lawbench-2-10-trigger_word_extraction-1-shot',
'lawbench-3-1-fact_based_article_prediction-1-shot',
'lawbench-3-2-scene_based_article_prediction-1-shot',
'lawbench-3-3-charge_prediction-1-shot',
'lawbench-3-4-prison_term_prediction_wo_article-1-shot',
'lawbench-3-5-prison_term_prediction_w_article-1-shot',
'lawbench-3-6-case_analysis-1-shot',
'lawbench-3-7-criminal_damages_calculation-1-shot',
'lawbench-3-8-consultation-1-shot',
],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
prompt_db=dict(
database_path='configs/datasets/log.json',
config_dir='configs/datasets',
blacklist='.promptignore'),
)
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .jigsawmultilingual import * # noqa: F401, F403
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
from .lambada import * # noqa: F401, F403
from .lawbench import * # noqa: F401, F403
from .lcsts import * # noqa: F401, F403
from .leval import * # noqa: F401, F403
from .longbench import * # noqa: F401, F403
Expand Down
1 change: 1 addition & 0 deletions opencompass/datasets/lawbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lawbench import LawBenchDataset # noqa: F401
19 changes: 19 additions & 0 deletions opencompass/datasets/lawbench/evaluation_functions/cjft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ..utils.function_utils import compute_rouge

#情景法条识别

def compute_cjft(data_dict):
"""
Compute the ROUGE-L score between the prediction and the reference
"""
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
predictions.append(prediction)
references.append(answer)

# compute the accuracy of score_list
rouge_scores = compute_rouge(predictions, references)
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
return {"score": average_rouge_l}
18 changes: 18 additions & 0 deletions opencompass/datasets/lawbench/evaluation_functions/flzx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..utils.function_utils import compute_rouge

#法律咨询
def compute_flzx(data_dict):
"""
Compute the ROUGE-L score between the prediction and the reference
"""
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
predictions.append(prediction)
references.append(answer)

# compute the accuracy of score_list
rouge_scores = compute_rouge(predictions, references)
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
return {"score": average_rouge_l}
19 changes: 19 additions & 0 deletions opencompass/datasets/lawbench/evaluation_functions/ftcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ..utils.function_utils import compute_rouge

#法条记忆问答
def compute_ftcs(data_dict):
"""
Compute the ROUGE-L score between the prediction and the reference
"""
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
answer = answer.replace("答案:", "")
predictions.append(prediction)
references.append(answer)

# compute the accuracy of score_list
rouge_scores = compute_rouge(predictions, references)
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
return {"score": average_rouge_l}
36 changes: 36 additions & 0 deletions opencompass/datasets/lawbench/evaluation_functions/jdzy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from ..utils.function_utils import multi_choice_judge

"""
multi-choice single-label selection
metric: accuracy
争议焦点:识别案件涉及的争议焦点
"""

def compute_jdzy(data_dict):
"""
Compute the Accuracy
The JEC dataset has 16 possible answers for each question, stored in the option_list
A prediction is correct if
1. The correct answer appears in the prediction, and
2. Options other than the answer do not appear in the prediction.
"""

score_list, abstentions = [], 0
option_list = ["诉讼主体", "租金情况", "利息", "本金争议", "责任认定", "责任划分", "损失认定及处理",
"原审判决是否适当", "合同效力", "财产分割", "责任承担", "鉴定结论采信问题", "诉讼时效", "违约", "合同解除", "肇事逃逸"]
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
if answer[7:-1] == "赔偿":
# todo: dataset imperfection
continue
assert answer.startswith("争议焦点类别:") and answer[7:-1] in option_list, \
f"answer: {answer} \n question: {question}"

answer_letter = answer[7:-1]
judge = multi_choice_judge(prediction, option_list, answer_letter)
score_list.append(judge["score"])
abstentions += judge["abstention"]

# compute the accuracy of score_list
accuracy = sum(score_list) / len(score_list)
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
29 changes: 29 additions & 0 deletions opencompass/datasets/lawbench/evaluation_functions/jec_ac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from ..utils.function_utils import multi_choice_judge

"""
Task: multi-choice selection
Metric: Accuracy
司法考试-案例分析
"""
def compute_jec_ac(data_dict):
"""
Compute the Accuracy
The JEC dataset has 4 options for each question: A, B, C, D
A prediction is correct if
1. The correct answer appears in the prediction, and
2. Options other than the answer do not appear in the prediction.
"""
score_list, abstentions = [], 0
option_list = ["A", "B", "C", "D"]
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
assert answer.startswith("正确答案:") and answer[5] in option_list, f"answer[5]: {answer}, question: {question}"

answer_letter = answer[5]
judge = multi_choice_judge(prediction, option_list, answer_letter)
score_list.append(judge["score"])
abstentions += judge["abstention"]

# compute the accuracy of score_list
accuracy = sum(score_list) / len(score_list)
return {"score": accuracy, "abstention_rate": abstentions / len(data_dict)}
Loading

0 comments on commit 861942a

Please sign in to comment.