-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add the classification process for EN WIKI #96
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
|
||
import pandas as pd | ||
import torch | ||
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | ||
|
||
from Database.scr.normalize_utils import Logging | ||
|
||
if __name__ == "__main__": | ||
logger = Logging.get_logger("classification training") | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-r", | ||
"--raw_dir", | ||
dest="raw_dir", | ||
help="The directory containing the training data", | ||
type=str, | ||
) | ||
|
||
args = parser.parse_args() | ||
logger.info(f"Passed args: {args}") | ||
# Load the dataset to inspect its structure | ||
filepath = f"{args.raw_dir}/'shuffled_training_dataset.csv'" | ||
df = pd.read_csv(filepath) | ||
|
||
# Manual splitting | ||
train_df = df[:150] # First 100 records for training | ||
validation_df = df[150:250] # Next 100 records for validation | ||
evaluation_df = df[250:] # Last 100 records for evaluation | ||
|
||
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) | ||
# Initialize the tokenizer | ||
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") | ||
|
||
# Ensure the Whole_Text column is a string and drop any rows with NaN values in 'Whole_Text' | ||
train_df = train_df.dropna(subset=["Whole_Text"]).astype({"Whole_Text": "str"}) | ||
validation_df = validation_df.dropna(subset=["Whole_Text"]).astype({"Whole_Text": "str"}) | ||
evaluation_df = evaluation_df.dropna(subset=["Whole_Text"]).astype({"Whole_Text": "str"}) | ||
|
||
# Now, tokenize again with the cleaned data | ||
train_encodings = tokenizer(train_df["Whole_Text"].tolist(), truncation=True, padding=True, max_length=512) | ||
val_encodings = tokenizer(validation_df["Whole_Text"].tolist(), truncation=True, padding=True, max_length=512) | ||
test_encodings = tokenizer(evaluation_df["Whole_Text"].tolist(), truncation=True, padding=True, max_length=512) | ||
|
||
from torch.utils.data import Dataset | ||
|
||
class WikipediaDataset(Dataset): | ||
def __init__(self, encodings, labels): | ||
self.encodings = encodings | ||
self.labels = labels | ||
|
||
def __getitem__(self, idx): | ||
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | ||
item["labels"] = torch.tensor(self.labels[idx]) | ||
return item | ||
|
||
def __len__(self): | ||
return len(self.labels) | ||
|
||
# Create dataset objects | ||
train_dataset = WikipediaDataset(train_encodings, train_df["Relevance"].tolist()) | ||
val_dataset = WikipediaDataset(val_encodings, validation_df["Relevance"].tolist()) | ||
test_dataset = WikipediaDataset(test_encodings, evaluation_df["Relevance"].tolist()) | ||
|
||
from transformers import Trainer, TrainingArguments | ||
|
||
# Define training arguments | ||
training_args = TrainingArguments( | ||
output_dir="./results", # Directory where the results will be saved | ||
num_train_epochs=10, # Total number of training epochs | ||
per_device_train_batch_size=8, # Batch size per device during training | ||
per_device_eval_batch_size=8, # Batch size for evaluation | ||
warmup_steps=500, # Number of warmup steps for learning rate scheduler | ||
weight_decay=0.01, # Strength of weight decay | ||
logging_dir="./logs", # Directory for storing logs | ||
logging_steps=10, | ||
evaluation_strategy="epoch", # Evaluate the model at the end of each epoch | ||
) | ||
|
||
# Initialize the Trainer | ||
trainer = Trainer( | ||
model=model, # The instantiated 🤗 Transformers model to be trained | ||
args=training_args, # Training arguments, defined above | ||
train_dataset=train_dataset, # Training dataset | ||
eval_dataset=val_dataset, # Evaluation dataset | ||
) | ||
|
||
import torch | ||
|
||
trainer.train() | ||
|
||
from transformers import EarlyStoppingCallback | ||
|
||
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2)) | ||
|
||
# Define the path where you want to save the model | ||
# Replace 'my_model' with your desired model name | ||
model_save_path = "./DistilBertForSequenceClassification_WIKI_Natural_disaster" | ||
|
||
# Save the model and tokenizer using the `save_pretrained` method | ||
model.save_pretrained(model_save_path) | ||
tokenizer.save_pretrained(model_save_path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# -*- coding: utf-8 -*- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The script names |
||
|
||
|
||
import argparse | ||
|
||
import pandas as pd | ||
import torch | ||
from torch.nn.functional import softmax | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
# Replace 'your-username/your-model-name' with the actual model path on Hugging Face | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
"liniiiiii/DistilBertForSequenceClassification_WIKI_Natural_disaster" | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained("liniiiiii/DistilBertForSequenceClassification_WIKI_Natural_disaster") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great job pushing the model to HF :) |
||
|
||
|
||
from Database.scr.normalize_utils import Logging | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's much better to keep all imports at the top to make the code easy to read |
||
|
||
if __name__ == "__main__": | ||
logger = Logging.get_logger("classification training") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could help to add more useful logs to both .py files |
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-f", | ||
"--filename", | ||
dest="filename", | ||
help="The name of the csv file containing Wikipedia articles in the <raw_dir> directory", | ||
type=str, | ||
) | ||
parser.add_argument( | ||
"-r", | ||
"--file_dir", | ||
dest="file_dir", | ||
help="The directory containing the raw file and the classified output file", | ||
type=str, | ||
) | ||
args = parser.parse_args() | ||
logger.info(f"Passed args: {args}") | ||
# File paths | ||
wiki_articles_path = f"{args.file_dir}/{args.filename}" # Replace with your file path | ||
|
||
# Load the dataset | ||
df = pd.read_csv(wiki_articles_path) | ||
|
||
# Prepare the model for evaluation | ||
model.eval() | ||
|
||
def classify_text(text): | ||
if not isinstance(text, str): | ||
text = str(text) # Convert to string if not already | ||
|
||
# Split the text into segments of 512 tokens | ||
tokenized_text = tokenizer.encode_plus(text, add_special_tokens=True, truncation=True, max_length=512) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable is never used anywhere. Maybe check why? |
||
|
||
# Process each segment and aggregate results (you can adjust this part) | ||
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | ||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
probabilities = softmax(outputs.logits, dim=-1) | ||
prediction = torch.argmax(probabilities, dim=-1).item() | ||
return prediction | ||
|
||
# Classify each text in the dataset | ||
results = [] | ||
for _, row in df.iterrows(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
prediction = classify_text(row["Whole_Text"]) | ||
results.append({"source": row["Source"], "prediction": prediction}) | ||
|
||
# Create a new DataFrame for the results | ||
results_df = pd.DataFrame(results) | ||
|
||
# Save the results to a new CSV file | ||
results_df.to_csv(f"{args.file_dir}/{args.filename.replace('.csv', '_classified.csv')}", index=False) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,10 @@ | ||||||
*** This is the classification process for English Wikipedia articles related to climate disasters.*** | ||||||
#Files description | ||||||
[] Classfication_wikipedia.py is a script used for training the BERT model, and the training data is shuffled_training_dataset.csv | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||||||
[] DistilBertForSequenceClassification_WIKI_Natural_disaster is our trained model, and in https://huggingface.co/liniiiiii/DistilBertForSequenceClassification_WIKI_Natural_disaster | ||||||
[] wikipedia_dataset_preforclassify_20240229.csv contains all articles we collected using the keywords searching | ||||||
[] Classifier_implement.py is a script to implement the classification model, the command you can refer to use this model is: | ||||||
```shell | ||||||
poetry run python3 BERT_Classification_EN_Wikipedia/Classifier_implement.py --filename wikipedia_dataset_preforclassify_20240229.csv --file_dir BERT_Classification_EN_Wikipedia | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the directory name:
Suggested change
|
||||||
``` | ||||||
It takes long time to run for the all articles we collected, and we recommand to run it for new articles in Wikipedia after day 20240229. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use properly written dates rather than timestamps (which are harder to read) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the name of this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you also mean to push |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good to think about why you want this directory to be in the root folder. Is it related to the Database? Maybe that would be a more appropriate location.