diff --git a/Project/backend/codebase/graph_analysis/graph_analysis.py b/Project/backend/codebase/graph_analysis/graph_analysis.py index 6c278cb..42b909e 100644 --- a/Project/backend/codebase/graph_analysis/graph_analysis.py +++ b/Project/backend/codebase/graph_analysis/graph_analysis.py @@ -108,7 +108,7 @@ def analyze_graph_structure(G): - Here, node 0, 1 (1.0) has the highest closeness centrality because it is connected to all other nodes (node 2, 3 = 0.75) - Closeness Centrality show the average distance of a node to all other nodes in the network """ - n = 20 # Number of top nodes to return + n = 20 if num_nodes > 20 else 5 # Number of top nodes to return # Calculate centrality measures degree_centrality = get_top_n_central_nodes(nx.degree_centrality(G), n) betweenness_centrality = get_top_n_central_nodes(nx.betweenness_centrality(G), n) diff --git a/Project/backend/codebase/graph_creator/embedding_handler.py b/Project/backend/codebase/graph_creator/embedding_handler.py index 67782ee..a7ef024 100644 --- a/Project/backend/codebase/graph_creator/embedding_handler.py +++ b/Project/backend/codebase/graph_creator/embedding_handler.py @@ -1,299 +1,216 @@ import pandas as pd import os +import logging +from graph_creator.models.graph_job import GraphJob from sentence_transformers import SentenceTransformer from langchain_community.vectorstores import FAISS import pickle from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import pdist, cosine import numpy as np +from sklearn.exceptions import NotFittedError +class embeddings_handler: -def save_data( - graph_dir, graph_id, vector_store, embedding_dict, merged_nodes, node_to_merged -): - for name, data in zip( - ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"], - [vector_store, embedding_dict, merged_nodes, node_to_merged], - ): - with open(os.path.join(graph_dir, f"{graph_id}_{name}.pkl"), "wb") as f: - pickle.dump(data, f) - - -def load_data(graph_dir, graph_id): - return [ - pickle.load(open(os.path.join(graph_dir, f"{graph_id}_{name}.pkl"), "rb")) - for name in ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"] - ] - - -def generate_embeddings_and_merge_duplicates( - data, - graph_id, - model_name="xlm-r-bert-base-nli-stsb-mean-tokens", - save_dir="Project/backend/codebase/embeddings", - threshold=0.2, -): - """ - Generates embeddings for nodes in the given data and merges duplicate nodes based on a threshold. - - Args: - data (pd.DataFrame): The input data containing 'node_1', 'node_2', and 'edge' columns. - model_name (str, optional): The name of the pre-trained model to use for generating embeddings. Defaults to 'xlm-r-bert-base-nli-stsb-mean-tokens'. - save_dir (str, optional): The directory to save the generated embeddings and other files. Defaults to 'embeddings'. - threshold (float, optional): The threshold value for hierarchical clustering. Nodes with a cosine distance below this threshold will be merged. Defaults to 0.2. - - Returns: - tuple: A tuple containing the following elements: - - embedding_dict (dict): A dictionary mapping nodes to their corresponding embeddings. - - merged_nodes (dict): A dictionary mapping merged node names to the original nodes in each cluster. - - merged_df (pd.DataFrame): A DataFrame containing the merged data with updated node names. - - vector_store (FAISS): The FAISS index created with original embeddings mapped to merged nodes. - - model (SentenceTransformer): The SentenceTransformer model used for generating embeddings. - - node_to_merged (dict): A dictionary mapping original nodes to their corresponding merged node names. - """ - - # Create a directory for the graph if it doesn't exist - graph_dir = os.path.join(save_dir, graph_id) - os.makedirs(graph_dir, exist_ok=True) - - all_nodes = pd.concat([data["node_1"], data["node_2"]]).unique() - model = SentenceTransformer(model_name) + def __init__(self, g_job: GraphJob, lazyLoad=False): + # Get graph (document) uuid + self.graph_id = g_job.id - embeddings = model.encode(all_nodes) - embedding_dict = {node: emb for node, emb in zip(all_nodes, embeddings)} + # Store embeddings in directory + self.save_dir = ".media/embeddings" - # Hierarchical Clustering - distance_matrix = pdist(embeddings, "cosine") - Z = linkage(distance_matrix, "ward") - labels = fcluster(Z, threshold, criterion="distance") + # Model used for embedding + self.model_name = "all-MiniLM-L6-v2" - merged_nodes = {} - node_to_merged = {} - # for label in set(labels): - # cluster = [all_nodes[i] for i in range(len(all_nodes)) if labels[i] == label] - # merged_name = "_".join(cluster) - # # merged_name = max(cluster, key=lambda x: sum(1 for c in x if c.isupper()) / len(x)) # Choose the longest node with the most uppercase characters - # merged_nodes[merged_name] = cluster - # for node in cluster: - # node_to_merged[node] = merged_name + # Ensure the embeddings directory exists + self.graph_dir = os.path.join(self.save_dir, str(self.graph_id)) # Convert UUID to string - # The under the command find the average embedding of the cluster and assign the node with the smallest cosine distance to the average embedding as the representative node - for label in set(labels): - cluster = [all_nodes[i] for i in range(len(all_nodes)) if labels[i] == label] - average_embedding = np.mean([embedding_dict[node] for node in cluster], axis=0) - representative_node = min( - cluster, key=lambda node: cosine(embedding_dict[node], average_embedding) + # Check if Graph already embedded + os.makedirs(self.graph_dir, exist_ok=True) + self.isEmbedded = os.path.isdir(self.graph_dir) and all( + os.path.isfile(os.path.join(self.graph_dir, f"{self.graph_id}_{name}.pkl")) + for name in ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"] ) - merged_nodes[representative_node] = cluster - for node in cluster: - node_to_merged[node] = representative_node - - # Update edges in the graph and avoid duplicate edges - seen_edges = set() - merged_data = [] - for _, row in data.iterrows(): - node_1 = node_to_merged.get(row["node_1"], row["node_1"]) - node_2 = node_to_merged.get(row["node_2"], row["node_2"]) - edge = row["edge"] - edge_tuple = (node_1, node_2, edge) - if node_1 != node_2 and edge_tuple not in seen_edges: - merged_data.append( + self.embeddings = self.load_data() if self.isEmbedded and not lazyLoad else None + + def delete_embeddings(self): + files = [os.path.join(self.graph_dir, f"{self.graph_id}_{name}.pkl") + for name in ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"]] + for file in files: + if os.path.exists(file): + os.remove(file) + if os.path.exists(self.graph_dir): + os.rmdir(self.graph_dir) + + def is_embedded(self): + return self.isEmbedded + + def save_data(self, vector_store, embedding_dict, merged_nodes, node_to_merged): + """ + Serialize and make variables of embedding step persistant + + Args: + vector_store : langchain_community.vectorstores.faiss.FAISS + embedding_dict : dict + merged_nodes : dict + node_to_merged : dict + """ + # store dictionaries + for name, data in zip( + ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"], + [vector_store, embedding_dict, merged_nodes, node_to_merged], + ): + with open(os.path.join(self.graph_dir, f"{self.graph_id}_{name}.pkl"), "wb") as f: + pickle.dump(data, f) + + def load_data(self) -> list[FAISS, dict, dict, dict]: + loaded_data = [] + for name in ["faiss_index", "embedding_dict", "merged_nodes", "node_to_merged"]: + file_path = os.path.join(self.graph_dir, f"{self.graph_id}_{name}.pkl") + #print(f"Loading {file_path}") + try: + with open(file_path, "rb") as f: + data = pickle.load(f) + loaded_data.append(data) + #print(f"Loaded {name}: {data}") + except Exception as e: + print(f"Error loading {file_path}: {e}") + loaded_data.append(None) + return loaded_data + + def generate_embeddings_and_merge_duplicates( + self, + data, + threshold=0.2, + ): + """ + Generates embeddings for nodes in the given data and merges duplicate nodes based on a threshold. + + Args: + data (pd.DataFrame): The input data containing 'node_1', 'node_2', and 'edge' columns. + model_name (str, optional): The name of the pre-trained model to use for generating embeddings. Defaults to 'xlm-r-bert-base-nli-stsb-mean-tokens'. + save_dir (str, optional): The directory to save the generated embeddings and other files. Defaults to 'embeddings'. + threshold (float, optional): The threshold value for hierarchical clustering. Nodes with a cosine distance below this threshold will be merged. Defaults to 0.2. + + Returns: + tuple: A tuple containing the following elements: + - embedding_dict (dict): A dictionary mapping nodes to their corresponding embeddings. + - merged_nodes (dict): A dictionary mapping merged node names to the original nodes in each cluster. + - merged_df (pd.DataFrame): A DataFrame containing the merged data with updated node names. + - vector_store (FAISS): The FAISS index created with original embeddings mapped to merged nodes. + - model (SentenceTransformer): The SentenceTransformer model used for generating embeddings. + - node_to_merged (dict): A dictionary mapping original nodes to their corresponding merged node names. + """ + # Debug: Print the DataFrame columns + print("DataFrame Columns:", data.columns) + + # Ensure columns are as expected + expected_columns = ["node_1", "node_2", "edge", "chunk_id", "topic_node_1", "topic_node_2"] + for col in expected_columns: + if col not in data.columns: + raise ValueError(f"Missing expected column: {col}") + + #work on copy of dataframe + data = data.copy() + + all_nodes = pd.concat([data["node_1"], data["node_2"]]).unique() + model = SentenceTransformer(self.model_name) + + embeddings = model.encode(all_nodes) + embedding_dict = {node: emb for node, emb in zip(all_nodes, embeddings)} + + # Hierarchical Clustering + distance_matrix = pdist(embeddings, "cosine") + Z = linkage(distance_matrix, "ward") + labels = fcluster(Z, threshold, criterion="distance") + + merged_nodes = {} + node_to_merged = {} + + # The under the command find the average embedding of the cluster and assign the node with the smallest cosine distance to the average embedding as the representative node + for label in set(labels): + cluster = [all_nodes[i] for i in range(len(all_nodes)) if labels[i] == label] + average_embedding = np.mean([embedding_dict[node] for node in cluster], axis=0) + representative_node = min( + cluster, key=lambda node: cosine(embedding_dict[node], average_embedding) + ) + merged_nodes[representative_node] = cluster + for node in cluster: + node_to_merged[node] = representative_node + + # Update edges in the graph and avoid duplicate edges + seen_edges = set() + merged_data = [] + for _, row in data.iterrows(): + node_1 = node_to_merged.get(row["node_1"], row["node_1"]) + node_2 = node_to_merged.get(row["node_2"], row["node_2"]) + edge = row["edge"] + edge_tuple = (node_1, node_2, edge) + if node_1 != node_2 and edge_tuple not in seen_edges: + merged_data.append( + { + "node_1": node_1, + "node_2": node_2, + "edge": edge, + "chunk_id": row["chunk_id"], + "topic_node_1": row["topic_node_1"], + "topic_node_2": row["topic_node_2"], + "original_Node_1": row["node_1"], + "original_Node_2": row["node_2"], + } + ) + seen_edges.add(edge_tuple) + + merged_df = pd.DataFrame(merged_data).drop_duplicates() + + # Create FAISS index with original embeddings, but map to merged nodes + vector_store = FAISS.from_embeddings( + [(node_to_merged[node], emb) for node, emb in embedding_dict.items()], + embedding=model, + ) + try: + self.save_data( + vector_store, embedding_dict, merged_nodes, node_to_merged + ) + except Exception as e: + logging.error(e) + + return merged_df + + def search_graph(self, query, k=20): + + if not self.isEmbedded: + logging.error("No embeddings found!") + return None + + # Load the model + model = SentenceTransformer(self.model_name) + vector_store, embedding_dict, merged_nodes, node_to_merged = self.embeddings + + query_embedding = model.encode([query])[0] + results = vector_store.similarity_search_with_score_by_vector(query_embedding, k=k) + similar_nodes = [] + visited_nodes = ( + set() + ) # Set to track which merged nodes have been added to the result + + for doc, score in results: + merged_node = doc.page_content + if merged_node in visited_nodes: + continue # Skip this node if it has already been added + + original_nodes = merged_nodes[merged_node] + similarities = [ + 1 - cosine(query_embedding, embedding_dict[node]) for node in original_nodes + ] + avg_similarity = np.mean(similarities) + similar_nodes.append( { - "node_1": node_1, - "node_2": node_2, - "edge": edge, - "original_node_1": row["node_1"], - "original_node_2": row["node_2"], + "merged_node": merged_node, + "original_nodes": original_nodes, + "similarity": avg_similarity, + "individual_similarities": dict(zip(original_nodes, similarities)), } ) - seen_edges.add(edge_tuple) - - merged_df = pd.DataFrame(merged_data).drop_duplicates() - - # Create FAISS index with original embeddings, but map to merged nodes - vector_store = FAISS.from_embeddings( - [(node_to_merged[node], emb) for node, emb in embedding_dict.items()], - embedding=model.encode, - ) - - save_data( - graph_dir, graph_id, vector_store, embedding_dict, merged_nodes, node_to_merged - ) - - return embedding_dict, merged_nodes, merged_df, vector_store, model, node_to_merged - - -def search_graph(query, graph_id, save_dir="Project/backend/codebase/embeddings", k=20): - # Load the model - model_name = "xlm-r-bert-base-nli-stsb-mean-tokens" - model = SentenceTransformer(model_name) - vector_store, embedding_dict, merged_nodes, node_to_merged = load_data( - os.path.join(save_dir, graph_id), graph_id - ) - - query_embedding = model.encode([query])[0] - results = vector_store.similarity_search_with_score(query, k=k) - similar_nodes = [] - visited_nodes = ( - set() - ) # Set to track which merged nodes have been added to the result - - for doc, score in results: - merged_node = doc.page_content - if merged_node in visited_nodes: - continue # Skip this node if it has already been added - - original_nodes = merged_nodes[merged_node] - similarities = [ - 1 - cosine(query_embedding, embedding_dict[node]) for node in original_nodes - ] - avg_similarity = np.mean(similarities) - similar_nodes.append( - { - "merged_node": merged_node, - "original_nodes": original_nodes, - "similarity": avg_similarity, - "individual_similarities": dict(zip(original_nodes, similarities)), - } - ) - visited_nodes.add(merged_node) # Mark this node as visited - return similar_nodes - - -def main(): - # Example data - data = pd.DataFrame( - { - "node_1": [ - "Name", - "Straße", - "Wohnort", - "Geburtsname", - "Geburtsort", - "Adresse", - "Geburtsdatum", - "Name", - "Strasse", - "wohnort", - "GeburtsName", - "geburtsort", - "Anschrift", - "GeburtsDaten", - "Namen", - "Straßen", - "Wohnorte", - "Geburtsnamen", - "Geburtsorte", - "Adressen", - "Geburtsdaten", - "Vorname", - "Hausnummer", - "PLZ", - "Mädchenname", - "Geburtsstadt", - "Anschrift", - "Geburtsjahr", - ], - "node_2": [ - "Vorname", - "Hausnummer", - "PLZ", - "Mädchenname", - "Geburtsstadt", - "Anschrift", - "Geburtsjahr", - "Spitzname", - "hausnummer", - "Postleitzahl", - "MädchenName", - "GeburtsStadt", - "Adresse", - "GeburtsJahr", - "Vornamen", - "Hausnummern", - "PLZs", - "Mädchennamen", - "Geburtsstädte", - "Adressen", - "Geburtsjahre", - "Nickname", - "HouseNumber", - "PostalCode", - "MaidenName", - "BirthCity", - "Address", - "YearOfBirth", - ], - "edge": [ - "related_as_personal_details", - "located_at", - "located_in", - "related_as_birth_details", - "related_as_place_of_birth", - "located_at", - "related_as_birth_details", - "related_as_personal_details", - "located_at", - "located_in", - "related_as_birth_details", - "related_as_place_of_birth", - "located_at", - "related_as_birth_details", - "related_as_personal_details", - "located_at", - "located_in", - "related_as_birth_details", - "related_as_place_of_birth", - "located_at", - "related_as_birth_details", - "related_as_personal_details", - "located_at", - "located_in", - "related_as_birth_details", - "related_as_place_of_birth", - "located_at", - "related_as_birth_details", - ], - } - ) - - # Process the data - model_name = "xlm-r-bert-base-nli-stsb-mean-tokens" - threshold = 0.2 - graph_id = "example_graph" - embedding_dict, merged_nodes, updated_data, vector_store, model, node_to_merged = ( - generate_embeddings_and_merge_duplicates( - data, graph_id, model_name=model_name, threshold=threshold - ) - ) - - # Output results - print("Process completed.") - print(f"Number of merged nodes: {len(merged_nodes)}") - print(f"Number of embeddings: {len(embedding_dict)}") - print("Merged Nodes:") - - for merged, original in merged_nodes.items(): - print(f" {merged}: {original}") - - print("\nUpdated DataFrame:") - print(updated_data) - print("\n") - # Example of a future search - query = "What is the name at birth?" - - results = search_graph(query, graph_id, k=20) - - print(f"Similar nodes to '{query}':") - for result in results: - print(f" Merged Node: {result['merged_node']}") - print(f" Original Nodes: {result['original_nodes']}") - print(f" Average Similarity: {result['similarity']:.4f}") - print(" Individual Similarities:") - for node, sim in result["individual_similarities"].items(): - print(f" {node}: {sim:.4f}") - print() - - -if __name__ == "__main__": - main() + visited_nodes.add(merged_node) # Mark this node as visited + return similar_nodes \ No newline at end of file diff --git a/Project/backend/codebase/graph_creator/graph_creator_main.py b/Project/backend/codebase/graph_creator/graph_creator_main.py index edae7b9..59ee706 100644 --- a/Project/backend/codebase/graph_creator/graph_creator_main.py +++ b/Project/backend/codebase/graph_creator/graph_creator_main.py @@ -1,6 +1,7 @@ import logging - +from graph_creator.embedding_handler import embeddings_handler from graph_creator import graph_handler +import os from graph_creator.services.llm.llama_gemini_combination import llama_gemini_combination from graph_creator.models.graph_job import GraphJob from graph_creator.services import netx_graphdb @@ -88,6 +89,12 @@ def create_and_store_graph(uuid, entities_and_relations, chunks, llm_handler): chunks[i] = chunks[i].dict() combined = graph_handler.connect_with_llm(df_e_and_r, chunks, llm_handler) + # Create an instance of the embeddings handler + embeddings_handler_instance = embeddings_handler(GraphJob(id=uuid)) + + # Generate embeddings and merge duplicates + combined = embeddings_handler_instance.generate_embeddings_and_merge_duplicates(combined) + # get graph db service graph_db_service = netx_graphdb.NetXGraphDB() diff --git a/Project/backend/codebase/graph_creator/graph_handler.py b/Project/backend/codebase/graph_creator/graph_handler.py index 0788915..3ed4dd5 100644 --- a/Project/backend/codebase/graph_creator/graph_handler.py +++ b/Project/backend/codebase/graph_creator/graph_handler.py @@ -134,7 +134,8 @@ def index_entity_relation_table(entity_and_relation_df, entities): entities_dict[entities[i]] = i relations = [] - for i, row in entity_and_relation_df.iterrows(): + entity_and_relation_df_withoutna = entity_and_relation_df.dropna() + for i, row in entity_and_relation_df_withoutna.iterrows(): relations.append([entities_dict[row["node_1"]], entities_dict[row["node_2"]]]) return entities_dict, relations @@ -213,7 +214,8 @@ def get_entities_by_chunk(entity_and_relation_df, entities_dict): A dictionary containing all entities per chunk as ids """ entities_by_chunk = {} - for i, row in entity_and_relation_df.iterrows(): + entity_and_relation_df_withoutna = entity_and_relation_df.dropna() + for i, row in entity_and_relation_df_withoutna.iterrows(): if row["chunk_id"] in entities_by_chunk: entities_by_chunk[row["chunk_id"]].append(entities_dict[row["node_1"]]) entities_by_chunk[row["chunk_id"]].append(entities_dict[row["node_2"]]) @@ -333,15 +335,18 @@ def add_relations_to_data(entity_and_relation_df, new_relations): """ for relation in new_relations: - node_1 = relation["node_1"] - node_2 = relation["node_2"] - edge = relation["edge"] - chunk_id = relation["chunk_id"] - - pos = len(entity_and_relation_df.index) - entity_and_relation_df.loc[pos] = [node_1, node_2, edge, chunk_id] - - return entity_and_relation_df + try: + node_1 = relation["node_1"] + node_2 = relation["node_2"] + edge = relation["edge"] + chunk_id = relation["chunk_id"] + + pos = len(entity_and_relation_df.index) + entity_and_relation_df.loc[pos] = [node_1, node_2, edge, chunk_id] + except ValueError: + print(f"Error in add_relations_to_data: ,", node_1, node_2, edge, chunk_id) + pass + return entity_and_relation_df.dropna() def add_topic(data: pd.DataFrame, max_topics: int = 25) -> pd.DataFrame: diff --git a/Project/backend/codebase/graph_creator/router.py b/Project/backend/codebase/graph_creator/router.py index aff172f..39d734e 100644 --- a/Project/backend/codebase/graph_creator/router.py +++ b/Project/backend/codebase/graph_creator/router.py @@ -1,3 +1,4 @@ +import json import logging import os import uuid @@ -7,6 +8,8 @@ from fastapi import UploadFile, File, HTTPException from starlette.responses import JSONResponse +from graph_creator.embedding_handler import embeddings_handler +from graph_creator.schemas.graph_query import QueryRequest import graph_creator.graph_creator_main as graph_creator_main from graph_creator.dao.graph_job_dao import GraphJobDAO from graph_creator.schemas.graph_job import GraphJobCreate @@ -193,6 +196,9 @@ async def delete_graph_job( graph_job_id = graph_job.id await graph_job_dao.delete_graph_job(graph_job) netx_services.delete_graph(graph_job_id) + graphEmbeddingsHandler = embeddings_handler(graph_job, lazyLoad=True) + graphEmbeddingsHandler.delete_embeddings() + @router.post("/create_graph/{graph_job_id}") @@ -298,3 +304,54 @@ async def query_graph( graph = netx_services.load_graph(graph_job_id=graph_job_id) graph_keywords = analyze_graph_structure(graph) return graph_keywords + + +@router.post("/graph_search/{graph_job_id}") +async def query_graph( + graph_job_id: uuid.UUID, + request: QueryRequest, + graph_job_dao: GraphJobDAO = Depends(), +): + """ + Reads a graph job by id and tries to answer a query about the graph using embeddings + + Args: + graph_job_id (uuid.UUID): ID of the graph job to be read. + request (QueryRequest): contains user query + graph_job_dao (GraphJobDAO): graph job database access object + + Returns: + Answer to question from the user regarding the graph + + Raises: + HTTPException: If there is no graph job with the given ID. + """ + + g_job = await graph_job_dao.get_graph_job_by_id(graph_job_id) + + if not g_job: + raise HTTPException(status_code=404, detail="Graph job not found") + if g_job.status != GraphStatus.GRAPH_READY: + raise HTTPException( + status_code=400, + detail="No graph created for this job!", + ) + + user_query = request.query + #print(f"Received query: {user_query}") + + graphEmbeddingsHandler = embeddings_handler(g_job) + + if graphEmbeddingsHandler.is_embedded(): + #do search + result = graphEmbeddingsHandler.search_graph(user_query, k=4) + #print(result) + answer = json.dumps(result) + else: + #can't answer because no embeddings exist + answer = 'No embeddings found' + + return JSONResponse( + content={"answer": answer}, + status_code=200, + ) diff --git a/Project/backend/codebase/graph_creator/schemas/graph_query.py b/Project/backend/codebase/graph_creator/schemas/graph_query.py new file mode 100644 index 0000000..5b975f9 --- /dev/null +++ b/Project/backend/codebase/graph_creator/schemas/graph_query.py @@ -0,0 +1,4 @@ +from pydantic import BaseModel + +class QueryRequest(BaseModel): + query: str \ No newline at end of file diff --git a/Project/backend/codebase/graph_creator/services/file_handler.py b/Project/backend/codebase/graph_creator/services/file_handler.py index 787fa39..ed19dc3 100644 --- a/Project/backend/codebase/graph_creator/services/file_handler.py +++ b/Project/backend/codebase/graph_creator/services/file_handler.py @@ -2,10 +2,11 @@ import os from pathlib import Path -from langchain_community.document_loaders import PyPDFLoader +from langchain_community.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import Docx2txtLoader from langchain_community.document_loaders import UnstructuredPowerPointLoader +from langchain_core.documents import Document from langchain_text_splitters import ( RecursiveCharacterTextSplitter, @@ -18,43 +19,68 @@ class FileHandler: def __init__(self, file_location: str): self.file_location = file_location self.file_loader = { - ".pdf": PyPDFLoader, - ".txt": TextLoader, - ".docx": Docx2txtLoader, - ".pptx": UnstructuredPowerPointLoader, - ".json": RecursiveJsonSplitter, + ".pdf": (PyPDFLoader, {}), + ".txt": (TextLoader, {}), + ".docx": (Docx2txtLoader, {}), + ".pptx": ( + UnstructuredPowerPointLoader, + {"mode": "elements", "strategy": "fast", "join_docs_by_page": True} + ), + ".json": (RecursiveJsonSplitter, {}), } if not os.path.isfile(self.file_location): raise ValueError("Invalid file path.") def process_file_into_chunks(self): - file_loader = self._get_file_loader() + file_loader, kwargs = self._get_file_loader() if file_loader == RecursiveJsonSplitter: return self._get_json_chunks() - loader = file_loader(self.file_location) + join_docs_by_page = kwargs.pop("join_docs_by_page", False) + loader = file_loader(self.file_location, **kwargs) docs = loader.load() - splits = self._process_doc_to_chunks(docs) + splits = self._process_doc_to_chunks(docs, join_docs_by_page=join_docs_by_page) return splits @staticmethod - def _process_doc_to_chunks(docs): + def _process_doc_to_chunks(docs, join_docs_by_page: bool): if not docs: raise ValueError("Failed to load documents.") + if join_docs_by_page: + new_docs = [] + current_doc = Document(page_content="") + current_page = None + new_docs.append(current_doc) + for doc in docs: + if doc.page_content == "": + continue + doc_current_page = doc.metadata.get("page_number", None) + # if doc_current_page is None + if current_page != doc_current_page and doc.metadata.get("category", None) not in ["PageBreak", None]: + current_doc = Document( + page_content=doc.page_content, + metadata={"page": doc_current_page - 1 if doc_current_page else "No page"} + ) + current_page = doc_current_page + new_docs.append(current_doc) + else: + current_doc.page_content += f"\n {doc.page_content}" + else: + new_docs = docs # splits text into chunks including metadata for mapping from chunk to pdf page (splits[0].metadata['page']) text_splitter = RecursiveCharacterTextSplitter( chunk_size=os.getenv("CHUNK_SIZE", 1500), chunk_overlap=150 ) - splits = text_splitter.split_documents(docs) + splits = text_splitter.split_documents(new_docs) return splits def _get_file_loader(self): _, extension = os.path.splitext(self.file_location) - loader = self.file_loader.get(extension) + loader, kwargs = self.file_loader.get(extension) if loader is None: raise ValueError("File format does not have a loader!") - return loader + return loader, kwargs def _get_json_chunks(self): json_data = json.loads(Path(self.file_location).read_text()) diff --git a/Project/frontend/src/components/App/index.css b/Project/frontend/src/components/App/index.css index 57d9e87..5a5132c 100644 --- a/Project/frontend/src/components/App/index.css +++ b/Project/frontend/src/components/App/index.css @@ -58,22 +58,3 @@ img { justify-content: center; gap: 10px; } - -.main_wrapper { - display: flex; - flex-direction: column; - align-items: center; - gap: 20px; - margin: 20px; - min-width: 100%; - min-height: 100%; -} - -.Appcontainer { - display: flex; - flex-direction: column; - align-items: center; - gap: 20px; - min-width: 100%; - min-height: 100%; -} diff --git a/Project/frontend/src/components/App/index.tsx b/Project/frontend/src/components/App/index.tsx index 4729727..da5ff12 100644 --- a/Project/frontend/src/components/App/index.tsx +++ b/Project/frontend/src/components/App/index.tsx @@ -5,14 +5,11 @@ import { Routes, } from 'react-router-dom'; import { - AppBar, createTheme, CssBaseline, - Divider, Paper, Stack, ThemeProvider, - Toolbar, Typography, } from '@mui/material'; @@ -37,7 +34,7 @@ function App() { { - const containerRef = useRef(null); - - useEffect(() => { - if (!graphData) return; - - // Set up SVG dimensions - const width = window.innerWidth; - const height = window.innerHeight - 50; - - // Clear previous graph - d3.select(containerRef.current).select('svg').remove(); - - // Create SVG element - const svg = d3 - .select(containerRef.current) - .append('svg') - .attr('width', width) - .attr('height', height) - .call( - d3.zoom().on('zoom', (event) => { - svg.attr('transform', event.transform); - }), - ) - .append('g'); - - // Set up the simulation - const simulation = d3 - .forceSimulation(graphData.nodes) - .force( - 'link', - d3 - .forceLink(graphData.edges) - .id((d) => d.id) - .distance(100), - ) - .force('charge', d3.forceManyBody().strength(-300)) - .force('center', d3.forceCenter(width / 2, height / 2)); - - // Apply different layout algorithms - if (layout === 'hierarchical') { - simulation.force('y', d3.forceY().strength(0.1)); - simulation.force('x', d3.forceX().strength(0.1)); - } - - // Create links - const link = svg - .append('g') - .attr('class', 'links') - .selectAll('line') - .data(graphData.edges) - .enter() - .append('line') - .attr('stroke-width', 2) - .attr('stroke', '#fff'); - - // Create link labels - const linkLabels = svg - .append('g') - .attr('class', 'link-labels') - .selectAll('text') - .data(graphData.edges) - .enter() - .append('text') - .attr('class', 'link-label') - .attr('dx', 15) - .attr('dy', '.35em') - .text((d) => d.label) // Correctly reading the label property for edges - .attr('fill', '#fff'); - - // Create nodes - const node = svg - .append('g') - .attr('class', 'nodes') - .selectAll('circle') - .data(graphData.nodes) - .enter() - .append('circle') - .attr('r', 25) - .attr('fill', '#69b3a2') - .attr('stroke', '#508e7f') - .call( - d3 - .drag() - .on('start', dragstarted) - .on('drag', dragged) - .on('end', dragended), - ); - - // Node labels - const nodeLabels = svg - .append('g') - .attr('class', 'node-labels') - .selectAll('text') - .data(graphData.nodes) - .enter() - .append('text') - .attr('class', 'node-label') - .attr('dx', 15) - .attr('dy', '.35em') - .text((d) => d.label) // Correctly reading the label property for nodes - .attr('fill', '#fff'); - - // Update simulation - simulation.nodes(graphData.nodes).on('tick', ticked); - - simulation.force('link').links(graphData.edges); - - function ticked() { - link - .attr('x1', (d) => d.source.x) - .attr('y1', (d) => d.source.y) - .attr('x2', (d) => d.target.x) - .attr('y2', (d) => d.target.y); - - node.attr('cx', (d) => d.x).attr('cy', (d) => d.y); - - nodeLabels.attr('x', (d) => d.x).attr('y', (d) => d.y); - - linkLabels - .attr('x', (d) => (d.source.x + d.target.x) / 2) - .attr('y', (d) => (d.source.y + d.target.y) / 2); - } - - function dragstarted(event, d) { - if (!event.active) simulation.alphaTarget(0.3).restart(); - d.fx = d.x; - d.fy = d.y; - } - - function dragged(event, d) { - d.fx = event.x; - d.fy = event.y; - } - - function dragended(event, d) { - if (!event.active) simulation.alphaTarget(0); - d.fx = null; - d.fy = null; - } - - // Stabilize nodes after a certain time - setTimeout(() => { - simulation.alphaTarget(0).restart(); - }, 5000); // 5 seconds stabilization time - }, [graphData, layout]); - - return ( -
- ); -}; - -export default D3Graph; diff --git a/Project/frontend/src/components/Graph/FloatingControlCard_d3.jsx b/Project/frontend/src/components/Graph/FloatingControlCard_d3.jsx deleted file mode 100644 index 4992e88..0000000 --- a/Project/frontend/src/components/Graph/FloatingControlCard_d3.jsx +++ /dev/null @@ -1,43 +0,0 @@ -import React from 'react'; -import { - Card, - CardContent, - FormControl, - InputLabel, - Select, - MenuItem, - Box, -} from '@mui/material'; - -const FloatingControlCard = ({ layout, setLayout }) => { - return ( - - - - Layout - - - - - ); -}; - -export default FloatingControlCard; diff --git a/Project/frontend/src/components/Graph/FloatingControlCard_sigma.jsx b/Project/frontend/src/components/Graph/FloatingControlCard_sigma.jsx deleted file mode 100644 index 9aba734..0000000 --- a/Project/frontend/src/components/Graph/FloatingControlCard_sigma.jsx +++ /dev/null @@ -1,130 +0,0 @@ -import React from 'react'; -import { - Card, - CardContent, - FormControl, - InputLabel, - Select, - MenuItem, - Slider, - Typography, - Box, -} from '@mui/material'; - -const FloatingControlCard = ({ - layout, - setLayout, - physicsOptions, - handlePhysicsChange, -}) => { - const handleSliderChange = (name) => (event, value) => { - handlePhysicsChange(name, value); - }; - - const renderSliders = () => { - return ( - - Iterations - - Barnes Hut Theta - - Gravity - - Scaling Ratio - - Edge Weight Influence - - Edge Length - - - ); - }; - - return ( - - - - Layout - - - {renderSliders()} - - - ); -}; - -export default FloatingControlCard; diff --git a/Project/frontend/src/components/Graph/index_3js.tsx b/Project/frontend/src/components/Graph/index_3js.tsx deleted file mode 100644 index af11d2a..0000000 --- a/Project/frontend/src/components/Graph/index_3js.tsx +++ /dev/null @@ -1,202 +0,0 @@ -import { useEffect, useState, useRef } from 'react'; -import { - select, - forceSimulation, - forceLink, - forceManyBody, - forceCenter, - forceCollide, - zoom, - drag, -} from 'd3'; -import { useParams } from 'react-router-dom'; -import './index.css'; -import { VISUALIZE_API_PATH } from '../../constant'; - -const GraphVisualization = () => { - const svgRef = useRef(); - const { fileId = '' } = useParams(); - const [graphData, setGraphData] = useState(null); - const [isLoading, setIsLoading] = useState(true); - - useEffect(() => { - const API = `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`; - fetch(API) - .then((res) => res.json()) - .then((data) => { - setGraphData(data); - setIsLoading(false); - }) - .catch((error) => { - console.error('Error fetching graph data:', error); - setIsLoading(false); - }); - }, [fileId]); - - useEffect(() => { - if (!graphData) return; - - const svg = select(svgRef.current); - const width = svgRef.current.clientWidth; - const height = svgRef.current.clientHeight; - - const g = svg.append('g'); - - const simulation = forceSimulation(graphData.nodes) - .force( - 'link', - forceLink(graphData.edges) - .id((d) => d.id) - .distance(100), - ) - .force('charge', forceManyBody().strength(-200)) - .force('center', forceCenter(width / 2, height / 2)) - .force('collide', forceCollide().radius(30).strength(1)) - .alphaDecay(0.01) - .alphaMin(0.001); - - const link = g - .selectAll('line') - .data(graphData.edges) - .enter() - .append('line') - .attr('stroke', '#999') - .attr('stroke-width', 1.5) - .attr('id', (d, i) => `link${i}`); - - const node = g - .selectAll('circle') - .data(graphData.nodes) - .enter() - .append('circle') - .attr('r', 15) - .attr('fill', '#69b3a2') - .call( - drag() - .on('start', (event, d) => { - if (!event.active) simulation.alphaTarget(0.3).restart(); - d.fx = d.x; - d.fy = d.y; - }) - .on('drag', (event, d) => { - d.fx = event.x; - d.fy = event.y; - }) - .on('end', (event, d) => { - if (!event.active) simulation.alphaTarget(0); - d.fx = null; - d.fy = null; - }), - ) - .on('mouseover', function (event, d) { - select(this).attr('fill', 'orange'); - svg - .select('#tooltip') - .style('display', 'block') - .html(`ID: ${d.id}
Label: ${d.label || 'N/A'}`) - .style('left', `${event.pageX + 10}px`) - .style('top', `${event.pageY + 10}px`); - }) - .on('mouseout', function () { - select(this).attr('fill', '#69b3a2'); - svg.select('#tooltip').style('display', 'none'); - }); - - const nodeLabels = g - .selectAll('text.node-label') - .data(graphData.nodes) - .enter() - .append('text') - .attr('class', 'node-label') - .attr('dy', -20) - .attr('text-anchor', 'middle') - .style('font-size', '12px') - .style('pointer-events', 'none') - .text((d) => d.id); - - // Create path elements for each link - const linkPaths = g - .selectAll('path') - .data(graphData.edges) - .enter() - .append('path') - .attr('class', 'link-path') - .attr('id', (d, i) => `link-path${i}`) - .attr('fill', 'none') - .attr('stroke', 'none'); - - const edgeLabels = g - .selectAll('text.edge-label') - .data(graphData.edges) - .enter() - .append('text') - .attr('class', 'edge-label') - .attr('dy', -5) - .style('font-size', '10px') - .style('pointer-events', 'none') - .append('textPath') - .attr('xlink:href', (d, i) => `#link-path${i}`) - .attr('startOffset', '50%') - .style('text-anchor', 'middle') - .text((d) => d.id); - - const zoomBehavior = zoom().on('zoom', (event) => { - g.attr('transform', event.transform); - }); - - svg.call(zoomBehavior); - - simulation.on('tick', () => { - link - .attr('x1', (d) => d.source.x) - .attr('y1', (d) => d.source.y) - .attr('x2', (d) => d.target.x) - .attr('y2', (d) => d.target.y); - - node.attr('cx', (d) => d.x).attr('cy', (d) => d.y); - - nodeLabels.attr('x', (d) => d.x).attr('y', (d) => d.y); - - linkPaths.attr( - 'd', - (d) => `M${d.source.x},${d.source.y} L${d.target.x},${d.target.y}`, - ); - - edgeLabels - .attr('x', (d) => (d.source.x + d.target.x) / 2) - .attr('y', (d) => (d.source.y + d.target.y) / 2); - }); - - svg - .append('foreignObject') - .attr('id', 'tooltip') - .style('position', 'absolute') - .style('background', '#fff') - .style('border', '1px solid #ccc') - .style('padding', '10px') - .style('display', 'none') - .append('xhtml:div') - .style('font-size', '10px') - .html('Tooltip'); - - return () => simulation.stop(); - }, [graphData]); - - if (isLoading) { - return
Loading graph...
; - } - if (!graphData) { - return
Sorry, an error has occurred!
; - } - return ( -
-

Graph Visualization

- -
- ); -}; - -export default GraphVisualization; diff --git a/Project/frontend/src/components/Graph/index_sigma.tsx b/Project/frontend/src/components/Graph/index_sigma.tsx deleted file mode 100644 index 23f2e85..0000000 --- a/Project/frontend/src/components/Graph/index_sigma.tsx +++ /dev/null @@ -1,132 +0,0 @@ -import { useEffect, useState } from 'react'; -import { MultiDirectedGraph } from 'graphology'; -import { SigmaContainer, useSigma } from '@react-sigma/core'; -import { useParams } from 'react-router-dom'; -import '@react-sigma/core/lib/react-sigma.min.css'; -import EdgeCurveProgram, { - DEFAULT_EDGE_CURVATURE, - indexParallelEdgesIndex, -} from '@sigma/edge-curve'; -import { EdgeArrowProgram } from 'sigma/rendering'; -import forceAtlas2 from 'graphology-layout-forceatlas2'; -import './index.css'; -import { VISUALIZE_API_PATH } from '../../constant'; - -const ForceAtlas2Layout = ({ maxIterations }) => { - const sigma = useSigma(); - const graph = sigma.getGraph(); - const [iterations, setIterations] = useState(0); - - useEffect(() => { - const settings = { - iterations: maxIterations, - barnesHutOptimize: true, - barnesHutTheta: 0.5, - slowDown: 1, - gravity: 1, - scalingRatio: 10, - edgeWeightInfluence: 1, - strongGravityMode: true, - adjustSizes: true, - }; - - const applyLayout = () => { - forceAtlas2.assign(graph, settings); - setIterations((prev) => prev + 1); - }; - - if (iterations < maxIterations) { - const interval = setInterval(applyLayout, 100); - return () => clearInterval(interval); - } - }, [graph, iterations, maxIterations]); - - return null; -}; - -export default function GraphVisualization() { - const [graphData, setGraphData] = useState(null); - const { fileId = '' } = useParams(); - const [isLoading, setIsLoading] = useState(true); - - useEffect(() => { - const API = `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`; - fetch(API) - .then((res) => res.json()) - .then((graphData) => { - const graph = new MultiDirectedGraph(); - graphData?.nodes?.forEach((node) => { - const { id, ...rest } = node; - graph.addNode(id, { - ...rest, - size: 15, // just for testing, i am making all the same size - x: Math.random() * 1000, - y: Math.random() * 1000, - }); - }); - graphData?.edges?.forEach((edge) => { - const { id, source, target, ...rest } = edge; - graph.addEdgeWithKey(id, source, target, { - ...rest, - size: 2, // edge - }); - }); - indexParallelEdgesIndex(graph, { - edgeIndexAttribute: 'parallelIndex', - edgeMaxIndexAttribute: 'parallelMaxIndex', - }); - graph.forEachEdge((edge, { parallelIndex, parallelMaxIndex }) => { - if (typeof parallelIndex === 'number') { - graph.mergeEdgeAttributes(edge, { - type: 'curved', - curvature: - DEFAULT_EDGE_CURVATURE + - (3 * DEFAULT_EDGE_CURVATURE * parallelIndex) / - (parallelMaxIndex || 1), - }); - } else { - graph.setEdgeAttribute(edge, 'type', 'straight'); - } - }); - setGraphData(graph); - }) - .catch((error) => { - console.log('Error fetching graphData:', error); - }) - .finally(() => { - setIsLoading(false); - }); - }, [fileId]); - - if (isLoading) { - return
Loading graph...
; - } - if (!graphData) { - return
Sorry, an error has occurred!
; - } - return ( -
-

Graph Visualization

- - - -
- ); -} diff --git a/Project/frontend/src/components/Graph/index_visjs.tsx b/Project/frontend/src/components/Graph/index_visjs.tsx index 621fe9a..cdf2547 100644 --- a/Project/frontend/src/components/Graph/index_visjs.tsx +++ b/Project/frontend/src/components/Graph/index_visjs.tsx @@ -2,7 +2,7 @@ import React, { useEffect, useRef, useState } from 'react'; import { Network, Options } from 'vis-network/standalone/esm/vis-network'; import { useParams } from 'react-router-dom'; import './index.css'; -import { KEYWORDS_API_PATH, VISUALIZE_API_PATH } from '../../constant'; +import { KEYWORDS_API_PATH, VISUALIZE_API_PATH, GRAPH_SEARCH_API_PATH } from '../../constant'; import SearchIcon from '@mui/icons-material/Search'; import { Box, @@ -202,6 +202,8 @@ const GraphVisualization: React.FC = () => { const { fileId = '' } = useParams<{ fileId: string }>(); const [graphData, setGraphData] = useState(null); const [isLoading, setIsLoading] = useState(true); + const [searchIsLoading, setSearchIsLoading] = useState(false); + const [answerText, setAnswerText] = useState(""); const [layout, setLayout] = useState('barnesHut'); const [searchQuery, setSearchQuery] = useState(''); const [keywords, setKeywords] = useState([]); @@ -227,70 +229,6 @@ const GraphVisualization: React.FC = () => { const theme = useTheme(); const isMobile = useMediaQuery(theme.breakpoints.down('sm')); - useEffect(() => { - const fetchGraphData = async () => { - try { - const response = await fetch( - `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`, - ); - const data = await response.json(); - setGraphData(data); - - // Get the list of unique topics - const uniqueTopics = Array.from( - new Set(data.nodes.map((node) => node.topic)), - ); - - // Create color scheme for the topics - const colorSchemes = [ - d3.schemeCategory10, - d3.schemePaired, - d3.schemeSet1, - ]; - const uniqueColors = Array.from(new Set(colorSchemes.flat())); - - const otherIndex = uniqueTopics.indexOf('other'); - if (otherIndex !== -1) { - uniqueTopics.splice(otherIndex, 1); - } - - const topicColorMap: ITopicColourMap = uniqueTopics.reduce( - (acc: ITopicColourMap, topic, index) => { - acc[topic] = uniqueColors[index % uniqueColors.length]; - return acc; - }, - {}, - ); - - if (otherIndex !== -1) { - topicColorMap['other'] = - uniqueColors[uniqueTopics.length % uniqueColors.length]; - } - - setTopicColorMap(topicColorMap); - } catch (error) { - console.error('Error fetching graph data:', error); - } finally { - setIsLoading(false); - } - }; - - const fetchKeywords = async () => { - try { - const response = await fetch( - `${import.meta.env.VITE_BACKEND_HOST}${KEYWORDS_API_PATH.replace(':fileId', fileId)}`, - ); - const data = await response.json(); - setKeywords(data); - } catch (error) { - console.error('Error fetching keywords:', error); - } - }; - - fetchGraphData(); - fetchKeywords(); - }, [fileId]); - useEffect(() => { switch (layout) { case 'barnesHut': @@ -439,9 +377,27 @@ const GraphVisualization: React.FC = () => { } }; - const performSearch = () => { + const performSearch = async () => { // Perform the search based on searchQuery - console.log('Searching for:', searchQuery); + const API = `${import.meta.env.VITE_BACKEND_HOST}${GRAPH_SEARCH_API_PATH.replace(':fileId', fileId)}`; + + setSearchIsLoading(true); + try { + const response = await fetch(API, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ query: searchQuery }), + }); + const result = await response.json(); + setAnswerText(result.answer); + setSearchIsLoading(false); + } catch (error) { + console.error("Error fetching the search results:", error); + setAnswerText("An error occurred while fetching the search results."); + setSearchIsLoading(false); + } }; if (isLoading) { @@ -548,6 +504,10 @@ const GraphVisualization: React.FC = () => { }} sx={{ marginBottom: '10px' }} /> + {searchIsLoading ? <> + + Searching... + : <>} { readOnly: true, }} sx={{ marginBottom: '10px' }} + value={searchIsLoading ? '' : answerText} /> @@ -587,19 +548,7 @@ const GraphVisualization: React.FC = () => { - {/* setStabilizationComplete(false)} - sx={{ - position: 'absolute', - top: '10px', - right: '10px', - zIndex: 1000, - }} - /> */} + { - const { fileId = '' } = useParams(); - const [graphData, setGraphData] = useState(null); - const [isLoading, setIsLoading] = useState(true); - const [layout, setLayout] = useState('barnesHut'); - const [physicsOptions, setPhysicsOptions] = useState({ - gravitationalConstant: -20000, - springLength: 100, - springConstant: 0.1, - damping: 0.09, - }); - - const fetchGraphData = async () => { - try { - const response = await fetch( - `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`, - ); - const data = await response.json(); - setGraphData(data); - } catch (error) { - console.error('Error fetching graph data:', error); - } finally { - setIsLoading(false); - } - }; - - useEffect(() => { - fetchGraphData(); - }, [fileId]); - - const handlePhysicsChange = (name, value) => { - setPhysicsOptions((prevOptions) => ({ - ...prevOptions, - [name]: value, - })); - }; - - if (isLoading) { - return
Loading graph...
; - } - - if (!graphData) { - return
Sorry, an error has occurred!
; - } - - return ( -
-

Graph Visualization

- - -
- ); -}; - -export default GraphVisualization; diff --git a/Project/frontend/src/components/Graph/index_visjs_sigma.tsx b/Project/frontend/src/components/Graph/index_visjs_sigma.tsx deleted file mode 100644 index b446b56..0000000 --- a/Project/frontend/src/components/Graph/index_visjs_sigma.tsx +++ /dev/null @@ -1,168 +0,0 @@ -import { useEffect, useState } from 'react'; -import { MultiDirectedGraph } from 'graphology'; -import { SigmaContainer, useSigma } from '@react-sigma/core'; -import { useParams } from 'react-router-dom'; -import '@react-sigma/core/lib/react-sigma.min.css'; -import EdgeCurveProgram, { - DEFAULT_EDGE_CURVATURE, - indexParallelEdgesIndex, -} from '@sigma/edge-curve'; -import { EdgeArrowProgram } from 'sigma/rendering'; -import forceAtlas2 from 'graphology-layout-forceatlas2'; -import FloatingControlCard from './FloatingControlCard_sigma'; -import './index.css'; -import { VISUALIZE_API_PATH } from '../../constant'; - -const ForceAtlas2Layout = ({ settings, onIteration, restart }) => { - const sigma = useSigma(); - const graph = sigma.getGraph(); - const [iterations, setIterations] = useState(0); - - useEffect(() => { - if (restart) { - setIterations(0); - } - - const applyLayout = () => { - forceAtlas2.assign(graph, { ...settings, adjustSizes: true }); - setIterations((prev) => { - const newIteration = prev + 1; - onIteration(newIteration); - return newIteration; - }); - }; - - if (iterations < settings.iterations) { - const interval = setInterval(applyLayout, 10); // Reduce interval for faster calculations - return () => clearInterval(interval); - } - }, [graph, iterations, settings, onIteration, restart]); - - return null; -}; - -export default function GraphVisualization() { - const [graphData, setGraphData] = useState(null); - const { fileId = '' } = useParams(); - const [isLoading, setIsLoading] = useState(true); - const [layout, setLayout] = useState('forceAtlas2Based'); - const [physicsOptions, setPhysicsOptions] = useState({ - iterations: 200, - barnesHutOptimize: true, - barnesHutTheta: 0.5, - slowDown: 0.1, // Faster calculations - gravity: 5, // Stronger gravity - scalingRatio: 10, - edgeWeightInfluence: 1, - strongGravityMode: true, - adjustSizes: true, - edgeLength: 100, // Added for edge length - }); - const [restart, setRestart] = useState(false); - - useEffect(() => { - const API = `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`; - fetch(API) - .then((res) => res.json()) - .then((graphData) => { - const graph = new MultiDirectedGraph(); - graphData?.nodes?.forEach((node) => { - const { id, ...rest } = node; - graph.addNode(id, { - ...rest, - size: 15, // just for testing, i am making all the same size - x: Math.random() * 1000, - y: Math.random() * 1000, - }); - }); - graphData?.edges?.forEach((edge) => { - const { id, source, target, ...rest } = edge; - graph.addEdgeWithKey(id, source, target, { - ...rest, - size: 2, // edge - length: physicsOptions.edgeLength, // Set edge length - }); - }); - indexParallelEdgesIndex(graph, { - edgeIndexAttribute: 'parallelIndex', - edgeMaxIndexAttribute: 'parallelMaxIndex', - }); - graph.forEachEdge((edge, { parallelIndex, parallelMaxIndex }) => { - if (typeof parallelIndex === 'number') { - graph.mergeEdgeAttributes(edge, { - type: 'curved', - curvature: - DEFAULT_EDGE_CURVATURE + - (3 * DEFAULT_EDGE_CURVATURE * parallelIndex) / - (parallelMaxIndex || 1), - }); - } else { - graph.setEdgeAttribute(edge, 'type', 'straight'); - } - }); - setGraphData(graph); - }) - .catch((error) => { - console.log('Error fetching graphData:', error); - }) - .finally(() => { - setIsLoading(false); - }); - }, [fileId, physicsOptions.edgeLength]); // Re-fetch data when edge length changes - - const handlePhysicsChange = (name, value) => { - setPhysicsOptions((prevOptions) => ({ - ...prevOptions, - [name]: value, - })); - setRestart(true); - }; - - useEffect(() => { - if (restart) { - setRestart(false); - } - }, [restart]); - - if (isLoading) { - return
Loading graph...
; - } - if (!graphData) { - return
Sorry, an error has occurred!
; - } - return ( -
-

Graph Visualization

- - - console.log(`Iteration: ${iteration}`)} - restart={restart} - /> - -
- ); -} diff --git a/Project/frontend/src/components/GraphList/index.tsx b/Project/frontend/src/components/GraphList/index.tsx index 9a284d4..67d4a6a 100644 --- a/Project/frontend/src/components/GraphList/index.tsx +++ b/Project/frontend/src/components/GraphList/index.tsx @@ -33,7 +33,7 @@ interface IGraphList { updated_at: string | null; } -interface notification { +export interface Notification { show: boolean; severity: messageSeverity; message: string; @@ -58,7 +58,7 @@ const GraphList = () => { const [loading, setLoading] = React.useState(true); const [error, setError] = React.useState(null); const [generating, setGenerating] = React.useState(null); - const [notification, setNotification] = React.useState({ + const [notification, setNotification] = React.useState({ show: false, severity: messageSeverity.SUCCESS, message: '', @@ -113,7 +113,7 @@ const GraphList = () => { }); }; - const notify = (n: notification) => { + const notify = (n: Notification) => { setNotification(n); }; diff --git a/Project/frontend/src/components/Graph_page/FloatingControlCard.tsx b/Project/frontend/src/components/Graph_page/FloatingControlCard.tsx index 5c0ff62..0dd4874 100644 --- a/Project/frontend/src/components/Graph_page/FloatingControlCard.tsx +++ b/Project/frontend/src/components/Graph_page/FloatingControlCard.tsx @@ -9,7 +9,6 @@ import { InputLabel, Select, MenuItem, - Slider, Typography, Box, } from '@mui/material'; @@ -22,8 +21,14 @@ const FloatingControlCard = ({ physicsOptions, handlePhysicsChange, restartStabilization, +}: { + layout: string, + setLayout: (layout: string) => void, + physicsOptions: any, + handlePhysicsChange: (name: string, value: any) => void, + restartStabilization: () => void, }) => { - const handleSliderChange = (name) => (event, value) => { + const handleSliderChange = (name: string) => (event: any, value: any) => { handlePhysicsChange(name, value); restartStabilization(); }; @@ -35,7 +40,7 @@ const FloatingControlCard = ({ Direction { handlePhysicsChange('sortMethod', e.target.value); restartStabilization(); @@ -186,7 +191,7 @@ const FloatingControlCard = ({ Shake Towards @@ -293,7 +293,7 @@ const FloatingControlCard = ({ {renderSliders()} void; - searchGraph: (event: React.KeyboardEvent) => void; + performSearch: (query: string) => void; + searchResults: SearchResult[]; + searchIsLoading: boolean; + fileId: string; } -const drawerWidth = 450; - const GraphInfoPanel: React.FC = ({ open, toggleDrawer, @@ -37,15 +52,24 @@ const GraphInfoPanel: React.FC = ({ keywords, searchQuery, setSearchQuery, - searchGraph, + performSearch, + searchResults, + searchIsLoading, + fileId, }) => { + const formatdateTime = () => { + const options: Intl.DateTimeFormatOptions = { year: 'numeric', month: 'numeric', day: 'numeric', hour: 'numeric', minute: 'numeric' }; + const date = new Date(graphData.graph_created_at).toLocaleDateString('UTC', options); + return date; + } + return ( = ({ variant="persistent" anchor="left" open={open} - > + > {open ? : } @@ -69,7 +93,7 @@ const GraphInfoPanel: React.FC = ({ Created At - {graphData.graph_created_at} + {formatdateTime()} @@ -83,7 +107,10 @@ const GraphInfoPanel: React.FC = ({ setSearchQuery(keyword)} + onClick={() => { + setSearchQuery(keyword); + performSearch(keyword); + }} sx={{ margin: '2px' }} clickable /> @@ -96,29 +123,58 @@ const GraphInfoPanel: React.FC = ({ fullWidth value={searchQuery} onChange={(e) => setSearchQuery(e.target.value)} - onKeyDown={searchGraph} + onKeyDown={(e) => { + if (e.key === 'Enter') { + performSearch(searchQuery); + } + }} InputProps={{ endAdornment: ( - + performSearch(searchQuery)} /> ), }} sx={{ marginBottom: '10px' }} /> - + {searchIsLoading && ( + + + Searching... + + )} + {searchResults.length > 0 && ( + + {searchResults.map((result, index) => ( + + } + aria-controls={`panel${index}a-content`} + id={`panel${index}a-header`} + > + {result.merged_node} + + + Similarity: {result.similarity.toFixed(4)} + + Original Nodes: + {result.original_nodes.map((node, idx) => ( + + {node} + + Similarity: {result.individual_similarities[node].toFixed(4)} + + + ))} + + + ))} + + )} ); }; export default GraphInfoPanel; + diff --git a/Project/frontend/src/components/Graph_page/GraphVisualization.tsx b/Project/frontend/src/components/Graph_page/GraphVisualization.tsx index a8662a8..93ec890 100644 --- a/Project/frontend/src/components/Graph_page/GraphVisualization.tsx +++ b/Project/frontend/src/components/Graph_page/GraphVisualization.tsx @@ -15,7 +15,7 @@ import ChevronLeftIcon from '@mui/icons-material/ChevronLeft'; import VisGraph from './VisGraph'; import LoadingScreen from './LoadingScreen'; import ErrorScreen from './ErrorScreen'; -import { KEYWORDS_API_PATH, VISUALIZE_API_PATH } from '../../constant'; +import { KEYWORDS_API_PATH, VISUALIZE_API_PATH, GRAPH_SEARCH_API_PATH } from '../../constant'; import Navbar from '../Navbar/Navbar'; import GraphInfoPanel from './GraphInfoPanel'; import FloatingControlCard from './FloatingControlCard.tsx'; @@ -25,9 +25,8 @@ import { getOptions, physicsOptionsByLayout, } from './config'; -// import './index.css'; import * as d3 from 'd3'; -import Legend from './Legend'; // Importiere die Legend-Komponente +import Legend from './Legend'; interface GraphData { nodes: Array<{ @@ -35,6 +34,8 @@ interface GraphData { label?: string; topic: string; pages: string; + x?: number; + y?: number; [key: string]: any; }>; edges: Array<{ source: string; target: string; [key: string]: any }>; @@ -46,6 +47,15 @@ interface ITopicColourMap { [key: string]: string; } +interface SearchResult { + merged_node: string; + original_nodes: string[]; + similarity: number; + individual_similarities: { + [key: string]: number; + }; +} + const Main = styled('main', { shouldForwardProp: (prop) => prop !== 'open' })<{ open?: boolean; }>(({ theme, open }) => ({ @@ -70,7 +80,7 @@ const DrawerHeader = styled('div')(({ theme }) => ({ })); const GraphVisualization: React.FC = () => { - const { fileId } = useParams<{ fileId: string }>(); + const { fileId = '' } = useParams<{ fileId: string }>(); const [graphData, setGraphData] = useState(null); const [isLoading, setIsLoading] = useState(true); const [layout, setLayout] = useState('barnesHut'); @@ -79,8 +89,9 @@ const GraphVisualization: React.FC = () => { const [stabilizationComplete, setStabilizationComplete] = useState(false); const [topicColorMap, setTopicColorMap] = useState({}); const [physicsOptions, setPhysicsOptions] = useState(initialPhysicsOptions); - const [drawerOpen, setDrawerOpen] = useState(true); + const [searchResults, setSearchResults] = useState([]); + const [searchIsLoading, setSearchIsLoading] = useState(false); const theme = useTheme(); const isSmallScreen = useMediaQuery(theme.breakpoints.down('sm')); @@ -101,54 +112,60 @@ const GraphVisualization: React.FC = () => { }, [isSmallScreen]); useEffect(() => { - const fetchGraphData = async () => { - try { - const response = await fetch( - `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`, - ); - const data = await response.json(); - setGraphData(data); + if (!graphData) return; - // Get the list of unique topics - const uniqueTopics = Array.from( - new Set(data.nodes.map((node) => node.topic)), - ); + // Get the list of unique topics + const uniqueTopics = Array.from( + new Set(graphData?.nodes.map((node) => node.topic)), + ); - // Create color scheme for the topics - const colorSchemes = [ - d3.schemeCategory10, - d3.schemePaired, - d3.schemeSet1, - ]; - const uniqueColors = Array.from(new Set(colorSchemes.flat())); + // Create color scheme for the topics + const colorSchemes = [d3.schemeCategory10, d3.schemePaired, d3.schemeSet1]; + const uniqueColors = Array.from(new Set(colorSchemes.flat())); - const otherIndex = uniqueTopics.indexOf('other'); - if (otherIndex !== -1) { - uniqueTopics.splice(otherIndex, 1); - } + const otherIndex = uniqueTopics.indexOf('other'); + if (otherIndex !== -1) { + uniqueTopics.splice(otherIndex, 1); + } - const topicColorMap: ITopicColourMap = uniqueTopics.reduce( - (acc: ITopicColourMap, topic, index) => { - acc[topic] = uniqueColors[index % uniqueColors.length]; - return acc; - }, - {}, - ); + const topicColorMap: ITopicColourMap = uniqueTopics.reduce( + (acc: ITopicColourMap, topic, index) => { + acc[topic] = uniqueColors[index % uniqueColors.length]; + return acc; + }, + {}, + ); - if (otherIndex !== -1) { - topicColorMap['other'] = - uniqueColors[uniqueTopics.length % uniqueColors.length]; - } + if (otherIndex !== -1) { + topicColorMap['other'] = + uniqueColors[uniqueTopics.length % uniqueColors.length]; + } + setTopicColorMap(topicColorMap); + }, [graphData]); - setTopicColorMap(topicColorMap); - } catch (error) { - console.error('Error fetching graph data:', error); - } finally { + useEffect(() => { + const fetchGraphData = async () => { + if (sessionStorage.getItem(fileId)) { + const savedGraphData = JSON.parse(sessionStorage.getItem(fileId) as string); + setGraphData(savedGraphData); + setStabilizationComplete(true); // Assume layout was previously stabilized setIsLoading(false); + } else { + try { + const response = await fetch( + `${import.meta.env.VITE_BACKEND_HOST}${VISUALIZE_API_PATH.replace(':fileId', fileId)}`, + ); + const data = await response.json(); + setGraphData(data); + sessionStorage.setItem(fileId, JSON.stringify(data)); + } catch (error) { + console.error(error); + } finally { + setIsLoading(false); + } } }; - fetchGraphData(); const fetchKeywords = async () => { try { const response = await fetch( @@ -161,6 +178,7 @@ const GraphVisualization: React.FC = () => { } }; + fetchGraphData(); fetchKeywords(); }, [fileId]); @@ -182,14 +200,32 @@ const GraphVisualization: React.FC = () => { const options = getOptions(layout, physicsOptions); - const searchGraph = (event: React.KeyboardEvent) => { - if (event.key === 'Enter') { - performSearch(); - } - }; + const performSearch = async (query: string) => { + if (!query.trim()) return; - const performSearch = () => { - console.log('Searching for:', searchQuery); + setSearchIsLoading(true); + setSearchResults([]); // Reset results before new search + + try { + const API = `${import.meta.env.VITE_BACKEND_HOST}${GRAPH_SEARCH_API_PATH.replace(':fileId', fileId)}`; + const response = await fetch(API, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ query }), + }); + const result = await response.json(); + + // Parse the result directly + const parsedResults: SearchResult[] = JSON.parse(result.answer); + setSearchResults(parsedResults); + } catch (error) { + console.error('Error fetching the search results:', error); + setSearchResults([]); + } finally { + setSearchIsLoading(false); + } }; if (isLoading) { @@ -200,13 +236,6 @@ const GraphVisualization: React.FC = () => { return ; } - const formattedDate = new Date( - graphData.graph_created_at, - ).toLocaleDateString(); - const formattedTime = new Date( - graphData.graph_created_at, - ).toLocaleTimeString(); - return ( @@ -242,7 +271,10 @@ const GraphVisualization: React.FC = () => { keywords={keywords} searchQuery={searchQuery} setSearchQuery={setSearchQuery} - searchGraph={searchGraph} + performSearch={performSearch} + searchResults={searchResults} + searchIsLoading={searchIsLoading} + fileId={fileId} /> {!drawerOpen && ( @@ -271,13 +303,16 @@ const GraphVisualization: React.FC = () => { }} > - {graphData && !isStabilizingRef.current && ( + {graphData && ( )} diff --git a/Project/frontend/src/components/Graph_page/Legend.tsx b/Project/frontend/src/components/Graph_page/Legend.tsx index 6396a11..6920a09 100644 --- a/Project/frontend/src/components/Graph_page/Legend.tsx +++ b/Project/frontend/src/components/Graph_page/Legend.tsx @@ -1,5 +1,12 @@ import React from 'react'; -import { Box } from '@mui/material'; +import { + Accordion, + AccordionDetails, + AccordionSummary, + Box, + Typography, +} from '@mui/material'; +import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; type ITopicColourMap = Record; @@ -9,43 +16,70 @@ const Legend: React.FC<{ topicColorMap: ITopicColourMap }> = ({ return ( - - {Object.entries(topicColorMap).map(([topic, color]) => ( + + } + > + Topic / Color Legend + + - - - {topic.substring(topic.indexOf('_') + 1)} - + + {Object.entries(topicColorMap).map(([topic, color]) => ( + + + + {topic.substring(topic.indexOf('_') + 1)} + + + ))} + - ))} - + + ); }; diff --git a/Project/frontend/src/components/Graph_page/PersistentDrawerControls.tsx b/Project/frontend/src/components/Graph_page/PersistentDrawerControls.tsx deleted file mode 100644 index c95e4ba..0000000 --- a/Project/frontend/src/components/Graph_page/PersistentDrawerControls.tsx +++ /dev/null @@ -1,318 +0,0 @@ -import React, { useMemo, useCallback } from 'react'; -import { - Accordion, - AccordionSummary, - AccordionDetails, - Card, - CardContent, - FormControl, - InputLabel, - Select, - MenuItem, -} from '@mui/material'; -import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; - -const FloatingControlCard = ({ - layout, - setLayout, - physicsOptions, - handlePhysicsChange, - restartStabilization, -}) => { - const handleSliderChange = useCallback( - (name) => (event, value) => { - handlePhysicsChange(name, value); - restartStabilization(); - }, - [handlePhysicsChange, restartStabilization], - ); - - const renderSliders = useMemo(() => { - switch (layout) { - case 'barnesHut': - return ( - - - - - - - ); - case 'forceAtlas2Based': - return ( - - - - - - - ); - case 'hierarchical': - return ( - - - - - - - - Direction - - Sort Method - - Shake Towards - - - ); - case 'repulsion': - return ( - - - - - - - - ); - default: - return null; - } - }, [layout, physicsOptions, handleSliderChange, restartStabilization]); - - return ( - - - } - style={{ backgroundColor: '#383838', color: '#fff' }} - > - Physics Options - - - - - - {' '} - {/* Add margin bottom to the Box wrapping InputLabel */} - - Layout - - - - - {renderSliders()} - Stabilization Iterations - - - - - - ); -}; - -export default PersistentDrawerControls; diff --git a/Project/frontend/src/components/Graph_page/VisGraph.tsx b/Project/frontend/src/components/Graph_page/VisGraph.tsx index 0b4adc7..51df15c 100644 --- a/Project/frontend/src/components/Graph_page/VisGraph.tsx +++ b/Project/frontend/src/components/Graph_page/VisGraph.tsx @@ -1,6 +1,7 @@ import React, { useEffect, useState, useRef } from 'react'; import { Network, Options } from 'vis-network/standalone/esm/vis-network'; -import { Box, CircularProgress, Typography } from '@mui/material'; +import { Box, Snackbar } from '@mui/material'; +import { useEditMode } from './useEditMode'; // Import the custom hook interface GraphData { nodes: Array<{ @@ -8,6 +9,8 @@ interface GraphData { label?: string; topic: string; pages: string; + x?: number; + y?: number; [key: string]: any; }>; edges: Array<{ source: string; target: string; [key: string]: any }>; @@ -23,37 +26,119 @@ const VisGraph: React.FC<{ graphData: GraphData; options: Options; setStabilizationComplete: React.Dispatch>; + stabilizationComplete: boolean; + setGraphData: React.Dispatch>; topicColorMap: ITopicColourMap; -}> = ({ graphData, options, setStabilizationComplete, topicColorMap }) => { + isStabilizingRef: React.MutableRefObject; + fileId: string; +}> = ({ + graphData, + options, + setStabilizationComplete, + stabilizationComplete, + setGraphData, + topicColorMap, + isStabilizingRef, + fileId, +}) => { const containerRef = useRef(null); const networkRef = useRef(null); const [stabilizationProgress, setStabilizationProgress] = useState(0); const [isStabilizing, setIsStabilizing] = useState(false); + const [showStabilizingMessage, setShowStabilizingMessage] = useState(false); - useEffect(() => { - const handleResize = () => { - if (networkRef.current) { - networkRef.current.redraw(); - networkRef.current.fit(); + const updateNodeStyles = () => { + if (!networkRef.current) return; + + const allNodes = networkRef.current.body.nodes; + const allEdges = networkRef.current.body.edges; + + if (!editMode || !firstClick) { + // Reset all nodes and edges to normal state + for (let nodeId in allNodes) { + allNodes[nodeId].setOptions({ + color: { background: topicColorMap[graphData.nodes.find(node => node.id === nodeId)?.topic || ''], border: 'white' }, + opacity: 1, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, // Add border to labels + }); } - }; + for (let edgeId in allEdges) { + allEdges[edgeId].setOptions({ color: '#ccccff', opacity: 1, font: { color: 'white' }, hidden: false }); + } + return; + } - window.addEventListener('resize', handleResize); + // Reset all nodes and edges to dimmed state + for (let nodeId in allNodes) { + allNodes[nodeId].setOptions({ + color: { background: '#2f2f2f', border: '#2f2f2f' }, + opacity: 0.3, + font: { color: '#2f2f2f', strokeWidth: 0 }, // Remove label border when dimmed + }); + } + for (let edgeId in allEdges) { + allEdges[edgeId].setOptions({ color: '#2f2f2f', opacity: 0.1, font: { color: '#2f2f2f' }, hidden: true }); + } - return () => { - window.removeEventListener('resize', handleResize); - }; - }, []); + // Highlight selected nodes and their neighbors + selectedNodes.forEach(nodeId => { + const connectedNodes = networkRef.current.getConnectedNodes(nodeId); + const connectedEdges = networkRef.current.getConnectedEdges(nodeId); + + allNodes[nodeId].setOptions({ + opacity: 1, + color: { background: '#ffd700', border: 'white' }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, // Add label border for selected nodes + }); + connectedNodes.forEach((id: any) => { + if (selectedNodes.has(id)) { + allNodes[id].setOptions({ + opacity: 1, + color: { background: '#ffd700', border: 'white' }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, // Add label border for connected selected nodes + }); + } else { + allNodes[id].setOptions({ + opacity: 1, + color: { background: topicColorMap[graphData.nodes.find(node => node.id === id)?.topic || ''], border: 'white' }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, // Add label border for connected nodes + }); + } + }); + connectedEdges.forEach((id: any) => { + if (selectedNodes.has(networkRef.current?.body.edges[id].fromId) && selectedNodes.has(networkRef.current?.body.edges[id].toId)) { + allEdges[id].setOptions({ opacity: 1, color: '#ffd700', font: { color: 'white' }, hidden: false }); + } else { + allEdges[id].setOptions({ opacity: 0.5, color: '#ccccff', font: { color: 'white' }, hidden: false }); + } + }); + }); + }; + + const { + editMode, + selectedNodes, + setSelectedNodes, + highlightActiveNodes, + resetHighlight, + setFirstClick, + firstClick, + } = useEditMode(networkRef, topicColorMap, graphData, updateNodeStyles); useEffect(() => { if (!containerRef.current || !graphData) return; + const savedPositions = sessionStorage.getItem(`${fileId}_positions`); + const parsedPositions = savedPositions ? JSON.parse(savedPositions) : {}; + const data = { nodes: graphData.nodes.map((node) => ({ id: node.id, label: node.label || node.id, shape: 'dot', size: 25, + x: parsedPositions[node.id]?.x ?? node.x, + y: parsedPositions[node.id]?.y ?? node.y, ...node, title: `Found in pages: ${node.pages} Topic: ${node.topic.substring(node.topic.indexOf('_') + 1)}`, @@ -62,16 +147,17 @@ const VisGraph: React.FC<{ border: 'white', highlight: { background: '#69b3a2', - border: '#508e7f', + border: 'white', }, }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, // Add border to labels })), edges: graphData.edges.map((edge) => ({ from: edge.source, to: edge.target, ...edge, arrows: { - to: { enabled: false }, // Entfernt die Pfeile in Richtung des Ziels + to: { enabled: false }, from: { enabled: false }, }, color: { @@ -79,6 +165,9 @@ const VisGraph: React.FC<{ highlight: '#ffff00', hover: '#ffffff', }, + font: { + color: 'white', + }, })), }; @@ -89,39 +178,120 @@ const VisGraph: React.FC<{ const network = new Network(containerRef.current, data, options); networkRef.current = network; - setIsStabilizing(true); - setStabilizationProgress(0); + if (!stabilizationComplete) { + setIsStabilizing(true); + setStabilizationProgress(0); + setShowStabilizingMessage(true); + network.setOptions({ physics: true }); - const stabilizationProgressHandler = (params: any) => { - const progress = (params.iterations / params.total) * 100; - setStabilizationProgress((prevProgress) => - Math.max(prevProgress, progress), - ); - }; + const stabilizationProgressHandler = (params: any) => { + const progress = (params.iterations / params.total) * 100; + setStabilizationProgress((prevProgress) => + Math.max(prevProgress, progress), + ); + }; - const stabilizationIterationsDoneHandler = () => { - setStabilizationProgress(100); - setIsStabilizing(false); - setStabilizationComplete(true); - network.fit(); - }; + const stabilizationIterationsDoneHandler = () => { + const positions = network.getPositions(); + const updatedNodes = graphData.nodes.map((node) => ({ + ...node, + x: positions[node.id].x, + y: positions[node.id].y, + })); + + const updatedGraphData = { + ...graphData, + nodes: updatedNodes, + }; + + setGraphData(updatedGraphData); + sessionStorage.setItem(fileId, JSON.stringify(updatedGraphData)); + sessionStorage.setItem(`${fileId}_positions`, JSON.stringify(positions)); // Save positions - network.on('stabilizationProgress', stabilizationProgressHandler); - network.on( - 'stabilizationIterationsDone', - stabilizationIterationsDoneHandler, - ); + network.setOptions({ physics: false }); + setStabilizationComplete(true); + setIsStabilizing(false); + setShowStabilizingMessage(false); + }; + + network.on('stabilizationProgress', stabilizationProgressHandler); + network.on('stabilizationIterationsDone', stabilizationIterationsDoneHandler); + } else { + network.setOptions({ physics: false }); + } + + network.on('dragStart', function (params) { + if (params.nodes.length > 0) { + network.setOptions({ physics: true }); + } + }); + + network.on('dragEnd', function (params) { + if (params.nodes.length > 0) { + setTimeout(() => { + if (networkRef.current) { + const positions = networkRef.current.getPositions(); + const updatedNodes = graphData.nodes.map((node) => ({ + ...node, + x: positions[node.id].x, + y: positions[node.id].y, + })); + + const updatedGraphData = { + ...graphData, + nodes: updatedNodes, + }; + setGraphData(updatedGraphData); + sessionStorage.setItem(fileId, JSON.stringify(updatedGraphData)); + sessionStorage.setItem(`${fileId}_positions`, JSON.stringify(positions)); // Save positions + + // Ensure that the view doesn't change after dragging + networkRef.current.setOptions({ physics: false }); + } + }, 70000); // Delay before disabling physics + } + }); + + network.on('hoverNode', highlightActiveNodes); + network.on('blurNode', resetHighlight); + + network.on('click', function (params) { + if (editMode && params.nodes.length > 0) { + const nodeId = params.nodes[0]; + setSelectedNodes((prevSelectedNodes) => { + const newSelectedNodes = new Set(prevSelectedNodes); + if (newSelectedNodes.has(nodeId)) { + newSelectedNodes.delete(nodeId); + } else { + newSelectedNodes.add(nodeId); + } + return newSelectedNodes; + }); + setFirstClick(true); // Set first click to true on the first node click in edit mode + updateNodeStyles(); // Update styles immediately after selecting nodes + } + }); return () => { - network.off('stabilizationProgress', stabilizationProgressHandler); - network.off( - 'stabilizationIterationsDone', - stabilizationIterationsDoneHandler, - ); - network.destroy(); - networkRef.current = null; + if (networkRef.current) { + networkRef.current.off('stabilizationProgress'); + networkRef.current.off('stabilizationIterationsDone'); + networkRef.current.off('dragStart'); + networkRef.current.off('dragEnd'); + networkRef.current.off('hoverNode', highlightActiveNodes); + networkRef.current.off('blurNode', resetHighlight); + networkRef.current.off('click'); + networkRef.current.destroy(); + networkRef.current = null; + } }; - }, [graphData, options, topicColorMap, setStabilizationComplete]); + }, [graphData, options, topicColorMap, setStabilizationComplete, fileId, selectedNodes, editMode, firstClick]); + + useEffect(() => { + if (networkRef.current) { + updateNodeStyles(); + } + }, [selectedNodes, editMode, firstClick]); return ( - {isStabilizing && ( - - - - Stabilizing... {Math.round(stabilizationProgress)}% - - - )} + ); }; export default VisGraph; + diff --git a/Project/frontend/src/components/Graph_page/config.tsx b/Project/frontend/src/components/Graph_page/config.tsx index c099f34..cb5a030 100644 --- a/Project/frontend/src/components/Graph_page/config.tsx +++ b/Project/frontend/src/components/Graph_page/config.tsx @@ -20,7 +20,7 @@ export const physicsOptionsByLayout = { gravitationalConstant: -20000, springLength: 100, springConstant: 0.1, - damping: 0.09, + damping: 0.29, }, forceAtlas2Based: { gravitationalConstant: -50, @@ -32,13 +32,13 @@ export const physicsOptionsByLayout = { gravitationalConstant: 0, springLength: 120, springConstant: 0, - damping: 0, + damping: 0.2, }, repulsion: { gravitationalConstant: 0.2, springLength: 200, springConstant: 0.05, - damping: 0.09, + damping: 0.59, }, hierarchical: { levelSeparation: 150, diff --git a/Project/frontend/src/components/Graph_page/useEditMode.tsx b/Project/frontend/src/components/Graph_page/useEditMode.tsx new file mode 100644 index 0000000..fa602d0 --- /dev/null +++ b/Project/frontend/src/components/Graph_page/useEditMode.tsx @@ -0,0 +1,105 @@ +import { useState, useEffect } from 'react'; + +export const useEditMode = (networkRef, topicColorMap, graphData, updateNodeStyles) => { + const [editMode, setEditMode] = useState(false); + const [selectedNodes, setSelectedNodes] = useState(new Set()); + const [firstClick, setFirstClick] = useState(false); // Track if a node has been clicked in edit mode + + const handleKeyDown = (event) => { + if (event.key === 'Escape') { + setSelectedNodes(new Set()); + setEditMode(false); + setFirstClick(false); // Reset first click status + updateNodeStyles(); + } else if (event.key === 'e') { + setEditMode((prevEditMode) => !prevEditMode); + if (!prevEditMode) { + setFirstClick(false); // Reset first click status when entering edit mode + } + updateNodeStyles(); + } + }; + + useEffect(() => { + window.addEventListener('keydown', handleKeyDown); + return () => { + window.removeEventListener('keydown', handleKeyDown); + }; + }, []); + + useEffect(() => { + if (networkRef.current) { + updateNodeStyles(); + } + }, [selectedNodes, editMode, firstClick]); + + const highlightActiveNodes = (params) => { + if (!editMode) return; + + const allNodes = networkRef.current.body.nodes; + const allEdges = networkRef.current.body.edges; + const nodeId = params.node; + const connectedNodes = networkRef.current.getConnectedNodes(nodeId); + const connectedEdges = networkRef.current.getConnectedEdges(nodeId); + + // Dim all nodes and edges and hide labels + for (let nodeId in allNodes) { + allNodes[nodeId].setOptions({ + color: { background: '#2f2f2f', border: '#2f2f2f' }, + opacity: 0.3, + font: { color: '#2f2f2f', strokeWidth: 0 }, + }); + } + for (let edgeId in allEdges) { + allEdges[edgeId].setOptions({ color: '#2f2f2f', opacity: 0.1, font: { color: '#2f2f2f' }, hidden: true }); + } + + // Highlight the hovered node and its neighbors + allNodes[nodeId].setOptions({ + opacity: 1, + color: { background: topicColorMap[graphData.nodes.find((node) => node.id === nodeId)?.topic || ''], border: 'white' }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, + }); + connectedNodes.forEach((id) => { + allNodes[id].setOptions({ + opacity: 1, + color: { background: topicColorMap[graphData.nodes.find((node) => node.id === id)?.topic || ''], border: 'white' }, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, + }); + }); + connectedEdges.forEach((id) => { + allEdges[id].setOptions({ opacity: 0.5, color: '#ccccff', font: { color: 'white' }, hidden: false }); + }); + }; + + const resetHighlight = () => { + if (!editMode || (editMode && !firstClick)) { + const allNodes = networkRef.current.body.nodes; + const allEdges = networkRef.current.body.edges; + + // Reset all nodes and edges to normal state + for (let nodeId in allNodes) { + allNodes[nodeId].setOptions({ + color: { background: topicColorMap[graphData.nodes.find((node) => node.id === nodeId)?.topic || ''], border: 'white' }, + opacity: 1, + font: { color: 'white', strokeWidth: 3, strokeColor: 'black' }, + }); + } + for (let edgeId in allEdges) { + allEdges[edgeId].setOptions({ color: '#ccccff', opacity: 1, font: { color: 'white' }, hidden: false }); + } + } else { + updateNodeStyles(); + } + }; + + return { + editMode, + selectedNodes, + setSelectedNodes, + highlightActiveNodes, + resetHighlight, + setFirstClick, + firstClick, + }; +}; diff --git a/Project/frontend/src/components/LandingPage/index.tsx b/Project/frontend/src/components/LandingPage/index.tsx index a0d4d33..3756b0d 100644 --- a/Project/frontend/src/components/LandingPage/index.tsx +++ b/Project/frontend/src/components/LandingPage/index.tsx @@ -1,7 +1,5 @@ import React from 'react'; - -import { Button, Container, Stack, Typography } from '@mui/material'; - +import { Button, Container, Stack, Typography, Link } from '@mui/material'; import GraphList from '../GraphList'; import { useNavigate } from 'react-router-dom'; @@ -38,10 +36,18 @@ const LandingPage = () => { -
+ {/* Footer message */} + + Made with 💙 at FAU and TU Berlin. + ); }; diff --git a/Project/frontend/src/components/Navbar/Navbar.tsx b/Project/frontend/src/components/Navbar/Navbar.tsx index 7c4e733..4291dbe 100644 --- a/Project/frontend/src/components/Navbar/Navbar.tsx +++ b/Project/frontend/src/components/Navbar/Navbar.tsx @@ -15,12 +15,9 @@ const Navbar = () => { case '/': setTitle('Home'); break; - case '/upload': + case '/upload': // todo setTitle('Upload'); break; - case '/about': // todo - setTitle('About'); - break; default: setTitle('Graph Masters'); } @@ -28,7 +25,7 @@ const Navbar = () => { }, [location.pathname]); return ( - + { {title} - - - Home - - - - - Upload - - - + About - + diff --git a/Project/frontend/src/components/Snackbar/index.css b/Project/frontend/src/components/Snackbar/index.css deleted file mode 100644 index e69de29..0000000 diff --git a/Project/frontend/src/components/Snackbar/index.tsx b/Project/frontend/src/components/Snackbar/index.tsx index 43bf2df..c7ba5e3 100644 --- a/Project/frontend/src/components/Snackbar/index.tsx +++ b/Project/frontend/src/components/Snackbar/index.tsx @@ -6,7 +6,7 @@ import { messageSeverity } from '../../constant'; interface CustomizedSnackbarsProps { open: boolean; - handleClick: () => void; + handleClick?: () => void; handleClose: (event?: React.SyntheticEvent | Event, reason?: string) => void; severity_value?: messageSeverity; message?: string; diff --git a/Project/frontend/src/components/Upload/index.css b/Project/frontend/src/components/Upload/index.css deleted file mode 100644 index ad52eaf..0000000 --- a/Project/frontend/src/components/Upload/index.css +++ /dev/null @@ -1,3 +0,0 @@ -.upload_wrapper { - width: 500px; -} diff --git a/Project/frontend/src/components/Upload/index.tsx b/Project/frontend/src/components/Upload/index.tsx index 708b5d6..2b7ea7b 100644 --- a/Project/frontend/src/components/Upload/index.tsx +++ b/Project/frontend/src/components/Upload/index.tsx @@ -1,10 +1,16 @@ +import React from 'react'; + import { FilePond, registerPlugin, FilePondProps } from 'react-filepond'; import 'filepond/dist/filepond.min.css'; import FilePondPluginFileValidateType from 'filepond-plugin-file-validate-type'; +import { + GRAPH_DELETE_API_PATH, + UPLOAD_API_PATH, + messageSeverity, +} from '../../constant'; +import CustomizedSnackbars from '../Snackbar'; +import { Notification } from '../GraphList'; -import { GRAPH_DELETE_API_PATH, UPLOAD_API_PATH } from '../../constant'; - -import './index.css'; registerPlugin(FilePondPluginFileValidateType); @@ -23,6 +29,12 @@ type UploadProps = { }; function Upload(props: UploadProps) { + const [notification, setNotification] = React.useState({ + show: false, + severity: messageSeverity.SUCCESS, + message: '', + }); + const server: FilePondProps['server'] = { url: `${import.meta.env.VITE_BACKEND_HOST}`, process: { @@ -54,11 +66,48 @@ function Upload(props: UploadProps) { }, }; - const handleFileProcess: FilePondProps['onprocessfile'] = (error, file) => + const handleFileProcess: FilePondProps['onprocessfile'] = (error, file) => { + console.log(error, file); props.handleAddFile?.(error, file); + }; + + const handleAddFile: FilePondProps['onaddfile'] = (error, file) => { + if (error) { + setNotification({ + show: true, + severity: messageSeverity.ERROR, + message: `${error.main}. ${error.sub}.`, + }); + } + }; + + const handleClose = ( + event?: React.SyntheticEvent | Event, + reason?: string, + ) => { + if (reason === 'clickaway') { + return; + } + setNotification({ + show: false, + severity: notification.severity, + message: notification.message, + }); + }; + + const renderSnackbar = () => { + return ( + + ); + }; return ( -
+
error.body} + labelFileTypeNotAllowed="Invalid file type" + fileValidateTypeLabelExpectedTypes="Kindly check the info" + onaddfile={handleAddFile} /> + {renderSnackbar()}
); } diff --git a/Project/frontend/src/components/UploadPage/index.css b/Project/frontend/src/components/UploadPage/index.css deleted file mode 100644 index 414ef16..0000000 --- a/Project/frontend/src/components/UploadPage/index.css +++ /dev/null @@ -1,21 +0,0 @@ -.main_wrapper_upload { - display: flex; - flex-direction: column; - align-items: center; - gap: 20px; - max-width: 500px; -} - -.buttons_container { - display: flex; - align-items: center; - justify-content: flex-end; - width: 100%; -} - -.upload_info { - display: flex; - align-items: center; - justify-content: center; - gap: 5px; -} \ No newline at end of file diff --git a/Project/frontend/src/components/UploadPage/index.tsx b/Project/frontend/src/components/UploadPage/index.tsx index ee40fd9..f015101 100644 --- a/Project/frontend/src/components/UploadPage/index.tsx +++ b/Project/frontend/src/components/UploadPage/index.tsx @@ -14,7 +14,6 @@ import InfoIcon from '@mui/icons-material/Info'; // Import InfoIcon for hint but import { GENERATE_API_PATH, GraphStatus } from '../../constant'; import CustomizedSnackbars from '../Snackbar'; import Upload from '../Upload'; -import './index.css'; function UploadPage() { const [fileId, setFileId] = useState(''); @@ -87,19 +86,27 @@ function UploadPage() { return ( -
+ theme.palette.text.secondary }} > Upload a document to generate the graph - {hintText}} placement='top' arrow> - + {hintText}} + placement="top" + arrow + > + - -
+