Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/graph coloring ambiguity issue #186

Merged
merged 5 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions Project/backend/codebase/graph_creator/graph_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,24 @@ def add_relations_to_data(entity_and_relation_df, new_relations):
return entity_and_relation_df


def add_topic(data: pd.DataFrame) -> pd.DataFrame:
def add_topic(data: pd.DataFrame, max_topics: int = 25) -> pd.DataFrame:
documents = list(set(data["node_1"]).union(set(data["node_2"])))

topic_model = BERTopic()
topics, probabilities = topic_model.fit_transform(documents)
topic_info = topic_model.get_topic_info()

# Keep only the top given number of topics
top_topics = topic_model.get_topic_info().head(max_topics)["Topic"].tolist()

topic_name_info = {
row["Topic"]: row["Name"] for _, row in topic_model.get_topic_info().iterrows()
row["Topic"]: row["Name"] for _, row in topic_info.iterrows()
}
doc_topic_map = {doc: topic for doc, topic in zip(documents, topics)}

# Create a mapping for "other" topics
doc_topic_map = {doc: (topic if topic in top_topics else "other") for doc, topic in zip(documents, topics)}
doc_topic_strings_map = {
doc: topic_name_info.get(topic, "no_topic")
doc: (topic_name_info.get(topic, "other") if topic != "other" else "other")
for doc, topic in doc_topic_map.items()
}

Expand Down
128 changes: 105 additions & 23 deletions Project/frontend/src/components/Graph/index_visjs.tsx
Original file line number Diff line number Diff line change
@@ -1,32 +1,83 @@
import React, { useEffect, useState, useRef } from 'react';
import React, { useEffect, useRef, useState } from 'react';
import { Network, Options } from 'vis-network/standalone/esm/vis-network';
import { useParams } from 'react-router-dom';
import './index.css';
import { KEYWORDS_API_PATH, VISUALIZE_API_PATH } from '../../constant';
import SearchIcon from '@mui/icons-material/Search';
import {
TextField,
InputAdornment,
Chip,
Box,
Chip,
CircularProgress,
Typography,
InputAdornment,
Stack,
useTheme,
TextField,
Typography,
useMediaQuery,
useTheme,
} from '@mui/material';
import FloatingControlCard from './FloatingControlCard';
import FloatingControlCard from './FloatingControlCard.jsx';
import * as d3 from 'd3';

type ITopicColourMap = Record<string, string>;

interface GraphData {
nodes: Array<{ id: string; label?: string; topic: string; pages: string; [key: string]: any }>;
nodes: Array<{
id: string;
label?: string;
topic: string;
pages: string;
[key: string]: any;
}>;
edges: Array<{ source: string; target: string; [key: string]: any }>;
document_name: string;
graph_created_at: string;
}

interface ITopicColourMap {
[key: string]: string;
}
const Legend: React.FC<{ topicColorMap: ITopicColourMap }> = ({
topicColorMap,
}) => {
return (
<Box
sx={{
padding: '16px',
backgroundColor: '#121826',
borderRadius: '10px',
color: 'white',
maxHeight: '250px',
overflowY: 'auto',
maxWidth: '500px',
position: 'absolute',
left: '16px',
top: '16px',
}}
>
<Box component="ul" sx={{ padding: 0, margin: 0, listStyle: 'none' }}>
{Object.entries(topicColorMap).map(([topic, color]) => (
<Box
component="li"
key={topic}
sx={{
display: 'flex',
marginBottom: '8px',
}}
>
<Box
sx={{
width: '20px',
height: '20px',
backgroundColor: color,
marginRight: '8px',
}}
/>
<span style={{ fontSize: '0.875rem' }}>
{topic.substring(topic.indexOf('_') + 1)}
</span>
</Box>
))}
</Box>
</Box>
);
};

const VisGraph: React.FC<{
graphData: GraphData;
Expand All @@ -50,7 +101,8 @@ const VisGraph: React.FC<{
shape: 'dot',
size: 25,
...node,
title: `Found in pages: ${node.pages}`,
title: `Found in pages: ${node.pages}
Topic: ${node.topic.substring(node.topic.indexOf('_') + 1)}`,
color: {
background: topicColorMap[node.topic],
border: 'white',
Expand All @@ -59,6 +111,7 @@ const VisGraph: React.FC<{
border: '#508e7f',
},
},
borderWidth: 0.5,
})),
edges: graphData.edges.map((edge) => ({
from: edge.source,
Expand Down Expand Up @@ -180,14 +233,38 @@ const GraphVisualization: React.FC = () => {
const data = await response.json();
setGraphData(data);

// Generate and set topic color map
const newTopicColorMap = data.nodes.reduce((acc: ITopicColourMap, curr: any) => {
if (!acc[curr.topic]) {
acc[curr.topic] = '#' + Math.floor(Math.random() * 16777215).toString(16);
}
return acc;
}, {});
setTopicColorMap(newTopicColorMap);
// Get the list of unique topics
const uniqueTopics = Array.from(
new Set(data.nodes.map((node) => node.topic)),
);

// Create color scheme for the topics
const colorSchemes = [
d3.schemeCategory10,
d3.schemePaired,
d3.schemeSet1,
];
const uniqueColors = Array.from(new Set(colorSchemes.flat()));

const otherIndex = uniqueTopics.indexOf('other');
if (otherIndex !== -1) {
uniqueTopics.splice(otherIndex, 1);
}

const topicColorMap: ITopicColourMap = uniqueTopics.reduce(
(acc: ITopicColourMap, topic, index) => {
acc[topic] = uniqueColors[index % uniqueColors.length];
return acc;
},
{},
);

if (otherIndex !== -1) {
topicColorMap['other'] =
uniqueColors[uniqueTopics.length % uniqueColors.length];
}

setTopicColorMap(topicColorMap);
} catch (error) {
console.error('Error fetching graph data:', error);
} finally {
Expand Down Expand Up @@ -397,8 +474,12 @@ const GraphVisualization: React.FC = () => {
);
}

const formattedDate = new Date(graphData.graph_created_at).toLocaleDateString();
const formattedTime = new Date(graphData.graph_created_at).toLocaleTimeString();
const formattedDate = new Date(
graphData.graph_created_at,
).toLocaleDateString();
const formattedTime = new Date(
graphData.graph_created_at,
).toLocaleTimeString();

return (
<Stack
Expand Down Expand Up @@ -500,6 +581,7 @@ const GraphVisualization: React.FC = () => {
topicColorMap={topicColorMap}
/>
)}
<Legend topicColorMap={topicColorMap} />
</Box>
</Stack>
<FloatingControlCard
Expand All @@ -519,4 +601,4 @@ const GraphVisualization: React.FC = () => {
);
};

export default GraphVisualization;
export default GraphVisualization;
Loading