Skip to content

Commit

Permalink
[Feature] Add py150 and maxmin (#562)
Browse files Browse the repository at this point in the history
* [feat] add clozeTesst_maxmin dataset

* [feat] add py150 datasets

* [feat] change __init__.py in opencompass/datasets

* [fix] pre-commit check

* [fix] rename py150 and masxmin datasets in configs

* [feat] add gen.py of py150 and maxmin in configs/datasets
  • Loading branch information
jingmingzhuo authored Nov 9, 2023
1 parent 889a6b2 commit b3cbef3
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 0 deletions.
4 changes: 4 additions & 0 deletions configs/datasets/clozeTest_maxmin/clozeTest_maxmin_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .clozeTest_maxmin_gen_c205fb import maxmin_datasets # noqa: F401, F403
42 changes: 42 additions & 0 deletions configs/datasets/clozeTest_maxmin/clozeTest_maxmin_gen_c205fb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import MaxminDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess


maxmin_reader_cfg = dict(
input_columns=["nl_tokens", "pl_tokens"],
output_column="answer",
)

maxmin_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role="HUMAN", prompt="Code:{pl_tokens}\nThe aim of the code: {nl_tokens}\nQuestion: Please tell me what \"<mask>\" in the code should be replaced with and you must response to me only A or B.\nA. max\nB. min\nAnswer:"),
dict(role="BOT", prompt="{answer}"),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

maxmin_eval_cfg = dict(evaluator=dict(type=AccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess))

maxmin_datasets = [
dict(
type=MaxminDataset,
abbr=f"maxmin",
test_path=f"data/clozeTest-maxmin/python/clozeTest.json",
answer_path=f"data/clozeTest-maxmin/python/answers.txt",
reader_cfg=maxmin_reader_cfg,
infer_cfg=maxmin_infer_cfg,
eval_cfg=maxmin_eval_cfg,
)
]
4 changes: 4 additions & 0 deletions configs/datasets/py150/py150_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .py150_gen_38b13d import py150_datasets # noqa: F401, F403
41 changes: 41 additions & 0 deletions configs/datasets/py150/py150_gen_38b13d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import BleuEvaluator
from opencompass.datasets import Py150Dataset
from opencompass.utils.text_postprocessors import first_capital_postprocess


py150_reader_cfg = dict(
input_columns="input",
output_column="gt",
)

py150_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role="HUMAN", prompt="I will give you a part of python code. Please write down what the next line of code is. Note that you only need to give the next line of code, and you don't need to give any other reply.\nCode:{input}\nNext line:"),
dict(role="BOT", prompt="{gt}"),
]
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

py150_eval_cfg = dict(evaluator=dict(type=BleuEvaluator),
pred_role="BOT"
)

py150_datasets = [
dict(
type=Py150Dataset,
abbr=f"py150",
path=f"data/py150/test.json",
reader_cfg=py150_reader_cfg,
infer_cfg=py150_infer_cfg,
eval_cfg=py150_eval_cfg,
)
]
2 changes: 2 additions & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .chid import * # noqa: F401, F403
from .cibench import * # noqa: F401, F403
from .civilcomments import * # noqa: F401, F403
from .clozeTest_maxmin import * # noqa: F401, F403
from .cluewsc import * # noqa: F401, F403
from .cmb import * # noqa: F401, F403
from .cmmlu import * # noqa: F401, F403
Expand Down Expand Up @@ -55,6 +56,7 @@
from .natural_question import * # noqa: F401, F403
from .obqa import * # noqa: F401, F403
from .piqa import * # noqa: F401, F403
from .py150 import * # noqa: F401, F403
from .qasper import * # noqa: F401, F403
from .qaspercut import * # noqa: F401, F403
from .race import * # noqa: F401, F403
Expand Down
35 changes: 35 additions & 0 deletions opencompass/datasets/clozeTest_maxmin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json

from datasets import Dataset

from opencompass.registry import LOAD_DATASET

from .base import BaseDataset


@LOAD_DATASET.register_module()
class MaxminDataset(BaseDataset):

@staticmethod
def load(test_path, answer_path=None):
if answer_path is not None:
with open(answer_path, 'r', encoding='utf-8') as answer_f:
answers = {}
for line in answer_f.readlines():
line = line.strip()
answers[line.split('<CODESPLIT>')[0]] = line.split(
'<CODESPLIT>')[1]
datasets = []
with open(test_path, 'r') as test_f:
test_data = json.load(test_f)
for item in test_data:
dataset = dict()
dataset['nl_tokens'] = ' '.join(item['nl_tokens'])
dataset['pl_tokens'] = ' '.join(item['pl_tokens'])
if answer_path is not None:
dataset['answer'] = 'A' if answers[
item['idx']] == 'max' else 'B'
else:
dataset['answer'] = ''
datasets.append(dataset)
return Dataset.from_list(datasets)
38 changes: 38 additions & 0 deletions opencompass/datasets/py150.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
import re

from datasets import Dataset

from opencompass.registry import LOAD_DATASET

from .base import BaseDataset


def py150_post_process(code):
code = code.replace('<NUM_LIT>',
'0').replace('<STR_LIT>',
'').replace('<CHAR_LIT>', '')
pattern = re.compile(r'<(STR|NUM|CHAR)_LIT:(.*?)>', re.S)
lit_s = re.findall(pattern, code)
for lit in lit_s:
code = code.replace(f'<{lit[0]}_LIT:{lit[1]}>', lit[1])
code = json.loads(code)
code['input'] = code['input'].replace('<s>', '').split('<EOL>')
for code_line in code['input']:
code_line = code_line.strip()
code['input'] = '\n'.join(code['input'])
code.pop('id', None)
return code


@LOAD_DATASET.register_module()
class Py150Dataset(BaseDataset):

@staticmethod
def load(path):
lines = open(path, 'r').readlines()
rows = []
for line in lines:
row = py150_post_process(line)
rows.append(row)
return Dataset.from_list(rows)

0 comments on commit b3cbef3

Please sign in to comment.