Skip to content

Commit

Permalink
Merge pull request #75 from Living-with-machines/develop
Browse files Browse the repository at this point in the history
version 1.2.0
  • Loading branch information
kasra-hosseini authored Sep 15, 2020
2 parents 9019cff + 675b772 commit 38d8142
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 57 deletions.
127 changes: 76 additions & 51 deletions DeezyMatch/rnn_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- https://medium.com/intel-student-ambassadors/implementing-attention-models-in-pytorch-f947034b3e66
"""

import copy
from datetime import datetime
import glob
import numpy as np
Expand Down Expand Up @@ -96,7 +97,7 @@ def gru_lstm_network(dl_inputs, model_name, train_dc, valid_dc=False, test_dc=Fa
do_validation = 1
else:
do_validation = int(do_validation)

# --- create the model
cprint('[INFO]', bc.dgreen, 'create a two_parallel_rnns model')
model_gru = two_parallel_rnns(main_architecture, vocab_size, embedding_dim, rnn_hidden_dim, output_dim,
Expand Down Expand Up @@ -134,28 +135,14 @@ def gru_lstm_network(dl_inputs, model_name, train_dc, valid_dc=False, test_dc=Fa
model_path=os.path.join(dl_inputs["general"]["models_dir"], model_name),
csv_sep=dl_inputs['preprocessing']["csv_sep"],
map_flag=map_flag,
do_validation=do_validation
do_validation=do_validation,
early_stopping_patience=dl_inputs["gru_lstm"]["early_stopping_patience"],
model_name=model_name
)

# --- save the model
cprint('[INFO]', bc.lgreen, 'saving the model')
model_path = os.path.join(dl_inputs["general"]["models_dir"],
model_name,
model_name + '.model')
if not os.path.isdir(os.path.dirname(model_path)):
os.makedirs(os.path.dirname(model_path))
torch.save(model_gru, model_path)
torch.save(model_gru.state_dict(), model_path + "_state_dict")

"""
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
"""

# --- print some simple stats on the run
print_stats(start_time)

# ------------------- fine_tuning --------------------
def fine_tuning(pretrained_model_path, dl_inputs, model_name,
train_dc, valid_dc=False, test_dc=False):
Expand Down Expand Up @@ -230,32 +217,19 @@ def fine_tuning(pretrained_model_path, dl_inputs, model_name,
model_path=os.path.join(dl_inputs["general"]["models_dir"], model_name),
csv_sep=dl_inputs['preprocessing']["csv_sep"],
map_flag=map_flag,
do_validation=do_validation
do_validation=do_validation,
early_stopping_patience=dl_inputs["gru_lstm"]["early_stopping_patience"],
model_name=model_name
)

# --- save the model
cprint('[INFO]', bc.lgreen, 'saving the model')
model_path = os.path.join(dl_inputs["general"]["models_dir"],
model_name,
model_name + '.model')
if not os.path.isdir(os.path.dirname(model_path)):
os.makedirs(os.path.dirname(model_path))
torch.save(pretrained_model, model_path)
torch.save(pretrained_model.state_dict(), model_path + "_state_dict")

"""
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
"""

# --- print some simple stats on the run
print_stats(start_time)

# ------------------- fit --------------------
def fit(model, train_dl, valid_dl, loss_fn, opt, epochs=3,
pooling_mode='attention', device='cpu',
tboard_path=False, model_path=False, csv_sep="\t", map_flag=False, do_validation=1):
tboard_path=False, model_path=False, csv_sep="\t", map_flag=False, do_validation=1,
early_stopping_patience=False, model_name="default"):

num_batch_train = len(train_dl)
num_batch_valid = len(valid_dl)
Expand All @@ -280,6 +254,11 @@ def fit(model, train_dl, valid_dl, loss_fn, opt, epochs=3,
print_summary = True
wtrain_counter = 0
wvalid_counter = 0

# initialize early stopping parameters
es_loss = False
es_stop = False

for epoch in tnrange(epochs):
if train_dl:
model.train()
Expand Down Expand Up @@ -364,18 +343,39 @@ def fit(model, train_dl, valid_dl, loss_fn, opt, epochs=3,

if valid_dl and (((epoch+1) % do_validation) == 0):
valid_desc = 'Epoch: {}/{}; Valid'.format(epoch+1, epochs)
test_model(model,
valid_dl,
eval_mode="valid",
valid_desc=valid_desc,
pooling_mode=pooling_mode,
device=device,
model_path=model_path,
tboard_writer=tboard_writer,
csv_sep=csv_sep,
epoch=epoch+1,
map_flag=map_flag
)
valid_loss = test_model(model,
valid_dl,
eval_mode="valid",
valid_desc=valid_desc,
pooling_mode=pooling_mode,
device=device,
model_path=model_path,
tboard_writer=tboard_writer,
csv_sep=csv_sep,
epoch=epoch+1,
map_flag=map_flag,
output_loss=True)

if (not es_loss) or (valid_loss <= es_loss):
es_loss = valid_loss
es_model = copy.deepcopy(model)
es_checkpoint = epoch + 1
es_counter = 0
else:
es_counter += 1

if early_stopping_patience:
if es_counter >= early_stopping_patience:
# --- save the model
checkpoint_path = os.path.join(model_path,
model_name + '.model')
if not os.path.isdir(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
cprint('[INFO]', bc.lgreen,
f'saving the model (early stopped) with least valid loss (checkpoint: {es_checkpoint}) at {checkpoint_path}')
torch.save(es_model, checkpoint_path)
torch.save(es_model.state_dict(), checkpoint_path + "_state_dict")
es_stop = True

if model_path:
# --- save the model
Expand All @@ -385,13 +385,30 @@ def fit(model, train_dl, valid_dl, loss_fn, opt, epochs=3,
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(model, checkpoint_path)
torch.save(model.state_dict(), checkpoint_path + "_state_dict")

if es_stop:
cprint('[INFO]', bc.dgreen, 'Early stopping at epoch: {}, selected epoch: {}'.format(epoch+1, es_checkpoint))
return

if model_path and epoch > 0:
# --- save the model with least validation loss
model_path_save = os.path.join(model_path,
model_name + '.model')
if not os.path.isdir(os.path.dirname(model_path_save)):
os.makedirs(os.path.dirname(model_path_save))
cprint(f'[INFO]', bc.lgreen,
f'saving the model with least valid loss (checkpoint: {es_checkpoint}) at {model_path_save}')
torch.save(es_model, model_path_save)
torch.save(es_model.state_dict(), model_path_save + "_state_dict")


# ------------------- test_model --------------------
def test_model(model, test_dl, eval_mode='test', valid_desc=None,
pooling_mode='attention', device='cpu', evaluation=True,
output_state_vectors=False, output_preds=False,
output_preds_file=False, model_path=False, tboard_writer=False,
csv_sep="\t", epoch=1, map_flag=False, print_epoch=True):
csv_sep="\t", epoch=1, map_flag=False, print_epoch=True,
output_loss=False):

model.eval()

Expand Down Expand Up @@ -549,6 +566,9 @@ def test_model(model, test_dl, eval_mode='test', valid_desc=None,
if test_map:
tboard_writer.add_scalar('Test/Map', test_map, epoch)
tboard_writer.flush()

if output_loss:
return test_loss

if output_preds or map_flag:
return all_preds
Expand Down Expand Up @@ -641,6 +661,11 @@ def forward(self, x1_seq, len1, x2_seq, len2, pooling_mode='hstates', device="cp
self.h1, self.c1 = self.init_hidden(x1_seq.size(1), device)
x1_embs_not_packed = self.emb(x1_seq)
x1_embs = pack_padded_sequence(x1_embs_not_packed, len1, enforce_sorted=False)
# To avoid the following issue:
# RNN module weights are not part of single contiguous chunk of memory.
# This means they need to be compacted at every call, possibly greatly increasing memory usage.
# To compact weights again call flatten_parameters().
self.rnn_1.flatten_parameters()
if self.main_architecture.lower() in ["lstm"]:
rnn_out_1, (self.h1, self.c1) = self.rnn_1(x1_embs, (self.h1, self.c1))
elif self.main_architecture.lower() in ["gru", "rnn"]:
Expand Down
6 changes: 6 additions & 0 deletions DeezyMatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def read_input_file(input_file_path):
dl_inputs['general']['device'] = device
cprint('[INFO]', bc.lgreen, 'pytorch will use: {}'.format(dl_inputs['general']['device']))

if not "early_stopping_patience" in dl_inputs["gru_lstm"]:
dl_inputs["gru_lstm"]["early_stopping_patience"] = False

if dl_inputs['gru_lstm']["early_stopping_patience"] <= 0:
dl_inputs['gru_lstm']["early_stopping_patience"] = False

# XXX separation in the input CSV file
# Hardcoded, see issue #38
dl_inputs['preprocessing']['csv_sep'] = "\t"
Expand Down
33 changes: 28 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
<a href="https://pypi.org/project/DeezyMatch/">
<img alt="PyPI" src="https://img.shields.io/pypi/v/DeezyMatch">
</a>
<a href="https://doi.org/10.5281/zenodo.3983555">
<img src="https://zenodo.org/badge/DOI/10.5281/zenodo.3983555.svg" alt="DOI">
</a>
<a href="https://github.com/Living-with-machines/DeezyMatch/blob/master/LICENSE">
<img alt="License" src="https://img.shields.io/badge/License-MIT-yellow.svg">
</a>
Expand Down Expand Up @@ -41,6 +38,7 @@ Table of contents
* [Candidate ranking on-the-fly](#candidate-ranking-on-the-fly)
* [Tips / Suggestions on DeezyMatch functionalities](#tips--suggestions-on-deezymatch-functionalities)
- [Examples on how to run DeezyMatch](./examples)
- [How to cite DeezyMatch](#how-to-cite-deezymatch)
- [Credits](#credits)

## Installation
Expand Down Expand Up @@ -401,26 +399,32 @@ Summary of the arguments/flags:
---
This command generates a figure `log_test001.png` and stores it in `models/test001` directory.
This command generates a figure `log_t001.png` and stores it in `models/test001` directory.
<p align="center">
<img src="https://raw.githubusercontent.com/Living-with-machines/DeezyMatch/master/figs/log_t001.png" alt="Example output of plot_log module" width="100%">
</p>
DeezyMatch stores models, vocabularies, input file, log file and checkpoints (for each epoch) in the following directory structure:
DeezyMatch stores models, vocabularies, input file, log file and checkpoints (for each epoch) in the following directory structure (unless `validation` option in the input file is not equal to 1). When DeezyMatch finishes the last epoch, it will save the model with least validation loss as well (`test001.model` in the following directory structure). Morevoer, DeezyMatch has an `early stopping` functionality. This can be activated by setting the `early_stopping_patience` option in the input file. This option specifies the number of epochs with no improvement after which training will be stopped and the model with the least validation loss will be saved.
```bash
models/
└── test001
├── checkpoint00001.model
├── checkpoint00001.model_state_dict
├── checkpoint00002.model
├── checkpoint00002.model_state_dict
├── checkpoint00003.model
├── checkpoint00003.model_state_dict
├── checkpoint00004.model
├── checkpoint00004.model_state_dict
├── checkpoint00005.model
├── checkpoint00005.model_state_dict
├── input_dfm.yaml
├── log_t001.png
├── log.txt
├── test001.model
├── test001.model_state_dict
└── test001.vocab
```
Expand Down Expand Up @@ -1021,6 +1025,25 @@ This adaptive search algorithm significantly reduces the computation time to fin
In most use cases, `search_size` can be set `>= num_candidates`. However, if `num_candidates` is very large, it is better to set the `search_size` to lower values. Let's clarify this in an example. First, assume `num_candidates=4` (number of desired candidates is 4 for each query). If we set the `search_size` to values less than 4, let's say, 2. DeezyMatch needs to do at least two iterations. In the first iteration, it looks at the closest 2 candidate vectors (as `search_size` is 2). In the second iteration, candidate vectors 3 and 4 will be examined. So two iterations. Another choice is `search_size=4`. Here, DeezyMatch looks at 4 candidates in the first iteration, if they pass the threshold, the process concludes. If not, it will seach candidates 5-8 in the next iteration. Now, let's assume `num_candidates=1001` (i.e., number of desired candidates is 1001 for each query). If we set the `search_size=1000`, DeezyMatch has to search at least 2000 candidates (2 x 1000 `search_size`). If we set `search_size=100`, this time, DeezyMatch has to search at least 1100 candidates (11 x 100 `search_size`). So 900 vectors less. In the end, it is a trade-off between iterations and `search_size`.
## How to cite DeezyMatch
Please consider acknowledging DeezyMatch if it helps you to obtain results and figures for publications or presentations, by citing:
```text
Hosseini, Nanni and Coll Ardanuy (2020), DeezyMatch: A Flexible Deep Learning Approach to Fuzzy String Matching, EMNLP: System Demonstrations.
```
and in BibTeX:
```bibtex
@inproceedings{hosseini2020deezy,
title={DeezyMatch: A Flexible Deep Learning Approach to Fuzzy String Matching},
author={Hosseini, Kasra and Nanni, Federico and Coll Ardanuy, Mariona},
booktitle={EMNLP: System Demonstrations},
year={2020}
}
```
## Credits
This project extensively uses the ideas/neural-network-architecture published in https://github.com/ruipds/Toponym-Matching.
5 changes: 5 additions & 0 deletions inputs/input_dfm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ gru_lstm:
# shuffle when creating DataLoader
dl_shuffle: True
random_seed: 123
# Early stopping:
# Number of epochs with no improvement after which training will be stopped and
# the model with the least validation loss will be saved
# If 0 or negative, early stopping will be deactivated
early_stopping_patience: -1

# if -1 or 1, perform the validation step in every epoch;
# if 0, no validation will be done
Expand Down
5 changes: 5 additions & 0 deletions inputs/input_dfm_notebook_001.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ gru_lstm:
# shuffle when creating DataLoader
dl_shuffle: True
random_seed: 123
# Early stopping:
# Number of epochs with no improvement after which training will be stopped and
# the model with the least validation loss will be saved
# If 0 or negative, early stopping will be deactivated
early_stopping_patience: -1

# if -1 or 1, perform the validation step in every epoch;
# if 0, no validation will be done
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="DeezyMatch",
version="1.1.0",
version="1.2.0",
description="A Flexible Deep Learning Approach to Fuzzy String Matching",
author=u"The LwM Development Team",
#author_email="",
Expand Down

0 comments on commit 38d8142

Please sign in to comment.