-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e37ee4b
commit 9d2a4fe
Showing
28 changed files
with
635 additions
and
230 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"buildTargets":[],"launchTargets":[],"customConfigurationProvider":{"workspaceBrowse":{"browsePath":[],"compilerArgs":[]},"fileIndex":[]}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
make.exe --dry-run --always-make --keep-going --print-directory | ||
'make.exe' is not recognized as an internal or external command, | ||
operable program or batch file. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"makefile.extensionOutputFolder": "./.vscode" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
make.exe all --print-data-base --no-builtin-variables --no-builtin-rules --question | ||
'make.exe' is not recognized as an internal or external command, | ||
operable program or batch file. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
|
||
|
||
|
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from sklearn.model_selection import train_test_split | ||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
import torch | ||
from torch import nn, optim | ||
from torch.utils.data import Dataset, DataLoader | ||
import torch.nn.functional as F | ||
import warnings | ||
warnings.filterwarnings('ignore') | ||
import sklearn | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import confusion_matrix,classification_report | ||
from collections import defaultdict | ||
from textwrap import wrap | ||
from joblib import load, dump | ||
import pickle | ||
from tqdm import tqdm | ||
import transformers | ||
import datetime | ||
import matplotlib.pylab as pylab | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
from sklearn.decomposition import PCA | ||
from scipy.stats import energy_distance | ||
from fastdist import fastdist | ||
df_imdb = pd.read_csv('.//data//IMDB.csv') | ||
# df_imdb = df_imdb.sample(2000) | ||
df_imdb.reset_index(drop=True,inplace=True) | ||
sns.countplot(df_imdb.sentiment) | ||
plt.ylabel('Samples') | ||
plt.xlabel('IMDB Movie Sentiments') | ||
plt.show() | ||
# sns.countplot(df_embeddings.predicted_raw_difference) | ||
df = df_imdb | ||
df_profile = df_imdb | ||
df_profile.columns = ['number','doc', 'labels_original'] | ||
df_profile['labels'] = df_profile['labels_original'] | ||
|
||
le = LabelEncoder() | ||
df_profile['labels']= le.fit_transform(df_profile['labels']) | ||
|
||
# X = df_profile.review | ||
X = df_profile.doc | ||
y = df_profile.labels | ||
# z = df_profile.user_name | ||
X_train,X_test,y_train,y_test= train_test_split(X,y,stratify=y,test_size=0.2, random_state=47) | ||
print('number of training samples:', len(X_train)) | ||
print('number of test samples:', len(X_test)) | ||
train_df = pd.DataFrame({'doc':X_train, | ||
'labels':y_train}) | ||
test_df = pd.DataFrame({'doc':X_test, | ||
'labels':y_test}) | ||
train_df.reset_index(drop=True,inplace=True) | ||
test_df.reset_index(drop=True,inplace=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# ============================================================================= | ||
# Main file | ||
# ============================================================================= | ||
from src import utils | ||
from data import dataloader | ||
from src import metric,model,plot,trainer | ||
|
||
|
||
def run(): | ||
"""Builds model, loads data, trains and evaluates""" | ||
model = Protoformer('twitter-uni') | ||
# DistilBERT(twitter-uni) BERT(imdb) RoBERTa(arxiv-10) | ||
model.load_data('twitter-uni') | ||
# twitter-uni, imdb, arxiv-10 | ||
model.build() | ||
model.train() | ||
model.evaluate() | ||
|
||
if __name__ == '__main__': | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,13 @@ | ||
# local package | ||
-e . | ||
|
||
# external requirements | ||
click | ||
Sphinx | ||
coverage | ||
awscli | ||
flake8 | ||
python-dotenv>=0.5.1 | ||
shap==0.35.0 | ||
shap | ||
numpy | ||
pandas | ||
matplotlib | ||
transformers | ||
sentence-transformers | ||
torch==1.8.1 | ||
|
||
torch | ||
fastdist | ||
sklearn | ||
seaborn | ||
pickle | ||
joblib | ||
tqdm | ||
pkbar | ||
fastdist |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# ============================================================================= | ||
# Evaluation metrics | ||
# ============================================================================= | ||
from sklearn.metrics import f1_score,classification_report | ||
def acc_cal(big_idx, targets): | ||
n_correct = (big_idx==targets).sum().item() | ||
return n_correct |
Oops, something went wrong.