Skip to content

Commit

Permalink
Merge pull request #186 from amosproj/feat/graph-coloring-ambiguity-i…
Browse files Browse the repository at this point in the history
…ssue

Feat/graph coloring ambiguity issue
  • Loading branch information
nikolas-rauscher authored Jul 9, 2024
2 parents 3f2599e + 2c6c2f1 commit c9ddd4a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 27 deletions.
15 changes: 11 additions & 4 deletions Project/backend/codebase/graph_creator/graph_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,24 @@ def add_relations_to_data(entity_and_relation_df, new_relations):
return entity_and_relation_df


def add_topic(data: pd.DataFrame) -> pd.DataFrame:
def add_topic(data: pd.DataFrame, max_topics: int = 25) -> pd.DataFrame:
documents = list(set(data["node_1"]).union(set(data["node_2"])))

topic_model = BERTopic()
topics, probabilities = topic_model.fit_transform(documents)
topic_info = topic_model.get_topic_info()

# Keep only the top given number of topics
top_topics = topic_model.get_topic_info().head(max_topics)["Topic"].tolist()

topic_name_info = {
row["Topic"]: row["Name"] for _, row in topic_model.get_topic_info().iterrows()
row["Topic"]: row["Name"] for _, row in topic_info.iterrows()
}
doc_topic_map = {doc: topic for doc, topic in zip(documents, topics)}

# Create a mapping for "other" topics
doc_topic_map = {doc: (topic if topic in top_topics else "other") for doc, topic in zip(documents, topics)}
doc_topic_strings_map = {
doc: topic_name_info.get(topic, "no_topic")
doc: (topic_name_info.get(topic, "other") if topic != "other" else "other")
for doc, topic in doc_topic_map.items()
}

Expand Down
128 changes: 105 additions & 23 deletions Project/frontend/src/components/Graph/index_visjs.tsx
Original file line number Diff line number Diff line change
@@ -1,32 +1,83 @@
import React, { useEffect, useState, useRef } from 'react';
import React, { useEffect, useRef, useState } from 'react';
import { Network, Options } from 'vis-network/standalone/esm/vis-network';
import { useParams } from 'react-router-dom';
import './index.css';
import { KEYWORDS_API_PATH, VISUALIZE_API_PATH } from '../../constant';
import SearchIcon from '@mui/icons-material/Search';
import {
TextField,
InputAdornment,
Chip,
Box,
Chip,
CircularProgress,
Typography,
InputAdornment,
Stack,
useTheme,
TextField,
Typography,
useMediaQuery,
useTheme,
} from '@mui/material';
import FloatingControlCard from './FloatingControlCard';
import FloatingControlCard from './FloatingControlCard.jsx';
import * as d3 from 'd3';

type ITopicColourMap = Record<string, string>;

interface GraphData {
nodes: Array<{ id: string; label?: string; topic: string; pages: string; [key: string]: any }>;
nodes: Array<{
id: string;
label?: string;
topic: string;
pages: string;
[key: string]: any;
}>;
edges: Array<{ source: string; target: string; [key: string]: any }>;
document_name: string;
graph_created_at: string;
}

interface ITopicColourMap {
[key: string]: string;
}
const Legend: React.FC<{ topicColorMap: ITopicColourMap }> = ({
topicColorMap,
}) => {
return (
<Box
sx={{
padding: '16px',
backgroundColor: '#121826',
borderRadius: '10px',
color: 'white',
maxHeight: '250px',
overflowY: 'auto',
maxWidth: '500px',
position: 'absolute',
left: '16px',
top: '16px',
}}
>
<Box component="ul" sx={{ padding: 0, margin: 0, listStyle: 'none' }}>
{Object.entries(topicColorMap).map(([topic, color]) => (
<Box
component="li"
key={topic}
sx={{
display: 'flex',
marginBottom: '8px',
}}
>
<Box
sx={{
width: '20px',
height: '20px',
backgroundColor: color,
marginRight: '8px',
}}
/>
<span style={{ fontSize: '0.875rem' }}>
{topic.substring(topic.indexOf('_') + 1)}
</span>
</Box>
))}
</Box>
</Box>
);
};

const VisGraph: React.FC<{
graphData: GraphData;
Expand All @@ -50,7 +101,8 @@ const VisGraph: React.FC<{
shape: 'dot',
size: 25,
...node,
title: `Found in pages: ${node.pages}`,
title: `Found in pages: ${node.pages}
Topic: ${node.topic.substring(node.topic.indexOf('_') + 1)}`,
color: {
background: topicColorMap[node.topic],
border: 'white',
Expand All @@ -59,6 +111,7 @@ const VisGraph: React.FC<{
border: '#508e7f',
},
},
borderWidth: 0.5,
})),
edges: graphData.edges.map((edge) => ({
from: edge.source,
Expand Down Expand Up @@ -180,14 +233,38 @@ const GraphVisualization: React.FC = () => {
const data = await response.json();
setGraphData(data);

// Generate and set topic color map
const newTopicColorMap = data.nodes.reduce((acc: ITopicColourMap, curr: any) => {
if (!acc[curr.topic]) {
acc[curr.topic] = '#' + Math.floor(Math.random() * 16777215).toString(16);
}
return acc;
}, {});
setTopicColorMap(newTopicColorMap);
// Get the list of unique topics
const uniqueTopics = Array.from(
new Set(data.nodes.map((node) => node.topic)),
);

// Create color scheme for the topics
const colorSchemes = [
d3.schemeCategory10,
d3.schemePaired,
d3.schemeSet1,
];
const uniqueColors = Array.from(new Set(colorSchemes.flat()));

const otherIndex = uniqueTopics.indexOf('other');
if (otherIndex !== -1) {
uniqueTopics.splice(otherIndex, 1);
}

const topicColorMap: ITopicColourMap = uniqueTopics.reduce(
(acc: ITopicColourMap, topic, index) => {
acc[topic] = uniqueColors[index % uniqueColors.length];
return acc;
},
{},
);

if (otherIndex !== -1) {
topicColorMap['other'] =
uniqueColors[uniqueTopics.length % uniqueColors.length];
}

setTopicColorMap(topicColorMap);
} catch (error) {
console.error('Error fetching graph data:', error);
} finally {
Expand Down Expand Up @@ -397,8 +474,12 @@ const GraphVisualization: React.FC = () => {
);
}

const formattedDate = new Date(graphData.graph_created_at).toLocaleDateString();
const formattedTime = new Date(graphData.graph_created_at).toLocaleTimeString();
const formattedDate = new Date(
graphData.graph_created_at,
).toLocaleDateString();
const formattedTime = new Date(
graphData.graph_created_at,
).toLocaleTimeString();

return (
<Stack
Expand Down Expand Up @@ -500,6 +581,7 @@ const GraphVisualization: React.FC = () => {
topicColorMap={topicColorMap}
/>
)}
<Legend topicColorMap={topicColorMap} />
</Box>
</Stack>
<FloatingControlCard
Expand All @@ -519,4 +601,4 @@ const GraphVisualization: React.FC = () => {
);
};

export default GraphVisualization;
export default GraphVisualization;

0 comments on commit c9ddd4a

Please sign in to comment.