diff --git a/README.md b/README.md
index 55d23cd..1388f63 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ following:
from contentmap.sitemap import SitemapToContentDatabase
database = SitemapToContentDatabase(
- sitemap_url="https://yourblog.com/sitemap.xml",
+ sitemap_sources=["https://yourblog.com/sitemap.xml"],
concurrency=10,
include_vss=True
)
diff --git a/contentmap/sitemap.py b/contentmap/sitemap.py
index 3e1cddf..5874aa4 100644
--- a/contentmap/sitemap.py
+++ b/contentmap/sitemap.py
@@ -1,6 +1,8 @@
import asyncio
import logging
+from typing import Literal
import requests
+import os
import aiohttp
import trafilatura
@@ -11,10 +13,17 @@
class SitemapToContentDatabase:
+ SOURCE_TYPE_URL: Literal['url'] = 'url'
+ SOURCE_TYPE_DISK: Literal['disk'] = 'disk'
+ SourceType = Literal['url', 'disk']
- def __init__(self, sitemap_url, seconds_timeout=10, concurrency=None,
+ def __init__(self, sitemap_sources: list,
+ source_type: SourceType = SOURCE_TYPE_URL,
+ seconds_timeout=10,
+ concurrency=None,
include_vss=False):
- self.sitemap_url = sitemap_url
+ self.sitemap_sources = sitemap_sources
+ self.source_type = source_type
self.semaphore = asyncio.Semaphore(concurrency) if concurrency is not None else None
self.timeout = aiohttp.ClientTimeout(
sock_connect=seconds_timeout,
@@ -30,13 +39,34 @@ def build(self):
cm.build()
def get_urls(self):
- r = requests.get(self.sitemap_url)
+ all_urls = []
+ if self.source_type == self.SOURCE_TYPE_URL:
+ for sitemap_url in self.sitemap_sources:
+ urls = self._get_urls_from_url(sitemap_url)
+ all_urls.extend(urls)
+ elif self.source_type == self.SOURCE_TYPE_DISK:
+ for directory in self.sitemap_sources:
+ for filename in os.listdir(directory):
+ if filename.endswith('.xml'):
+ filepath = os.path.join(directory, filename)
+ urls = self._get_urls_from_disk(filepath)
+ all_urls.extend(urls)
+ return all_urls
+
+ def _get_urls_from_url(self, sitemap_url):
+ r = requests.get(sitemap_url)
tree = etree.fromstring(r.content)
- urls = [
+ return self._extract_urls_from_tree(tree)
+
+ def _get_urls_from_disk(self, filepath):
+ tree = etree.parse(filepath)
+ return self._extract_urls_from_tree(tree)
+
+ def _extract_urls_from_tree(self, tree):
+ return [
url.text for url
in tree.findall(".//{http://www.sitemaps.org/schemas/sitemap/0.9}loc")
]
- return urls
async def get_contents(self, urls):
async with aiohttp.ClientSession(timeout=self.timeout) as session:
diff --git a/tests/fixtures/sitemap_folder_a/sitemap_a.xml b/tests/fixtures/sitemap_folder_a/sitemap_a.xml
new file mode 100644
index 0000000..c056f48
--- /dev/null
+++ b/tests/fixtures/sitemap_folder_a/sitemap_a.xml
@@ -0,0 +1,9 @@
+
+
+
+ https://www.example.com/docs/en/example/?topic=testing
+
+
+ https://www.example.com/docs/en/example/?topic=contact-us
+
+
\ No newline at end of file
diff --git a/tests/fixtures/sitemap_folder_b/sitemap_b.xml b/tests/fixtures/sitemap_folder_b/sitemap_b.xml
new file mode 100644
index 0000000..00b7c97
--- /dev/null
+++ b/tests/fixtures/sitemap_folder_b/sitemap_b.xml
@@ -0,0 +1,9 @@
+
+
+
+ https://www.example.com/docs/en/example/?topic=library-overview
+
+
+ https://www.example.com/docs/en/example/?topic=about-this-content
+
+
\ No newline at end of file
diff --git a/tests/test_sitemap.py b/tests/test_sitemap.py
new file mode 100644
index 0000000..c14113a
--- /dev/null
+++ b/tests/test_sitemap.py
@@ -0,0 +1,71 @@
+import os
+import unittest
+import pytest
+
+from unittest.mock import patch, MagicMock
+from contentmap.sitemap import SitemapToContentDatabase
+
+
+class TestSitemapToContentDatabase(unittest.TestCase):
+ def create_mock_response(self, content):
+ mock_response = MagicMock()
+ mock_response.content = content
+ return mock_response
+
+ def generate_sample_sitemap_xml(self, url):
+ return f'''
+
+
+ {url}
+
+ '''
+ @patch('contentmap.sitemap.requests.get')
+ def test_get_urls_given_one_sitemap_url(self, mock_get):
+ mock_get.return_value = self.create_mock_response(self.generate_sample_sitemap_xml('https://www.example.com/docs/en/example/?topic=testing'))
+
+ sitemap_db = SitemapToContentDatabase(sitemap_sources=['https://example.com/sitemap.xml'], source_type='url')
+ urls = sitemap_db.get_urls()
+
+ self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing'])
+ mock_get.assert_called_once_with('https://example.com/sitemap.xml')
+
+
+ @patch('contentmap.sitemap.requests.get')
+ def test_get_urls_given_multiple_sitemap_urls(self, mock_get):
+ mock_get.side_effect = [
+ self.create_mock_response(self.generate_sample_sitemap_xml('https://www.example.com/docs/en/example/?topic=testing')),
+ self.create_mock_response(self.generate_sample_sitemap_xml('https://www.anotherexample.com/docs/en/example/?topic=contact-us'))
+ ]
+
+ sitemap_db = SitemapToContentDatabase(sitemap_sources=['https://example.com/sitemap.xml', 'https://anotherexample.com/sitemap.xml'], source_type='url')
+ urls = sitemap_db.get_urls()
+
+ self.assertEqual(urls, [
+ 'https://www.example.com/docs/en/example/?topic=testing',
+ 'https://www.anotherexample.com/docs/en/example/?topic=contact-us'
+ ])
+ mock_get.assert_any_call('https://example.com/sitemap.xml')
+ mock_get.assert_any_call('https://anotherexample.com/sitemap.xml')
+ self.assertEqual(mock_get.call_count, 2)
+
+ def test_get_urls_given_one_location_on_disk(self):
+ sitemap_folder_a_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_a')
+ sitemap_db = SitemapToContentDatabase(sitemap_sources=[sitemap_folder_a_path], source_type='disk')
+ urls = sitemap_db.get_urls()
+
+ self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing',
+ 'https://www.example.com/docs/en/example/?topic=contact-us'
+ ])
+
+
+ def test_get_urls_given_multiple_locations_on_disk(self):
+ sitemap_folder_a_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_a')
+ sitemap_folder_b_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_b')
+ sitemap_db = SitemapToContentDatabase(sitemap_sources=[sitemap_folder_a_path, sitemap_folder_b_path], source_type='disk')
+ urls = sitemap_db.get_urls()
+
+ self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing',
+ 'https://www.example.com/docs/en/example/?topic=contact-us',
+ 'https://www.example.com/docs/en/example/?topic=library-overview',
+ 'https://www.example.com/docs/en/example/?topic=about-this-content'
+ ])
\ No newline at end of file