diff --git a/.github/workflows/run_eval_tests.yml b/.github/workflows/run_eval_tests.yml new file mode 100644 index 000000000..3c47f43a0 --- /dev/null +++ b/.github/workflows/run_eval_tests.yml @@ -0,0 +1,40 @@ +name: Run Unit Tests for Evaluation scripts via Pytest + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2.4.0 + with: + fetch-depth: 0 + - name: Set up Python 3.11 + uses: actions/setup-python@v2.3.1 + with: + python-version: 3.11 + - id: cache-dependencies + name: Cache dependencies + uses: actions/cache@v2.1.7 + with: + path: ${{ github.workspace }}/.venv + key: dependencies-${{ hashFiles('**/poetry.lock') }} + restore-keys: dependencies- + - name: Install dependencies + if: steps.cache-dependencies.cache-hit != 'true' + run: | + python3 -m pip install -U pip poetry + poetry --version + poetry check --no-interaction + poetry config virtualenvs.in-project true + poetry install --no-interaction + - name: Run tests + run: | + poetry run pytest -ra -s tests diff --git a/Database/gold/ImpactDB_DataTable_Validation.xlsx b/Database/gold/ImpactDB_DataTable_Validation.xlsx index 9a243cb63..1d5c6cf13 100644 Binary files a/Database/gold/ImpactDB_DataTable_Validation.xlsx and b/Database/gold/ImpactDB_DataTable_Validation.xlsx differ diff --git a/Database/gold/gold_from_excel/Affected.parquet b/Database/gold/gold_from_excel/Affected.parquet deleted file mode 100644 index ab2283090..000000000 --- a/Database/gold/gold_from_excel/Affected.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:04cf088bd965b9c65af9e28b2427e64d0acbada0fdb67b9f956b12321be5f4bc -size 27316 diff --git a/Database/gold/gold_from_excel/Buildings_Damaged.parquet b/Database/gold/gold_from_excel/Buildings_Damaged.parquet deleted file mode 100644 index 21f800396..000000000 --- a/Database/gold/gold_from_excel/Buildings_Damaged.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6536b9f8e83bac69a746f0e12afec9dde21ab1e0bb740c06952ba8245a500dd8 -size 28526 diff --git a/Database/gold/gold_from_excel/Damage.parquet b/Database/gold/gold_from_excel/Damage.parquet deleted file mode 100644 index 680b7f4c9..000000000 --- a/Database/gold/gold_from_excel/Damage.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:33397cf34322b3713754580e9fdac2597264380bb9dc02ec956d723b19bf90da -size 29985 diff --git a/Database/gold/gold_from_excel/Deaths.parquet b/Database/gold/gold_from_excel/Deaths.parquet deleted file mode 100644 index ddf839d9d..000000000 --- a/Database/gold/gold_from_excel/Deaths.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f008ca007eb1c12a9b94899df06d50aaf7771710743e35af322004254b2ba1e7 -size 28477 diff --git a/Database/gold/gold_from_excel/Displaced.parquet b/Database/gold/gold_from_excel/Displaced.parquet deleted file mode 100644 index 7da37003b..000000000 --- a/Database/gold/gold_from_excel/Displaced.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:61f0a601dfb9191c17b42050a3d414edf9629583a79a708889077c34f40a84bb -size 27660 diff --git a/Database/gold/gold_from_excel/Events.parquet b/Database/gold/gold_from_excel/Events.parquet deleted file mode 100644 index 696474fe0..000000000 --- a/Database/gold/gold_from_excel/Events.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:86155cdf45bc57dfee0c8508da5f704abef572334652d0408a46b6b0ea6c93ce -size 40090 diff --git a/Database/gold/gold_from_excel/Homeless.parquet b/Database/gold/gold_from_excel/Homeless.parquet deleted file mode 100644 index 7ee79297b..000000000 --- a/Database/gold/gold_from_excel/Homeless.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:45e933c3d06ba8ecc8dba20349cab8864d32cd11247bff693bfc633df1c16095 -size 27116 diff --git a/Database/gold/gold_from_excel/Injured.parquet b/Database/gold/gold_from_excel/Injured.parquet deleted file mode 100644 index b27d720a9..000000000 --- a/Database/gold/gold_from_excel/Injured.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9e0045f950ff2d740bb506183ec8fb1493da5643d10ace107234db976eb074ee -size 27320 diff --git a/Database/gold/gold_from_excel/Insured_Damage.parquet b/Database/gold/gold_from_excel/Insured_Damage.parquet deleted file mode 100644 index 571eb7bc8..000000000 --- a/Database/gold/gold_from_excel/Insured_Damage.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7b7799dcee9ed7c84ebd9fee1bca6b150b29025df1dec0d6dd5c0e953034421d -size 28771 diff --git a/Database/gold/specific_instances/Affected.parquet b/Database/gold/specific_instances/Affected.parquet new file mode 100644 index 000000000..d6662437a --- /dev/null +++ b/Database/gold/specific_instances/Affected.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb361a5a29bf3e1643b7b2b6c21694857b1041c796477b4724a5503f1b649c98 +size 28742 diff --git a/Database/gold/specific_instances/Buildings_Damaged.parquet b/Database/gold/specific_instances/Buildings_Damaged.parquet new file mode 100644 index 000000000..d0333a37d --- /dev/null +++ b/Database/gold/specific_instances/Buildings_Damaged.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3e831bdddee5f3213a72ae131e6ea855d4f6685582b21577d20759252cad3fe +size 29952 diff --git a/Database/gold/specific_instances/Damage.parquet b/Database/gold/specific_instances/Damage.parquet new file mode 100644 index 000000000..e89ae8efe --- /dev/null +++ b/Database/gold/specific_instances/Damage.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84e7bf563c5f02dc15a3e259e8d3b3c186ab4025fd30b8ac6dfb325a7c4dfd19 +size 31411 diff --git a/Database/gold/specific_instances/Deaths.parquet b/Database/gold/specific_instances/Deaths.parquet new file mode 100644 index 000000000..25675ede2 --- /dev/null +++ b/Database/gold/specific_instances/Deaths.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ade788503ddf610051387b91266a4eac0b7568092e48290d8151d7e8df56e729 +size 29903 diff --git a/Database/gold/specific_instances/Displaced.parquet b/Database/gold/specific_instances/Displaced.parquet new file mode 100644 index 000000000..032e0379b --- /dev/null +++ b/Database/gold/specific_instances/Displaced.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ce776a668ab63f42b93ad4968aba57cec0f88148c818d93554493b8e33600ab +size 29086 diff --git a/Database/gold/specific_instances/Events.parquet b/Database/gold/specific_instances/Events.parquet new file mode 100644 index 000000000..c40daa37a --- /dev/null +++ b/Database/gold/specific_instances/Events.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61ad566b2afcf68ea6a094345ffdb7f8ccb97cc7937ba76c296bbedc799670cd +size 42799 diff --git a/Database/gold/specific_instances/Homeless.parquet b/Database/gold/specific_instances/Homeless.parquet new file mode 100644 index 000000000..ccd044e6a --- /dev/null +++ b/Database/gold/specific_instances/Homeless.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc2468c5305c93f9a6f88ca617e1b7e0dc123b64f129d2948cf98ba57f35002e +size 28542 diff --git a/Database/gold/specific_instances/Injured.parquet b/Database/gold/specific_instances/Injured.parquet new file mode 100644 index 000000000..37da394c9 --- /dev/null +++ b/Database/gold/specific_instances/Injured.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96cb177948896be74f3d046f999664ab2ba42d3ff6016a439d356da93b64a6cd +size 28746 diff --git a/Database/gold/specific_instances/Insured_Damage.parquet b/Database/gold/specific_instances/Insured_Damage.parquet new file mode 100644 index 000000000..b2f382664 --- /dev/null +++ b/Database/gold/specific_instances/Insured_Damage.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67d0697c096c2ad53131c629dd7c85626cbd006a7c604244fdb61e2ccbb3fa1d +size 30197 diff --git a/Database/gold_from_excel.py b/Database/gold_from_excel.py index a8cc109a9..316c73c44 100644 --- a/Database/gold_from_excel.py +++ b/Database/gold_from_excel.py @@ -1,9 +1,10 @@ -import pathlib import argparse +import pathlib import re import pandas as pd -from scr.normalize_utils import Logging + +from Database.scr.normalize_utils import Logging pd.set_option("display.max_rows", None) pd.set_option("display.max_columns", None) @@ -58,6 +59,7 @@ def flatten(xss): # main and specific impact events have these three column sets in common shared_cols = [ "Event_ID", + "Event_ID_decimal", "Source", "Event_Name", ] @@ -115,7 +117,8 @@ def flatten(xss): for i in ["Insured_Damage", "Damage"]: convert_to_boolean.extend([x for x in specific_impacts_columns[i] if "_Adjusted" in x and "_Year" not in x]) -convert_to_float = ["Event_ID"] +convert_to_float = ["Event_ID_decimal"] + def flatten_data_table(): logger.info("Loading excel file...") @@ -197,7 +200,7 @@ def flatten_data_table(): ) logger.info("Splitting main events from specific impact") - data_table["main"] = data_table.Event_ID.apply(lambda x: float(x).is_integer()) + data_table["main"] = data_table.Event_ID_decimal.apply(lambda x: float(x).is_integer()) data_table["main"].value_counts() logger.info("Storing Main Events table") @@ -271,4 +274,4 @@ def flatten_data_table(): logger.info(f"Creating {args.output_dir} if it does not exist!") pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) - flatten_data_table() \ No newline at end of file + flatten_data_table() diff --git a/Database/merge_json_output.py b/Database/merge_json_output.py index 3235bed5f..93137fb0c 100644 --- a/Database/merge_json_output.py +++ b/Database/merge_json_output.py @@ -1,6 +1,7 @@ import argparse import pathlib -from scr.normalize_utils import Logging, NormalizeJsonOutput + +from Database.scr.normalize_utils import Logging, NormalizeJsonOutput if __name__ == "__main__": logger = Logging.get_logger("merge-mixtral-or-mistral-output") @@ -34,7 +35,7 @@ logger.info(args) logger.info(f"Creating {args.output_dir} if it does not exist!") - pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) json_utils = NormalizeJsonOutput() dfs = json_utils.merge_json(args.input_dir) diff --git a/Database/output/nlp4climate/README.md b/Database/output/README.md similarity index 73% rename from Database/output/nlp4climate/README.md rename to Database/output/README.md index c9ccf1eb5..b5d68ce71 100644 --- a/Database/output/nlp4climate/README.md +++ b/Database/output/README.md @@ -1,5 +1,7 @@ #### Post-processed files +This is where parsed LLM outputs are stored in .parquet + Suggested breakdown: ```shell @@ -8,8 +10,8 @@ Suggested breakdown: └── nlp4climate # <-- ℹ️ Broader name to group experiments ├── dev # <-- ℹ️ dev set, specific to this group of experiments │ ├── gpt4_experiment.parquet - │ └── mistral_experiment.json + │ └── mistral_experiment.parquet └── test # <-- ℹ️ test set, specific to this group of experiments ├── gpt4_experiment.parquet - └── mistral_experiment.json + └── mistral_experiment.parquet ``` diff --git a/Database/parse_events.py b/Database/parse_events.py index f17ddbb58..a7eac14f8 100644 --- a/Database/parse_events.py +++ b/Database/parse_events.py @@ -1,11 +1,12 @@ import argparse -import re import pathlib +import re + import pandas as pd -from scr.normalize_locations import NormalizeLocation -from scr.normalize_numbers import NormalizeNumber -from scr.normalize_utils import Logging, NormalizeUtils +from Database.scr.normalize_locations import NormalizeLocation +from Database.scr.normalize_numbers import NormalizeNumber +from Database.scr.normalize_utils import Logging, NormalizeUtils if __name__ == "__main__": logger = Logging.get_logger("parse_events") @@ -81,7 +82,7 @@ logger.info(f"Passed args: {args}") logger.info(f"Creating {args.output_dir} if it does not exist!") - pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) utils = NormalizeUtils() nlp = utils.load_spacy_model(args.spaCy_model_name) @@ -138,7 +139,7 @@ logger.info(f"Normalizing boolean column {inflation_adjusted_col}") events[inflation_adjusted_col] = events[inflation_adjusted_col].replace( {_no: False, _yes: True}, regex=True - ) + ) logger.info("Normalizing nulls") events = utils.replace_nulls(events) @@ -203,7 +204,6 @@ ) if args.location_column in events.columns and args.country_column in events.columns: - logger.info("Normalizing Locations") events["Location_Tmp"] = events["Location"].apply( lambda locations: ( @@ -312,9 +312,11 @@ sub_event = pd.concat([sub_event.Event_ID, sub_event[col].apply(pd.Series)], axis=1) - logger.info(f"Dropping any columns with non-str column names due to None types in the dicts {[c for c in sub_event.columns if not isinstance(c, str)]}") + logger.info( + f"Dropping any columns with non-str column names due to None types in the dicts {[c for c in sub_event.columns if not isinstance(c, str)]}" + ) sub_event = sub_event[[c for c in sub_event.columns if isinstance(c, str)]] - + logger.info(f"Normalizing nulls for subevent {col}") sub_event = utils.replace_nulls(sub_event) @@ -322,7 +324,7 @@ col for col in sub_event.columns if col.startswith("Num_") - or col.endswith("_Damage") + or col.endswith("Damage") and "Date" not in col and args.location_column not in col ] @@ -389,10 +391,10 @@ lambda country: (norm_loc.get_gadm_gid(country=country) if country else None) ) - ''' + """ logger.info(f"Dropping columns with no locations for subevent {col}") sub_event.dropna(subset=[f"Location_{location_col}"], how="all", inplace=True) - ''' + """ logger.info(f"Normalizing location names for subevent {col}") sub_event[ [ @@ -427,7 +429,7 @@ ) def normalize_location_rows_if_country(row): - # if location and country are identical in subevents, generalize country normalization + # if location and country are identical in subevents, generalize country normalization if row[f"Location_{location_col}"] == row[args.country_column]: for i in ["Norm", "Type", "GeoJson", "GID"]: row[f"Location_{location_col}_{i}"] = row[f"Country_{i}"] diff --git a/Database/scr/normalize_locations.py b/Database/scr/normalize_locations.py index 961f107af..ad4af6674 100644 --- a/Database/scr/normalize_locations.py +++ b/Database/scr/normalize_locations.py @@ -485,7 +485,7 @@ def get_gadm_gid( @staticmethod def extract_locations( text: str, - ) -> tuple[list] | None: + ) -> tuple[list[str]]: """ Extracts countries and sublocations from the '|, &' string format Example: @@ -496,7 +496,7 @@ def extract_locations( try: split_by_pipe = text.split("|") except BaseException: - return + return [], [] try: if split_by_pipe: for s in split_by_pipe: @@ -507,7 +507,7 @@ def extract_locations( locations.extend([locations_tmp]) return countries, locations except BaseException: - return + return [], [] def _debug(self, response): self.logger.debug(type(response)) diff --git a/Evaluation/__init__.py b/Evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Evaluation/comparer.py b/Evaluation/comparer.py index ebd1636b1..00aa7a8fb 100644 --- a/Evaluation/comparer.py +++ b/Evaluation/comparer.py @@ -1,4 +1,4 @@ -import normaliser +from Evaluation.normaliser import Normaliser class Comparer: @@ -8,7 +8,7 @@ def __init__(self, null_penalty: bool, target_columns: list[str]): """Initialisation.""" # Penalty score if one field is None, but not the other self.null_penalty = null_penalty - self.norm = normaliser.Normaliser() + self.norm = Normaliser() self.target_columns = target_columns def target_col(self, l) -> list: diff --git a/Evaluation/matcher.py b/Evaluation/matcher.py new file mode 100644 index 000000000..04f8e2146 --- /dev/null +++ b/Evaluation/matcher.py @@ -0,0 +1,134 @@ +from statistics import mean + +from Evaluation.comparer import Comparer +from Evaluation.utils import Logging + + +class SpecificInstanceMatcher: + """Matches and pads specific instances (subevents) from two separate lists. + 'Padded' specific instances will have NoneType objects as values""" + + def __init__(self, threshold: float = 0.6, null_penalty: float = 0.5): + self.logger = Logging.get_logger("specific instance matcher") + + self.threshold = threshold + self.int_cat: list[str] = [ + "Num_Min", + "Num_Max", + "Adjusted_Year", + "Start_Date_Day", + "Start_Date_Month", + "Start_Date_Year", + "End_Date_Day", + "End_Date_Month", + "End_Date_Year", + ] + self.bool_cat: list[str] = ["Adjusted"] + self.str_cat: list[str] = ["Country_Norm", "Unit"] + self.list_cat: list[str] = ["Location_Norm"] + + self.comp = Comparer(null_penalty, []) + + @staticmethod + def create_pad(specific_instance: dict) -> dict: + padded = {} + for k in specific_instance.keys(): + # preserve "Event_D" + padded[k] = specific_instance[k] if k == "Event_ID" else None + return padded + + def calc_similarity(self, gold_instance: dict, sys_list: list) -> list[float]: + score_list: float = [] + for si in sys_list: + scores = [] + for k in gold_instance.keys(): + if k in self.int_cat: + r = self.comp.integer(gold_instance[k], si[k]) + elif k in self.bool_cat: + r = self.comp.boolean(gold_instance[k], si[k]) + elif k in self.str_cat: + r = self.comp.string(gold_instance[k], si[k]) + elif k in self.list_cat: + r = self.comp.sequence(gold_instance[k], si[k]) + try: + scores.append(1 - r) + del r + except Exception: + if k != "Event_ID": + self.logger.warning(f"Unsupported column name: {k} will be ignored during matching.") + + score_list.append(mean(scores)) + + # index of mean score corresponds to sys_list item + return score_list + + def schema_checker(self, gold_list: list[dict], sys_list: list[dict]) -> bool: + # in case the sys output or gold is an empty list + if len(gold_list) == 0 or len(sys_list) == 0: + return True + + for g in range(len(gold_list)): + # check that all column names in the gold are consistent + if sorted(gold_list[0].keys()) != sorted(gold_list[g].keys()): + self.logger.error( + f"Gold file contains entries with inconsistent column names at specific instance #{g}: {gold_list[g].keys()}. Expected columns: {gold_list[0].keys()}" + ) + return False + + for s in range(len(sys_list)): + # if all gold columns are consistent, check that they are consistent with the sys_list ones + try: + assert all([e in sys_list[s].keys() for e in gold_list[0].keys()]) + return True + except Exception: + self.logger.error( + f"Inconsistent columns found in sys file!: {[e for e in sys_list[s].keys() if e not in gold_list[0].keys()]}" + ) + return False + + def match(self, gold_list: list[dict], sys_list: list[dict]) -> tuple[list[dict]]: + if self.schema_checker(gold_list, sys_list) != True: + self.logger.error("Please check the column names in your gold and sys files.") + raise BaseException + + gold, sys, similarity, gold_matched, sys_matched = [], [], [], [], [] + similarity_matrix = [self.calc_similarity(si, sys_list) for si in gold_list] + best_matches = [ + (gi, si, similarity_matrix[gi][si]) + for gi in range(len(similarity_matrix)) + for si in range(len(similarity_matrix[gi])) + if similarity_matrix[gi][si] > self.threshold + ] + best_matches.sort(key=lambda x: x[2], reverse=True) + + # find the best matches in the similarity matrix + for gi, si, sim in best_matches: + if gi not in gold_matched and si not in sys_matched: + gold.append(gold_list[gi]) + sys.append(sys_list[si]) + gold_matched.append(gi) + sys_matched.append(si) + similarity.append(sim) + + # pad remaining unmatched specific instances + for gi in range(len(gold_list)): + if gi not in gold_matched: + gold.append(gold_list[gi]) + sys.append(self.create_pad(gold_list[gi])) + + for si in range(len(sys_list)): + if si not in sys_matched: + sys.append(sys_list[si]) + gold.append(self.create_pad(sys_list[si])) + + assert len(gold) == len(sys), AssertionError( + f"Something went wrong! number of specific instances in gold: {len(gold)}; in sys: {len(sys)}" + ) + + for ds in [gold, sys]: + counter = 0 + for si in ds: + si["Event_ID"] = f"{si['Event_ID']}-{counter}" + counter += 1 + + return (gold, sys) diff --git a/Evaluation/weights.py b/Evaluation/weights.py index 5fc727c85..b76df4ff6 100644 --- a/Evaluation/weights.py +++ b/Evaluation/weights.py @@ -112,4 +112,4 @@ # "Country_Norm": 1, # "Location_Norm": 1, }, -} +} \ No newline at end of file diff --git a/README.md b/README.md index 6b8b38017..7f33228fb 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ If you have generated some LLM output and would like to test it against the dev Choose a new experiment name! You will use this for the whole pipeline. -#### PRESTEPS +#### PRESTEP (before Step 2): +If the system output is split across several files (such as Mixtral and Mistral system outputs), then first merge it: - Normalizing JSON output for Mistral/Mixtral If the system output is split across several files (such as Mixtral and Mistral system outputs), then first merge it: @@ -128,7 +129,7 @@ Also, this config will result in evaluating only on this smaller set of columns, ``` -##### (B) Evaluate +##### (B) Evaluate main events When your config is ready, run the evaluation script: ```shell @@ -146,6 +147,37 @@ poetry run python3 Evaluation/evaluator.py --sys-file Database/output/nlp4clima --weights_config nlp4climate ``` +#### Evaluate sub events (ie. specific instances) + +Specific instances can be evaluated using the same script. The same script (`Evaluation/evalutor.py`) will automatically match specific instances from the gold data with the system output. If no match exists for a specific instance, it will be matched up with a "padded" example with NULL values so that the system is penalized for not having been able to find a particular specific instance or for finding extra specific instances not found in the gold dataset. + +Below is a scipt that evaluates two dummy sets (gold and sys) to showcase a working example and the correct schema for the `.parquet` files. Sub events are evaluated separately from main events. + +```shell +poetry run python3 Evaluation/evaluator.py \ +--sys-file tests/specific_instance_eval/test_sys_list_death.parquet \ +--gold-file tests/specific_instance_eval/test_gold_list_death.parquet \ +--model-name "specific_instance_eval_test/dev/deaths" \ +--event_type sub \ +--weights_config specific_instance \ +--specific_instance_type deaths +``` +If run properly, you should see the results in `Database/evaluation_results/specific_instance_eval_test`: + +```shell +Database/evaluation_results/specific_instance_eval_test +└── dev + └── deaths + ├── all_27_deaths_avg_per_event_id_results.csv # <- average error rate grouped by event_id + ├── all_27_deaths_avg_results.json # <- overall average results + ├── all_27_deaths_results.csv # <- results for each pair of gold/sys + ├── gold_deaths.parquet # <- modified gold file with matches + padded specific instances + └── sys_deaths.parquet # <- modified sys file with matches + padded specific instances +``` + +> [!WARNING] +> Do not commit these files to your branch or to `main`, big thanks! + ### Parsing and normalization If you have new events to add to the database, first parse them and insert them. @@ -203,14 +235,6 @@ poetry run python3 Database/gold_from_excel.py --input-file "Database/gold/Impac ``` These results are not split to test/dev. -The plan is to expand this functionality further and evaluate subevents - -To be implemented: -- [ ] How to evaluate subevents when the gold may contain more/less than the system output? Maybe subevents can be matched by location and timestamp and evaluated accordingly -- finding too many could be penalized. -- [ ] Match the short uuids (generated by [Database/scr/normalize_utils.pyrandom_short_uuid](Database/scr/normalize_utils.pyrandom_short_uuid)) in the excel sheet for the ones that already exist in the dev and test sets. -- [ ] Make any edits (if needed) to the evaluation script so it can handle subevents - -(Input appreciated! Just email @i-be-snek) > [!IMPORTANT] > Please don't track or push excel sheets into the repository diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..768a60430 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,12 @@ + +[pytest] +testpaths = tests +python_files = test_*.py +addopts = -rf --import-mode=importlib +pythonpath = . + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" diff --git a/tests/specific_instance_eval/test_gold_list_death.parquet b/tests/specific_instance_eval/test_gold_list_death.parquet new file mode 100644 index 000000000..626f7861d --- /dev/null +++ b/tests/specific_instance_eval/test_gold_list_death.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b6c58941d521870609b56226511726ae68596d49205a3b6b09ee793b372c071 +size 11610 diff --git a/tests/specific_instance_eval/test_sys_list_death.parquet b/tests/specific_instance_eval/test_sys_list_death.parquet new file mode 100644 index 000000000..61950344d --- /dev/null +++ b/tests/specific_instance_eval/test_sys_list_death.parquet @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4def986f0d4de7c8bb5e8fc84c886dc4d9c5d01004d1d8466eb22d83ce6e283d +size 9050 diff --git a/tests/test_specific_instance_matcher.py b/tests/test_specific_instance_matcher.py new file mode 100644 index 000000000..cdcb32a56 --- /dev/null +++ b/tests/test_specific_instance_matcher.py @@ -0,0 +1,216 @@ +import pytest + +from Evaluation.matcher import SpecificInstanceMatcher + + +class TestSpecificInstanceMatcher: + @pytest.mark.parametrize( + "gold_instance, sys_list, expected", + [ + ( + {"Num_Min": 0, "Num_Max": 10, "Start_Date_Year": 2030}, + [ + {"Num_Min": 2, "Num_Max": 91, "Start_Date_Year": 2030}, + {"Num_Min": 0, "Num_Max": 10, "Start_Date_Year": 2031}, + ], + [0.39933993399339934, 0.9999179184109004], + ), + ], + ) + def test_calc_similarity(self, gold_instance, sys_list, expected): + matcher = SpecificInstanceMatcher() + if expected: + assert matcher.calc_similarity(gold_instance, sys_list) == expected + else: + with pytest.raises(UnboundLocalError): + matcher.calc_similarity(gold_instance, sys_list) + + @pytest.mark.parametrize( + "test_gold_list, test_sys_list, expected_gold, expected_sys", + [ + ( + # gold_list + [ + { + "Event_ID": "aA3", + "Num_Min": 2, + "Num_Max": 82, + "Start_Date_Year": 2030, + "Location_Norm": ["Amman", "Zarqa"], + }, + { + "Event_ID": "aA3", + "Num_Min": None, + "Num_Max": 91, + "Start_Date_Year": 2030, + "Location_Norm": ["Uppsala", "Stockholm"], + }, + { + "Event_ID": "aA3", + "Num_Min": 0, + "Num_Max": 10, + "Start_Date_Year": 2031, + "Location_Norm": ["Paris", "Lyon"], + }, + ], + # sys_list + [ + { + "Event_ID": "aA3", + "Num_Min": 0, + "Num_Max": 11, + "Start_Date_Year": 2031, + "Location_Norm": ["Lyon"], + }, + { + "Event_ID": "aA3", + "Num_Min": 1, + "Num_Max": 84, + "Start_Date_Year": 2029, + "Location_Norm": ["Uppsala", "Zarqa"], + }, + { + "Event_ID": "aA3", + "Num_Min": 2, + "Num_Max": 91, + "Start_Date_Year": 2030, + "Location_Norm": ["Stockholm"], + }, + { + "Event_ID": "aA3", + "Num_Min": 7, + "Num_Max": 30, + "Start_Date_Year": 2030, + "Location_Norm": ["Uppsala", "Linköping"], + }, + ], + # gold + [ + { + "Event_ID": "aA3-0", + "Num_Min": 0, + "Num_Max": 10, + "Start_Date_Year": 2031, + "Location_Norm": ["Paris", "Lyon"], + }, + { + "Event_ID": "aA3-1", + "Num_Min": None, + "Num_Max": 91, + "Start_Date_Year": 2030, + "Location_Norm": ["Uppsala", "Stockholm"], + }, + { + "Event_ID": "aA3-2", + "Num_Min": 2, + "Num_Max": 82, + "Start_Date_Year": 2030, + "Location_Norm": ["Amman", "Zarqa"], + }, + { + "Event_ID": "aA3-3", + "Num_Min": None, + "Num_Max": None, + "Start_Date_Year": None, + "Location_Norm": None, + }, + ], + # sys + [ + { + "Event_ID": "aA3-0", + "Num_Min": 0, + "Num_Max": 11, + "Start_Date_Year": 2031, + "Location_Norm": ["Lyon"], + }, + { + "Event_ID": "aA3-1", + "Num_Min": 2, + "Num_Max": 91, + "Start_Date_Year": 2030, + "Location_Norm": ["Stockholm"], + }, + { + "Event_ID": "aA3-2", + "Num_Min": 1, + "Num_Max": 84, + "Start_Date_Year": 2029, + "Location_Norm": ["Uppsala", "Zarqa"], + }, + { + "Event_ID": "aA3-3", + "Num_Min": 7, + "Num_Max": 30, + "Start_Date_Year": 2030, + "Location_Norm": ["Uppsala", "Linköping"], + }, + ], + ), + ( + [{"Event_ID": "aA3", "Num_Min": 1}], + [{"Event_ID": "aA3", "Num_Min": 1000}], + [ + {"Event_ID": "aA3-0", "Num_Min": 1}, + {"Event_ID": "aA3-1", "Num_Min": None}, + ], + [ + {"Event_ID": "aA3-0", "Num_Min": None}, + {"Event_ID": "aA3-1", "Num_Min": 1000}, + ], + ), + # empty lists as input + ([], [], [], []), + # empty sys_list as input + ( + [{"Event_ID": "aA3B4", "Start_Date_Year": 2030}], + [], + [{"Event_ID": "aA3B4-0", "Start_Date_Year": 2030}], + [{"Event_ID": "aA3B4-0", "Start_Date_Year": None}], + ), + # empty gold_list as input + ( + [], + [{"Event_ID": "aA3C4", "Start_Date_Year": 2030}], + [{"Event_ID": "aA3C4-0", "Start_Date_Year": None}], + [{"Event_ID": "aA3C4-0", "Start_Date_Year": 2030}], + ), + # inconsistent schema + ( + [ + { + "Event_ID": "aA3", + "Num_Min": 0, + "Num_Max": 10, + "Start_Date_Year": 2030, + } + ], + [ + { + "Event_ID": "aA3-0", + "Num_Mix": 2, + "Num_Max": 91, + "Start_Date_Year": 2030, + }, + { + "Event_ID": "aA3-1", + "Num_Min": 0, + "Num_Max": 10, + "Start_Date_Year": 2031, + }, + ], + None, + None, + ), + ], + ) + def test_matcher(self, test_gold_list, test_sys_list, expected_gold, expected_sys): + matcher = SpecificInstanceMatcher(threshold=0.6, null_penalty=0.5) + if expected_gold is not None and expected_sys is not None: + assert matcher.match(gold_list=test_gold_list, sys_list=test_sys_list) == ( + expected_gold, + expected_sys, + ) + else: + with pytest.raises(BaseException): + matcher.match(test_gold_list, test_sys_list)