-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_research_agent.py
1407 lines (1179 loc) · 65.8 KB
/
deep_research_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import json
from typing import List, Dict, Optional, Set, Tuple
import google.generativeai as genai
from test_search import build
from dotenv import load_dotenv
import time
from collections import defaultdict
import logging
import colorlog
from datetime import datetime
import re
import asyncio
import httplib2
from difflib import SequenceMatcher
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from vertexai.preview import tokenization
from utils.browsing import BrowserManager
class DeepResearchAgent:
def __init__(self):
"""Regular initialization of non-async components."""
load_dotenv()
# Setup logging
self.setup_logging()
# Initialize Google Gemini with context
self.current_date = datetime.now().strftime("%Y-%m-%d")
self.approximate_location = "UK"
system_context = f"Current date: {self.current_date}. Location: {self.approximate_location}. "
system_context += "Only use the information from web searches, not your training data. "
system_context += "For queries about current/latest things, use generic search terms without specific versions/numbers."
genai.configure(api_key=os.getenv('GOOGLE_AI_KEY'))
# Define common safety settings for all models
self.safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# Initialize different models for different tasks with safety settings
self.model = genai.GenerativeModel(
'gemini-2.0-flash-lite-preview-02-05',
safety_settings=self.safety_settings
) # Default model for general tasks
self.ranking_model = genai.GenerativeModel(
'gemini-2.0-flash-lite-preview-02-05',
safety_settings=self.safety_settings
) # Specific model for ranking
self.analysis_model = genai.GenerativeModel(
'gemini-2.0-flash-lite-preview-02-05',
safety_settings=self.safety_settings
) # Model for analysis
self.report_model = genai.GenerativeModel(
'gemini-2.0-flash',
safety_settings=self.safety_settings
) # Model for final report generation
self.chat_model = genai.GenerativeModel(
'gemini-2.0-flash',
safety_settings=self.safety_settings
) # Model for chat
# Initialize Google Custom Search
self.search_engine = build(
"customsearch", "v1",
developerKey=os.getenv('GOOGLE_SEARCH_KEY')
).cse()
self.search_engine_id = os.getenv('GOOGLE_SEARCH_ENGINE_ID')
# Research state
self.previous_queries = set() # Changed to set for uniqueness
self.all_results = {}
self.high_ranking_urls = {} # Track URLs with score > 0.6
self.blacklisted_urls = set()
self.scraped_urls = set() # Track already scraped URLs
self.research_iterations = 0
self.MAX_ITERATIONS = 5
self.system_context = system_context
self.total_tokens = 0 # Track total tokens used
self.research_data = None # Store the latest research data for follow-up questions
# Initialize tokenizer for Gemini model
self.model_name = "gemini-1.5-flash-002"
self.tokenizer = tokenization.get_tokenizer_for_model(self.model_name)
self.token_usage_by_operation = defaultdict(int)
self.content_tokens = 0 # Track tokens from stored content separately
# Initialize browser manager
self.browser_manager = None
async def __aenter__(self):
"""Async initialization when entering context."""
self.browser_manager = await BrowserManager().__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Cleanup when exiting context."""
if self.browser_manager:
await self.browser_manager.__aexit__(exc_type, exc_val, exc_tb)
def setup_logging(self):
"""Setup colorized logging."""
handler = colorlog.StreamHandler()
handler.setFormatter(colorlog.ColoredFormatter(
'%(log_color)s%(levelname)s:%(reset)s %(message)s',
log_colors={
'DEBUG': 'cyan',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'red,bg_white',
}
))
logger = colorlog.getLogger('deep_research')
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
self.logger = logger
def should_skip_url(self, url: str) -> bool:
"""Check if URL should be skipped."""
return (
url in self.blacklisted_urls or
any(ext in url for ext in ['.pdf', '.doc', '.docx', '.ppt', '.pptx'])
)
def generate_subqueries(self, main_query: str, research_state: Optional[Dict] = None) -> List[str]:
"""Generate sub-queries using AI to explore different aspects of the main query."""
self.logger.info("Analyzing query and generating search queries...")
MAX_QUERIES = 5 # Maximum number of queries to return
context = ""
if research_state and self.previous_queries:
context = f"""
Previously used queries: {json.dumps(list(self.previous_queries), indent=2)}
Current research state: {json.dumps(research_state, indent=2)}
Based on the above context and gaps in current research, """
prompt = f"""{self.system_context}
{context}Generate comprehensive search queries to gather maximum relevant information about this query:
Query: '{main_query}'
First, determine if this is a SIMPLE query (basic math, unit conversion, single fact lookup) or a COMPLEX query requiring research.
For COMPLEX queries, generate search queries that:
1. Cover all temporal aspects:
- Historical background and development
- Current state and recent developments
- Future predictions and trends
2. Include different information types:
- Core facts and definitions
- Statistics and data
- Expert analysis and opinions
- Case studies and examples
- Comparisons and contrasts
- Problems and solutions
- Impacts and implications
3. Use search optimization techniques:
- Site-specific searches (e.g., site:edu, site:gov)
- Date-range specific queries when relevant
- Include synonyms and related terms
- Combine key concepts in different ways
- Use both broad and specific queries
4. Response Format:
TYPE: [SIMPLE/COMPLEX]
REASON: [One clear sentence explaining the classification]
QUERIES:
[If SIMPLE: Only output the original query
If COMPLEX: Generate up to {MAX_QUERIES} search queries that:
- Start each line with a number
- Ensure broad coverage within the {MAX_QUERIES} query limit]
Example of a COMPLEX query about "impact of remote work":
1. "remote work" impact statistics 2023-2024
2. site:edu research "remote work productivity"
3. challenges "distributed workforce" solutions
4. remote work employee mental health studies
5. "hybrid work model" vs "fully remote" comparison
6. site:gov telework policy guidelines
7. "remote work" environmental impact data
8. distributed teams collaboration best practices
9. "remote work" industry adoption rates
10. future "workplace trends" expert predictions"""
try:
response = self.model.generate_content(prompt)
if not response or not response.text:
self.logger.error("Empty response from AI model")
return [main_query]
# Parse response
response_text = response.text.strip()
query_type = None
reason = None
queries_section_started = False
subqueries = []
for line in response_text.split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('TYPE:'):
query_type = line.split(':', 1)[1].strip().upper()
elif line.startswith('REASON:'):
reason = line.split(':', 1)[1].strip()
elif line.startswith('QUERIES:'):
queries_section_started = True
elif queries_section_started:
# For SIMPLE queries, just add the main query
if query_type == "SIMPLE":
subqueries = [main_query]
break
# For COMPLEX queries, process numbered lines
if any(c.isdigit() for c in line):
try:
# Handle different number formats (1., 1-, 1), etc.
query = re.split(r'^\d+[.)-]\s*', line)[-1].strip()
# Only validate minimum length, no maximum
if query and len(query) >= 3:
subqueries.append(query)
# Break if we've reached the maximum number of queries
if len(subqueries) >= MAX_QUERIES:
break
except Exception as e:
self.logger.warning(f"Error processing query line '{line}': {e}")
continue
# Always include the main query for complex queries
if main_query not in subqueries:
subqueries.append(main_query)
# Log results
self.logger.info(f"Query type: {query_type} - {reason}")
self.logger.info(f"Generated {len(subqueries)} queries:")
for q in subqueries:
self.logger.info(f"Query: {q}")
return subqueries if subqueries else [main_query]
except Exception as e:
self.logger.error(f"Error generating queries: {e}")
return [main_query]
async def batch_web_search(self, queries: List[str], num_results: int = 10) -> List[Dict]:
"""Perform multiple web searches in parallel with increased batch size."""
self.logger.info(f"Batch searching {len(queries)} queries...")
# Increased batch size for better throughput
batch_size = 10 # Increased from 3
max_concurrent = 5 # Maximum concurrent API calls
semaphore = asyncio.Semaphore(max_concurrent)
async def search_with_semaphore(query: str) -> List[Dict]:
"""Perform a single search with semaphore control."""
async with semaphore:
try:
# Add retry logic with exponential backoff
max_retries = 3
base_delay = 1
for attempt in range(max_retries):
try:
# Create SSL-unverified HTTP client
http = httplib2.Http(timeout=30)
http.disable_ssl_certificate_validation = True
# Call list() synchronously, then execute() asynchronously
search_response = self.search_engine.list(
q=query,
cx=self.search_engine_id,
num=num_results
)
results = await asyncio.to_thread(
search_response.execute,
http=http
)
if not results or 'items' not in results:
self.logger.warning(f"No results found for query: {query}")
return []
search_results = []
for item in results.get('items', []):
try:
url = item.get('link', '')
if not url or self.should_skip_url(url):
continue
# Use browser manager's rewrite_url
url = self.browser_manager.rewrite_url(url)
result = {
'title': item.get('title', ''),
'url': url,
'snippet': item.get('snippet', ''),
'domain': item.get('displayLink', ''),
'source_queries': [query]
}
search_results.append(result)
except Exception as item_error:
self.logger.warning(f"Error processing search result: {item_error}")
continue
return search_results
except Exception as retry_error:
if attempt == max_retries - 1:
raise
delay = base_delay * (2 ** attempt)
self.logger.warning(f"Search attempt {attempt + 1} failed: {retry_error}")
await asyncio.sleep(delay)
except Exception as e:
self.logger.error(f"Search error for query '{query}': {str(e)}")
return []
# Process queries in parallel batches
all_results = []
for i in range(0, len(queries), batch_size):
batch_queries = queries[i:i + batch_size]
batch_tasks = [search_with_semaphore(query) for query in batch_queries]
batch_results = await asyncio.gather(*batch_tasks)
# Process batch results
for query, query_results in zip(batch_queries, batch_results):
for result in query_results:
url = result['url']
if url in self.all_results:
self.all_results[url]['source_queries'].append(query)
else:
self.all_results[url] = result
all_results.append(result)
# Small delay between batches to prevent rate limiting
if i + batch_size < len(queries):
await asyncio.sleep(0.5)
# Deduplicate results
all_results = self._deduplicate_results(all_results)
self.logger.info(f"Found {len(all_results)} unique results across all queries")
return all_results
def rank_new_results(self, main_query: str, new_results: List[Dict]) -> List[Dict]:
"""Rank only new search results based on relevance using AI."""
if not new_results:
return []
# Deduplicate results before ranking
new_results = self._deduplicate_results(new_results)
self.logger.info(f"Ranking {len(new_results)} new URLs")
prompt = f"""{self.system_context}
For the query '{main_query}', analyze and rank these search results.
For each URL, determine:
1. Relevance score (0-0.99, or 1.0 for perfect matches)
2. Whether to scrape the content (YES/NO)
3. Scraping level (LOW/MEDIUM/HIGH) - determines how much content to extract:
- LOW: 3000 chars - For basic/overview content (default)
- MEDIUM: 6000 chars - For moderate detail
- HIGH: 10000 chars - For in-depth analysis
Consider these factors:
- Content depth and relevance to query
- Source authority and reliability
- Need for detailed information
- If the request is simple (e.g. 2+2), mark none for scraping
Format response EXACTLY as follows, one entry per line:
[url] | [score] | [YES/NO] | [LOW/MEDIUM/HIGH]
IMPORTANT RULES:
- All scores must be unique (no ties) and between 0 and 1.0
- Only give 1.0 for perfect matches
- Mark YES for scraping only if the content is likely highly relevant
- Scraping level should match content importance and depth
- You MUST rank ALL URLs provided
- Provide scraping decisions for ALL URLs
URLs to analyze:
""" + "\n".join([
f"{data['url']}\nTitle: {data['title']}\nSnippet: {data['snippet']}"
for data in new_results
])
try:
response = self.ranking_model.generate_content(prompt)
# Parse rankings and verify uniqueness
rankings = {}
scrape_decisions = {}
scrape_count = 0 # Track number of URLs marked for scraping
for line in response.text.strip().split('\n'):
try:
# Split line by | and strip whitespace
parts = [p.strip() for p in line.split('|')]
if len(parts) != 4: # Updated to expect 4 parts
continue
url, score_str, scrape_decision, scrape_level = parts
score = float(score_str)
rankings[url] = score
# Only mark for scraping if we haven't hit our limit
should_scrape = scrape_decision.upper() == 'YES' and scrape_count < 5
if should_scrape:
scrape_count += 1
# Validate and normalize scraping level
scrape_level = scrape_level.upper()
if scrape_level not in ['LOW', 'MEDIUM', 'HIGH']:
scrape_level = 'MEDIUM' # Default to LOW if invalid
scrape_decisions[url] = {
'should_scrape': should_scrape,
'scrape_level': scrape_level,
'original_decision': scrape_decision.upper() == 'YES'
}
# Track high-ranking URLs (score > 0.6)
if score > 0.6:
result = next((r for r in new_results if r['url'] == url), None)
if result:
self.high_ranking_urls[url] = {
'score': score,
'title': result['title'],
'snippet': result['snippet'],
'domain': result['domain'],
'source_queries': result['source_queries'],
'scrape_decision': scrape_decisions[url]
}
except (ValueError, IndexError):
continue
# Update scores and sort results
ranked_results = []
for result in new_results:
if result['url'] in rankings:
result['relevance_score'] = rankings[result['url']]
result['scrape_decision'] = scrape_decisions[result['url']]
ranked_results.append(result)
ranked_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
# Log summary instead of all URLs
self.logger.info(
f"Ranking summary:\n"
f"Total URLs: {len(ranked_results)}\n"
f"URLs marked for scraping: {scrape_count}\n"
f"High-ranking URLs (score > 0.6): {len(self.high_ranking_urls)}"
)
return ranked_results
except Exception as e:
self.logger.error(f"Ranking error: {e}")
return new_results
def get_scrape_limit(self, scrape_level: str) -> int:
"""Get character limit based on scraping level."""
limits = {
'LOW': 3000,
'MEDIUM': 6000,
'HIGH': 10000
}
return limits.get(scrape_level.upper(), 3000) # Default to LOW if invalid
def rank_all_results(self, main_query: str) -> List[Dict]:
"""Get all results sorted by their existing ranking scores."""
if not self.all_results:
return []
self.logger.info(f"Getting all {len(self.all_results)} ranked results")
# Simply sort by existing scores
ranked_results = sorted(
[r for r in self.all_results.values() if r['url'] not in self.blacklisted_urls],
key=lambda x: x.get('relevance_score', 0),
reverse=True
)
return ranked_results
def analyze_research_state(self, main_query: str, current_data: Dict) -> Tuple[bool, str, List[str], str]:
"""Analyze current research state and decide if more research is needed."""
self.logger.info("Analyzing research state...")
prompt = f"""{self.system_context}
Analyze the research on: '{main_query}'
Current data: {json.dumps(current_data, indent=2)}
Current iteration: {self.research_iterations}
IMPORTANT DECISION GUIDELINES:
1. For simple factual queries (e.g. "2+2", "capital of France"), say NO immediately and mark as SIMPLE
2. For queries needing more depth/verification:
- On iteration 0-2: Say YES if significant information is missing
- On iteration 3: Only say YES if crucial information is missing
- On iteration 4+: Strongly lean towards NO unless absolutely critical information is missing
- If a section called "Further Research" can be written, continue research.
3. Consider the query ANSWERED when you have:
- Sufficient high-quality sources (relevance_score > 0.6)
- Enough information to provide a comprehensive answer
- Cross-verified key information from multiple sources
ANALYSIS STEPS:
1. First, determine if this is a simple factual query requiring no research
2. If research is needed, assess if current findings sufficiently answer the query
3. Review all scraped URLs and identify any that are:
- Not directly relevant to the main query
- Contain tangential or off-topic information
- Duplicate or redundant information
- Low quality or unreliable sources
4. Only continue research if genuinely valuable information is missing
5. Generate a custom report structure based on the query type and findings. If the query is simple, the report should be a simple answer to the query. If the query is complex, the report should be a comprehensive report with all the information needed to answer the query.
6. Mark any unscraped URLs that should be scraped in the next iteration
Format response EXACTLY as follows:
DECISION: [YES (continue research)/NO (produce final report)]
TYPE: [SIMPLE/COMPLEX]
REASON: [One clear sentence explaining the decision, mentioning iteration number if relevant]
REMOVE_URLS: [List URLs to remove from context, one per line, with brief reason after # symbol]
BLACKLIST: [List URLs to blacklist, one per line. These URLs will be ignored in future iterations.]
MISSING: [List missing information aspects, one per line]
SEARCH_QUERIES: [List complete search queries, one per line, max 7. Search formatting and quotes are allowed. These queries should be specific to the information you are looking for.]
SCRAPE_NEXT: [List URLs to scrape in next iteration, one per line, in format: URL | LOW/MEDIUM/HIGH]"""
try:
response = self.analysis_model.generate_content(prompt)
# Parse sections
decision = False
query_type = "COMPLEX" # Default to complex
blacklist = []
missing = []
search_queries = []
urls_to_scrape_next = {}
urls_to_remove = {}
current_section = None
for line in response.text.split('\n'):
line = line.strip()
if not line:
continue
# Handle section headers
if line.startswith("DECISION:"):
decision = "YES" in line.upper()
current_section = None
elif line.startswith("TYPE:"):
query_type = line.split(":", 1)[1].strip().upper()
current_section = None
elif line.startswith("REMOVE_URLS:"):
current_section = "REMOVE_URLS"
elif line.startswith("BLACKLIST:"):
current_section = "BLACKLIST"
elif line.startswith("MISSING:"):
current_section = "MISSING"
elif line.startswith("SEARCH_QUERIES:"):
current_section = "SEARCH_QUERIES"
elif line.startswith("SCRAPE_NEXT:"):
current_section = "SCRAPE_NEXT"
elif line.startswith("REPORT_STRUCTURE:"):
current_section = "REPORT_STRUCTURE"
elif line.startswith("REASON:"): # Add reason to explanation
explanation = line.split(":", 1)[1].strip()
current_section = None
else:
# Handle section content
if current_section == "REMOVE_URLS" and line.startswith('http'):
# Parse URL and reason if provided
parts = line.split('#', 1)
url = parts[0].strip()
reason = parts[1].strip() if len(parts) > 1 else "Not relevant to query"
urls_to_remove[url] = reason
elif current_section == "BLACKLIST" and line.startswith('http'):
blacklist.append(line.strip())
elif current_section == "MISSING" and line.startswith('-'):
missing.append(line[1:].strip())
elif current_section == "SEARCH_QUERIES":
# Handle multiple query formats
if line.startswith('- '):
search_queries.append(line[2:].strip())
elif line.strip() and not line.startswith(('DECISION:', 'TYPE:', 'BLACKLIST:', 'MISSING:', 'SEARCH_QUERIES:', 'REASON:', 'REPORT_STRUCTURE:', 'SCRAPE_NEXT:', 'REMOVE_URLS:')):
# Handle numbered or plain queries
clean_query = line.split('. ', 1)[-1] if '. ' in line else line
search_queries.append(clean_query.strip())
elif current_section == "SCRAPE_NEXT":
# Parse URLs marked for scraping in next iteration
if '|' in line:
url, level = [part.strip() for part in line.split('|')]
if level.upper() in ['LOW', 'MEDIUM', 'HIGH']:
urls_to_scrape_next[url] = level.upper()
elif current_section == "REPORT_STRUCTURE":
if report_structure:
report_structure += "\n"
report_structure += line
# Process URLs to remove
for url, reason in urls_to_remove.items():
# Remove from all_results
if url in self.all_results:
del self.all_results[url]
self.logger.info(f"Removed URL from context: {url} (Reason: {reason})")
# Remove from high_ranking_urls
if url in self.high_ranking_urls:
del self.high_ranking_urls[url]
self.logger.info(f"Removed URL from high-ranking URLs: {url}")
# Add to blacklist to prevent re-scraping
self.blacklisted_urls.add(url)
# Remove from scraped_urls if present
if url in self.scraped_urls:
self.scraped_urls.remove(url)
# Update blacklist with additional URLs
self.blacklisted_urls.update(blacklist)
# Update scraping decisions in all_results for next iteration
for url, level in urls_to_scrape_next.items():
if url in self.all_results and url not in self.scraped_urls:
self.all_results[url]['scrape_decision'] = {
'should_scrape': True,
'scrape_level': level,
'original_decision': True
}
self.logger.info(f"Marked {url} for {level} scraping in next iteration")
# Log missing information if any
if missing:
self.logger.info("Missing information:\n- " + "\n- ".join(missing))
# Prepare explanation (simplified)
explanation = f"Query type: {query_type}. " + explanation
if urls_to_remove:
explanation += f"\nRemoved {len(urls_to_remove)} URLs"
if search_queries:
explanation += f"\nGenerated {len(search_queries)} new search queries"
if urls_to_scrape_next:
explanation += f"\nMarked {len(urls_to_scrape_next)} URLs for next scraping"
return decision, explanation, search_queries, ""
except Exception as e:
self.logger.error(f"Analysis error: {e}")
return False, str(e), [], ""
def save_report_streaming(self, query: str, report_text, sources_used: str):
"""Save the report to a markdown file."""
try:
# Create reports directory if it doesn't exist
os.makedirs('reports', exist_ok=True)
# Clean and truncate query for filename
clean_query = self.clean_filename(query)
# Create filename with date
filename = f"reports/{clean_query}-{self.current_date}.md"
# Write the report content
with open(filename, 'w', encoding='utf-8') as f:
try:
# Write the report text
if report_text:
f.write(report_text)
self.log_token_usage(report_text, "Report content")
# Add the sources section at the end
f.write("\n\n") # Add some spacing
f.write(sources_used)
except Exception as e:
self.logger.error(f"Error writing report: {e}")
raise # Re-raise to be caught by outer try-except
self.logger.info(f"Report saved to {filename}")
return filename
except Exception as e:
self.logger.error(f"Error saving report: {e}")
return None
def count_tokens(self, text: str) -> int:
"""Count tokens accurately using Gemini's token counter."""
if not text:
return 0
try:
# Convert text to string if it's not already
text = str(text)
# Count tokens using Gemini's counter
return self.model.count_tokens(text).total_tokens
except Exception as e:
self.logger.warning(f"Token counting error: {e}")
# Fallback to rough estimation if counter fails
return len(text) // 4
def log_token_usage(self, text: str, operation: str):
"""Log token usage for an operation with improved tracking."""
try:
tokens = self.count_tokens(text)
self.total_tokens += tokens
self.token_usage_by_operation[operation] += tokens
except Exception as e:
self.logger.error(f"Error logging token usage: {e}")
def reset_state(self):
"""Reset all state tracking for a new query."""
self.previous_queries = set()
self.all_results = {}
self.high_ranking_urls = {}
self.blacklisted_urls = set()
self.scraped_urls = set()
self.research_iterations = 0
self.total_tokens = 0
self.content_tokens = 0
self.token_usage_by_operation.clear()
self.research_data = None
self.logger.info("Reset research state and token counter")
def clean_filename(self, query: str, max_length: int = 100) -> str:
"""Clean and truncate query for filename creation."""
# Remove special characters and convert spaces to hyphens
clean_query = re.sub(r'[^\w\s-]', '', query).strip().lower()
clean_query = re.sub(r'[-\s]+', '-', clean_query)
# Truncate if longer than max_length while keeping whole words
if len(clean_query) > max_length:
clean_query = clean_query[:max_length].rsplit('-', 1)[0]
# Add current time to the filename
current_time = datetime.now().strftime("%H-%M-%S")
return f"{clean_query}-{current_time}"
async def generate_report(self, main_query: str, research_data: Dict, report_structure: str) -> Tuple[str, str]:
"""Generate a comprehensive report from the gathered research data."""
max_retries = 3
base_delay = 2 # Base delay in seconds
# Deduplicate research data before generating report
research_data = self._deduplicate_research_data(research_data)
for attempt in range(max_retries):
try:
# Prepare enhanced context with high-ranking URLs and detailed source information
high_ranking_sources = {}
for url, data in self.high_ranking_urls.items():
# Find the full content from research_data
source_content = None
for finding in research_data['final_sources']:
if finding['source'] == url:
source_content = finding['content']
break
high_ranking_sources[url] = {
'title': data['title'],
'snippet': data['snippet'],
'score': data['score'],
'content': source_content, # Full content if available
'domain': data['domain'],
'queries_used': data['source_queries']
}
# Sort sources by score for easy reference
sorted_sources = sorted(
high_ranking_sources.items(),
key=lambda x: x[1]['score'],
reverse=True
)
# Create numbered references for citations
source_references = {
url: f"[{i+1}]"
for i, (url, _) in enumerate(sorted_sources)
}
# Create source list for the report
sources_used = "\n\n## Sources Used\n\n"
for i, (url, data) in enumerate(sorted_sources, 1):
sources_used += f"[{i}] {data['title']}\n"
sources_used += f"- URL: {url}\n"
sources_used += f"- Relevance Score: {data['score']:.2f}\n"
sources_used += f"- Domain: {data['domain']}\n\n"
# Create optimized report context with only final sources
report_context = {
'main_query': main_query,
'final_sources': research_data['final_sources'], # Only include final sources
'high_ranking_sources': dict(sorted_sources), # Ordered by score
'source_references': source_references, # For citations
'total_sources_analyzed': len(self.all_results),
'total_high_ranking_sources': len(self.high_ranking_urls),
'research_iterations': self.research_iterations,
'total_queries_used': len(self.previous_queries),
'queries_by_iteration': [
iter_data['queries_used']
for iter_data in research_data['iterations']
]
}
prompt = f"""Generate a comprehensive research report on: '{main_query}'
# For simple queries (mathematical, factual, or definitional):
- Use # for the main title
- Use ## for main sections
- Use ### for subsections if needed
- Provide a clear, direct answer
- Include a brief explanation of the concept if relevant
- Keep additional context minimal and focused
# For complex queries:
- Create a title with # heading
- Use ## for main sections
- Use ### for subsections
- Use #### for detailed subsection breakdowns where needed
- Include comprehensive analysis of all relevant information
- Address any contradictions or nuances in the sources
- Provide thorough explanations and context
# General Guidelines:
- The report should be detailed and include all relevant information from sources
- Always use proper heading hierarchy (# → ## → ### → ####)
- Use **bold** for emphasis on key points
- Format numbers naturally with proper thousands separators
- Use [1][2][3] format for references, DO NOT DO [1, 2, 3]
- Mention when using knowledge beyond the sources and note potential for hallucination
- Use LaTeX for ALL math expressions by wrapping them in $$. Examples:
- For inline math: $$x^2$$ or $$x_2$$
- For display math on its own line:
$$
x^2 + y^2 = z^2
$$
- DO NOT use single $ for math. DO NOT use HTML formatting like <sup> or <sub>
# The report should be comprehensive and thorough:
- Aim for a length of at least 16,000 words to ensure complete coverage
- Include extensive analysis and discussion of all relevant aspects
- Break down complex topics into detailed subsections
- Provide rich examples and case studies where applicable
- Explore historical context, current state, and future implications
- Address multiple perspectives and viewpoints
- Support all claims with evidence from the sources
- Use clear topic transitions and maintain logical flow
- Ensure proper citation of sources throughout
- Provide tables if relevant.
Start the report immediately without any additional formatting or preamble.
Format in clean Markdown without code blocks (unless showing code snippets).
DO NOT include a sources section - it will be added automatically.
Using the following information:
{json.dumps(report_context, indent=2)}"""
# Save prompt to text file with current date and time
# Create directory if it doesn't exist
os.makedirs("prompts", exist_ok=True)
# Create file
prompt_file = f"prompts/prompt-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
with open(prompt_file, "w") as f:
f.write(prompt)
# Generate response
response = self.report_model.generate_content(
prompt,
generation_config={
'temperature': 1,
'max_output_tokens': 8192,
},
safety_settings=self.safety_settings
)
# Check for prompt feedback and blocking
if hasattr(response, 'prompt_feedback') and response.prompt_feedback:
if any(feedback.block_reason for feedback in response.prompt_feedback):
raise ValueError(f"Prompt blocked: {response.prompt_feedback}")
# Verify we have a valid response
if not response or not response.text:
raise ValueError("Invalid response from model")
self.logger.info("Report generation completed")
return response.text, sources_used
except Exception as e:
delay = base_delay * (2 ** attempt)
if attempt < max_retries - 1:
self.logger.warning(
f"Report generation attempt {attempt + 1} failed: {str(e)}\n"
f"Retrying in {delay} seconds..."
)
await asyncio.sleep(delay)
else:
self.logger.error(f"All report generation attempts failed: {str(e)}")
return None, ""
return None, ""
async def research(self, query: str) -> str:
"""Main research function that coordinates the entire research process."""
self.reset_state()
self.logger.info(f"Starting research: {query}")
research_data = {
'main_query': query,
'iterations': [],
'final_sources': []
}
while self.research_iterations < self.MAX_ITERATIONS:
# Combine iteration logs
self.logger.info(
f"Iteration {self.research_iterations + 1}: "
f"Processing {len(self.previous_queries)} queries"
)
# Get search queries for this iteration
if self.research_iterations == 0:
# First iteration: generate initial queries
search_queries = self.generate_subqueries(query, research_data)
self.previous_queries.update(search_queries)
else:
if not new_queries:
self.logger.warning("No additional search queries provided")
break
search_queries = [q for q in new_queries if q not in self.previous_queries]
if not search_queries:
self.logger.warning("No new unique queries to process")
break
self.previous_queries.update(search_queries)
self.logger.info(f"Processing {len(search_queries)} queries for iteration {self.research_iterations + 1}")
# Parallel processing of search and content extraction
async def process_search_batch():
# Perform searches in parallel
search_results = await self.batch_web_search(search_queries)
if not search_results:
return None, []
# Rank results in parallel
ranked_results = self.rank_new_results(query, search_results)
# Remove duplicates with parallel processing
seen_urls = set()
unique_ranked_results = []
async def process_result(result):
url = result['url']
rewritten_url = self.browser_manager.rewrite_url(url)
if rewritten_url not in seen_urls:
seen_urls.add(rewritten_url)
result['url'] = rewritten_url
return result
return None
# Process results in parallel
tasks = [process_result(result) for result in ranked_results]
processed_results = await asyncio.gather(*tasks)
unique_ranked_results = [r for r in processed_results if r is not None]
return unique_ranked_results, seen_urls
# Execute search batch processing
unique_ranked_results, seen_urls = await process_search_batch()
if not unique_ranked_results:
self.logger.warning("No valid results found in this iteration")
break
# Prepare URLs for scraping with parallel processing
url_to_result = {}
async def process_scraping_candidate(result):
url = result['url']
if (url not in self.scraped_urls and
result.get('scrape_decision', {}).get('should_scrape', False)):
return url, result
return None
# Process scraping candidates in parallel
tasks = [
process_scraping_candidate(result)
for result in list(self.all_results.values()) + unique_ranked_results