Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Guo Zhengrui committed Dec 23, 2024
1 parent 31a40af commit 303e7f8
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 5 deletions.
5 changes: 3 additions & 2 deletions main_test_AllinOne.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.tokenizers import Tokenizer, MedicalReportTokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.tester_AllinOne import Tester
Expand Down Expand Up @@ -99,7 +99,8 @@ def main():
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

tokenizer = Tokenizer(args)
# tokenizer = Tokenizer(args)
tokenizer = MedicalReportTokenizer(args)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
model = HistGenModel(args, tokenizer)

Expand Down
4 changes: 2 additions & 2 deletions main_train_AllinOne.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer, ModernTokenizer
from modules.tokenizers import Tokenizer, MedicalReportTokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
Expand Down Expand Up @@ -110,7 +110,7 @@ def main():


# tokenizer = Tokenizer(args)
tokenizer = ModernTokenizer(args)
tokenizer = MedicalReportTokenizer(args)

train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
Expand Down
238 changes: 237 additions & 1 deletion modules/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,240 @@ def batch_encode(self, reports: List[str],
truncation=True,
return_tensors='pt'
)
return encodings['input_ids']
return encodings['input_ids']

class MedicalReportTokenizer(object):
def __init__(self, args):
self.ann_path = args.ann_path
self.threshold = args.threshold
self.dataset_name = args.dataset_name

# Special tokens
self.bos_token = '<bos>'
self.eos_token = '<eos>'
self.pad_token = '<pad>'
self.unk_token = '<unk>'

# Add medical domain vocabulary
self.domain_vocab = set([
# Common anatomical terms
'granuloma', 'parametria', 'lymph', 'nodes', 'cervix', 'uterus',
'ovary', 'ovaries', 'fallopian', 'tubes', 'endometrium', 'myometrium',
'serosa', 'parametrium', 'vagina', 'vulva',

# Common pathological terms
'carcinoma', 'adenocarcinoma', 'sarcoma', 'lymphoma', 'metastasis',
'neoplasm', 'tumor', 'lesion', 'mass', 'cyst', 'polyp', 'nodule', 'scars', 'adult',
'dome-shaped', 'flap', 'prophylactic',

# Descriptive terms
'focal', 'diffuse', 'acute', 'chronic', 'mild', 'moderate', 'severe',
'benign', 'malignant', 'invasive', 'metastatic', 'astrocytomas', 'anaplastic', 'cataractous',

"mm.", "cm.", "%", "mm,", "cm,", "%,",

# Add measurement combinations
*[f"{i}mm" for i in range(1, 101)], # 1mm to 100mm
*[f"{i}cm" for i in range(1, 51)], # 1cm to 50cm
*[f"{i}%" for i in range(1, 100)], # 1cm to 50cm
*[f"{i}.{j}" for i in range(10) for j in range(10)], # Common decimal numbers
*[f"{i}.{j}cm" for i in range(10) for j in range(10)],
*[f"{i}.{j}mm" for i in range(10) for j in range(10)],
*[f"{i}.{j}%" for i in range(10) for j in range(10)],
*[f"{i}th" for i in range(1, 101)],
])

# Map dataset names to cleaning functions
self.clean_report_map = {
'iu_xray': self.clean_report_iu_xray,
'wsi_report': self.clean_report_pathology,
'mimic_cxr': self.clean_report_mimic_cxr
}
self.clean_report = self.clean_report_map.get(self.dataset_name, self.clean_report_mimic_cxr)

self.preserve_case = (self.dataset_name == 'wsi_report')
self.ann = json.loads(open(self.ann_path, 'r').read())
self.token2idx, self.idx2token = self.create_vocabulary()

def create_vocabulary(self):
total_tokens = []
for example in self.ann['train']:
tokens = self.clean_report(example['report']).split()
total_tokens.extend(tokens)

counter = Counter(total_tokens)
vocab_tokens = set(k for k, v in counter.items() if v >= self.threshold)
vocab_tokens.update(self.domain_vocab)
vocab_tokens = sorted(list(vocab_tokens))

special_tokens = [self.pad_token, self.bos_token, self.eos_token, self.unk_token]
vocab = special_tokens + vocab_tokens

token2idx = {token: idx for idx, token in enumerate(vocab)}
idx2token = {idx: token for idx, token in enumerate(vocab)}
return token2idx, idx2token

def normalize_whitespace(self, text):
if self.dataset_name == 'wsi_report':
text = re.sub(r'\n\s*\n', '\n', text)
text = re.sub(r' +', ' ', text)
else:
text = ' '.join(text.split())
text = text.replace('\n', ' ')
return text

def normalize_punctuation(self, text):
# Replace multiple periods with single period, but preserve decimal points
text = re.sub(r'\.(?![\d])', '. ', text) # Add space after period unless followed by digit
text = re.sub(r'\.{2,}', '.', text) # Replace multiple periods

if self.dataset_name == 'wsi_report':
# Preserve commas in numbers (e.g., 1,000)
text = re.sub(r'([^0-9]),([^0-9])', r'\1, \2', text) # Add spaces around non-numeric commas
text = re.sub(r'([.,;:])\1+', r'\1', text)
else:
text = re.sub(r'\d+\.\s*', '', text)
text = re.sub(r'\s*\.\s*', '. ', text)

return text.strip()

def clean_sentence(self, sentence):
if self.dataset_name == 'wsi_report':
# Preserve numbers, decimal points, and measurements
# Allow numbers, letters, basic punctuation, and measurement units
sentence = re.sub(r'[^a-zA-Z0-9\s.,;:()\-\.]', ' ', sentence)

# Ensure measurements stay together (no space between number and unit)
sentence = re.sub(r'(\d+)\s+(mm|cm|ml|cc|µm)', r'\1\2', sentence)

# Normalize spaces
sentence = re.sub(r'\s+', ' ', sentence)
cleaned = sentence.strip()
else:
sentence = re.sub(r'[^a-zA-Z0-9\s.,;:]', ' ', sentence)
cleaned = sentence.lower().strip()
cleaned = re.sub(r'[.,;:]', '', cleaned)

return cleaned

def tokenize_with_numbers(self, text):
"""Special tokenization that preserves numbers and measurements."""
# Split on spaces but preserve special patterns
tokens = []
words = text.split()

for word in words:
# Check if word contains a number pattern
if re.search(r'\d', word):
# Keep measurements together (e.g., "7mm", "4.5cm")
if re.match(r'^\d+\.?\d*(?:mm|cm|ml|cc|µm)?$', word):
tokens.append(word)
else:
# Split other numeric patterns
numeric_parts = re.findall(r'(\d+\.?\d*)|([^0-9]+)', word)
tokens.extend(part for parts in numeric_parts for part in parts if part)
else:
tokens.append(word)

return tokens

def __call__(self, report):
"""Convert a report to token ids."""
cleaned_report = self.clean_report(report)

if self.dataset_name == 'wsi_report':
tokens = self.tokenize_with_numbers(cleaned_report)
else:
tokens = cleaned_report.split()

tokens = [self.bos_token] + tokens + [self.eos_token]
return [self.get_id_by_token(token) for token in tokens]

def clean_report_pathology(self, report):
"""Special handling for pathology reports."""
# Normalize underscores and whitespace
report = re.sub(r'_+', ' ', report)
report = self.normalize_whitespace(report)
report = self.normalize_punctuation(report)

# Split on periods while preserving them
sentences = re.split(r'(?<=\.)\s+', report)
sentences = [s.strip() for s in sentences if s.strip()]

# Clean each sentence
cleaned_sentences = [self.clean_sentence(sent) for sent in sentences]
cleaned_sentences = [sent for sent in cleaned_sentences if sent]

# Join with proper punctuation
if cleaned_sentences:
# Add period if not present
result = '. '.join(s.rstrip('.') for s in cleaned_sentences) + '.'
# Clean up any double periods
result = re.sub(r'\.+', '.', result)
return result
return ''

def clean_report_iu_xray(self, report):
# Normalize whitespace and punctuation
report = self.normalize_whitespace(report)
report = self.normalize_punctuation(report)

# Split into sentences and clean each one
sentences = [s.strip() for s in report.split('.') if s.strip()]
cleaned_sentences = [self.clean_sentence(sent) for sent in sentences]
cleaned_sentences = [sent for sent in cleaned_sentences if sent]

# Join sentences with proper spacing
return ' . '.join(cleaned_sentences) + ' .'

def clean_report_mimic_cxr(self, report):
# Additional preprocessing for MIMIC-CXR
report = re.sub(r'_+', ' ', report) # Replace underscores with space
report = self.normalize_whitespace(report)
report = self.normalize_punctuation(report)

sentences = [s.strip() for s in report.split('.') if s.strip()]
cleaned_sentences = [self.clean_sentence(sent) for sent in sentences]
cleaned_sentences = [sent for sent in cleaned_sentences if sent]

return ' . '.join(cleaned_sentences) + ' .'

def get_token_by_id(self, id):
return self.idx2token.get(id, self.unk_token)

def get_id_by_token(self, token):
# For non-WSI reports, convert to lowercase for token lookup
if not self.preserve_case:
token = token.lower()
return self.token2idx.get(token, self.token2idx[self.unk_token])

def get_vocab_size(self):
return len(self.token2idx)

def decode(self, ids):
"""Convert token ids back to text."""
tokens = []
for idx in ids:
if idx == self.token2idx[self.eos_token]:
break
if idx != self.token2idx[self.pad_token]:
token = self.idx2token[idx]
if token != self.bos_token: # Skip BOS token in output
tokens.append(token)

# Join tokens with proper spacing
if self.dataset_name == 'wsi_report':
# For WSI reports, handle punctuation specially
text = ''
for i, token in enumerate(tokens):
if token in '.,;:':
text = text.rstrip() + token + ' '
else:
text += token + ' '
return text.strip()
else:
return ' '.join(tokens)

def decode_batch(self, ids_batch):
"""Convert a batch of token ids back to texts."""
return [self.decode(ids) for ids in ids_batch]

0 comments on commit 303e7f8

Please sign in to comment.