Skip to content

Commit

Permalink
Merge branch 'main' into handle-ranges-per-guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
i-be-snek authored Jul 10, 2024
2 parents f4f3b0d + 4c821dd commit e02592f
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions Evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,25 @@
gold = gold[gold["Article_From"] == args.score]

logger.info(f"Evaluation limited to {sys.shape} events from source {args.score}")
assert len(sys.sort_values("Event_ID")["Event_ID"].to_list()) == len(
gold.sort_values("Event_ID")["Event_ID"].to_list()
), f"Missing events! {set(sys.sort_values('Event_ID')['Event_ID'].to_list()) ^ set(gold.sort_values('Event_ID')['Event_ID'].to_list())}"

# Add dummy rows for missing events
missing_ids = set(sys['Event_ID'].to_list()) ^ set(gold['Event_ID'].to_list())
if missing_ids:
logger.info(f"Missing events! {missing_ids}. The columns in these events will be constructed with `NoneType` objects. The system output will be penalized for missing events with the selected null penalty ({args.null_penalty})")
gold_cols = list(gold.columns)
rows_to_add = []
for event_id in missing_ids:
# Create a dictionary for the new row with all columns set to "" except Country_Norm which excepts a list
new_row = {col: None for col in gold_cols}
for col in ["Country_Norm", "Location_Norm"]:
if col in gold_cols:
new_row[col] = "[]"
new_row['Event_ID'] = event_id # Set the 'Event_ID'
rows_to_add.append(new_row)

missing_rows = pd.DataFrame(rows_to_add)
sys = pd.concat([sys, missing_rows], ignore_index=True).sort_values('Event_ID')
sys.replace({np.nan: None}, inplace=True)

# Specify null penalty
null_penalty = args.null_penalty
Expand All @@ -132,9 +148,11 @@
sys = sys.sort_values("Event_ID")
gold = gold.sort_values("Event_ID")

for col in ["Country_Norm"]:
sys[col] = sys[col].apply(ast.literal_eval)
gold[col] = gold[col].apply(ast.literal_eval)
for col in ["Country_Norm", "Location_Norm"]:
if col in sys.columns:
sys[col] = sys[col].apply(ast.literal_eval)
if col in gold.columns:
gold[col] = gold[col].apply(ast.literal_eval)

logger.info("Parsed strings to lists or dicts")

Expand Down

0 comments on commit e02592f

Please sign in to comment.