Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
sess = requests.Session()

def parse_args():
""" Parse arguments """
parser = argparse.ArgumentParser()
parser.add_argument('--link', '-l', type=str, required=True, help='Share link of Tsinghua Cloud')
parser.add_argument('--password', '-p', type=str, default='', help='Password of the share link')
Expand All @@ -16,6 +17,7 @@ def parse_args():
return parser.parse_args()

def get_share_key(url):
""" Get share key from share link """
prefix = 'https://cloud.tsinghua.edu.cn/d/'
if not url.startswith(prefix):
raise ValueError('Share link of Tsinghua Cloud should start with {}'.format(prefix))
Expand All @@ -26,6 +28,7 @@ def get_share_key(url):


def dfs_search_files(share_key: str, path="/"):
""" DFS search all files in the share link """
global sess
filelist = []
print('https://cloud.tsinghua.edu.cn/api/v2.1/share-links/{}/dirents/?path={}'.format(share_key, path))
Expand All @@ -40,6 +43,7 @@ def dfs_search_files(share_key: str, path="/"):
return filelist

def download_single_file(url: str, fname: str):
""" Download single file """
global sess
resp = sess.get(url, stream=True)
total = int(resp.headers.get('content-length', 0))
Expand All @@ -58,6 +62,7 @@ def download_single_file(url: str, fname: str):
bar.update(size)

def download(url, save_dir):
""" Download all files in the share link """
share_key = get_share_key(url)

print("Searching for files to be downloaded...")
Expand Down Expand Up @@ -101,6 +106,7 @@ def download(url, save_dir):
return flag

def make_data(sample):
""" Make data for training """
src = ""
for ix, ref in enumerate(sample['references']):
src += "Reference [%d]: %s\\" % (ix+1, ref)
Expand Down
8 changes: 8 additions & 0 deletions train_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from torch.utils.data.distributed import DistributedSampler

class QuestionReferenceDensity(torch.nn.Module):
""" Question Reference Density Model """

def __init__(self) -> None:
""" Initialize the model """
super().__init__()
self.question_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco")
self.reference_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco")
Expand All @@ -17,12 +20,14 @@ def __init__(self) -> None:
print("Number of parameter: %.2fM" % (total / 1e6))

def mean_pooling(self, token_embeddings, mask):
""" Mean Pooling """
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
return sentence_embeddings


def forward(self, question, pos, neg):
""" Forward """
global args

q = self.question_encoder(**question)
Expand All @@ -41,6 +46,7 @@ def forward(self, question, pos, neg):

@staticmethod
def loss(l_pos, l_neg):
""" Loss """
return torch.nn.functional.cross_entropy(torch.cat([l_pos, l_neg], dim=1), torch.arange(0, len(l_pos), dtype=torch.long, device=args.device))

@staticmethod
Expand All @@ -53,6 +59,7 @@ def acc(l_pos, l_neg):


class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
""" Linear warmup and then linear decay. """
def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
self.warmup = warmup
self.total = total
Expand Down Expand Up @@ -104,6 +111,7 @@ def save(name):
model.reference_encoder.save_pretrained(os.path.join(log_dir, name, "reference_encoder"))

def train(max_epoch = 10, eval_step = 200, save_step = 400, print_step = 50):
""" Train the model """
step = 0
for epoch in range(0, max_epoch):
print("EPOCH %d"%epoch)
Expand Down
3 changes: 2 additions & 1 deletion web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@

"""

def query(query: str):
def query(query: str):
""" Query the model """

refs = []
answer = "Loading ..."
Expand Down