diff --git a/README.md b/README.md index 55d23cd..1388f63 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ following: from contentmap.sitemap import SitemapToContentDatabase database = SitemapToContentDatabase( - sitemap_url="https://yourblog.com/sitemap.xml", + sitemap_sources=["https://yourblog.com/sitemap.xml"], concurrency=10, include_vss=True ) diff --git a/contentmap/sitemap.py b/contentmap/sitemap.py index 3e1cddf..5874aa4 100644 --- a/contentmap/sitemap.py +++ b/contentmap/sitemap.py @@ -1,6 +1,8 @@ import asyncio import logging +from typing import Literal import requests +import os import aiohttp import trafilatura @@ -11,10 +13,17 @@ class SitemapToContentDatabase: + SOURCE_TYPE_URL: Literal['url'] = 'url' + SOURCE_TYPE_DISK: Literal['disk'] = 'disk' + SourceType = Literal['url', 'disk'] - def __init__(self, sitemap_url, seconds_timeout=10, concurrency=None, + def __init__(self, sitemap_sources: list, + source_type: SourceType = SOURCE_TYPE_URL, + seconds_timeout=10, + concurrency=None, include_vss=False): - self.sitemap_url = sitemap_url + self.sitemap_sources = sitemap_sources + self.source_type = source_type self.semaphore = asyncio.Semaphore(concurrency) if concurrency is not None else None self.timeout = aiohttp.ClientTimeout( sock_connect=seconds_timeout, @@ -30,13 +39,34 @@ def build(self): cm.build() def get_urls(self): - r = requests.get(self.sitemap_url) + all_urls = [] + if self.source_type == self.SOURCE_TYPE_URL: + for sitemap_url in self.sitemap_sources: + urls = self._get_urls_from_url(sitemap_url) + all_urls.extend(urls) + elif self.source_type == self.SOURCE_TYPE_DISK: + for directory in self.sitemap_sources: + for filename in os.listdir(directory): + if filename.endswith('.xml'): + filepath = os.path.join(directory, filename) + urls = self._get_urls_from_disk(filepath) + all_urls.extend(urls) + return all_urls + + def _get_urls_from_url(self, sitemap_url): + r = requests.get(sitemap_url) tree = etree.fromstring(r.content) - urls = [ + return self._extract_urls_from_tree(tree) + + def _get_urls_from_disk(self, filepath): + tree = etree.parse(filepath) + return self._extract_urls_from_tree(tree) + + def _extract_urls_from_tree(self, tree): + return [ url.text for url in tree.findall(".//{http://www.sitemaps.org/schemas/sitemap/0.9}loc") ] - return urls async def get_contents(self, urls): async with aiohttp.ClientSession(timeout=self.timeout) as session: diff --git a/tests/fixtures/sitemap_folder_a/sitemap_a.xml b/tests/fixtures/sitemap_folder_a/sitemap_a.xml new file mode 100644 index 0000000..c056f48 --- /dev/null +++ b/tests/fixtures/sitemap_folder_a/sitemap_a.xml @@ -0,0 +1,9 @@ + + + + https://www.example.com/docs/en/example/?topic=testing + + + https://www.example.com/docs/en/example/?topic=contact-us + + \ No newline at end of file diff --git a/tests/fixtures/sitemap_folder_b/sitemap_b.xml b/tests/fixtures/sitemap_folder_b/sitemap_b.xml new file mode 100644 index 0000000..00b7c97 --- /dev/null +++ b/tests/fixtures/sitemap_folder_b/sitemap_b.xml @@ -0,0 +1,9 @@ + + + + https://www.example.com/docs/en/example/?topic=library-overview + + + https://www.example.com/docs/en/example/?topic=about-this-content + + \ No newline at end of file diff --git a/tests/test_sitemap.py b/tests/test_sitemap.py new file mode 100644 index 0000000..c14113a --- /dev/null +++ b/tests/test_sitemap.py @@ -0,0 +1,71 @@ +import os +import unittest +import pytest + +from unittest.mock import patch, MagicMock +from contentmap.sitemap import SitemapToContentDatabase + + +class TestSitemapToContentDatabase(unittest.TestCase): + def create_mock_response(self, content): + mock_response = MagicMock() + mock_response.content = content + return mock_response + + def generate_sample_sitemap_xml(self, url): + return f''' + + + {url} + + ''' + @patch('contentmap.sitemap.requests.get') + def test_get_urls_given_one_sitemap_url(self, mock_get): + mock_get.return_value = self.create_mock_response(self.generate_sample_sitemap_xml('https://www.example.com/docs/en/example/?topic=testing')) + + sitemap_db = SitemapToContentDatabase(sitemap_sources=['https://example.com/sitemap.xml'], source_type='url') + urls = sitemap_db.get_urls() + + self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing']) + mock_get.assert_called_once_with('https://example.com/sitemap.xml') + + + @patch('contentmap.sitemap.requests.get') + def test_get_urls_given_multiple_sitemap_urls(self, mock_get): + mock_get.side_effect = [ + self.create_mock_response(self.generate_sample_sitemap_xml('https://www.example.com/docs/en/example/?topic=testing')), + self.create_mock_response(self.generate_sample_sitemap_xml('https://www.anotherexample.com/docs/en/example/?topic=contact-us')) + ] + + sitemap_db = SitemapToContentDatabase(sitemap_sources=['https://example.com/sitemap.xml', 'https://anotherexample.com/sitemap.xml'], source_type='url') + urls = sitemap_db.get_urls() + + self.assertEqual(urls, [ + 'https://www.example.com/docs/en/example/?topic=testing', + 'https://www.anotherexample.com/docs/en/example/?topic=contact-us' + ]) + mock_get.assert_any_call('https://example.com/sitemap.xml') + mock_get.assert_any_call('https://anotherexample.com/sitemap.xml') + self.assertEqual(mock_get.call_count, 2) + + def test_get_urls_given_one_location_on_disk(self): + sitemap_folder_a_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_a') + sitemap_db = SitemapToContentDatabase(sitemap_sources=[sitemap_folder_a_path], source_type='disk') + urls = sitemap_db.get_urls() + + self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing', + 'https://www.example.com/docs/en/example/?topic=contact-us' + ]) + + + def test_get_urls_given_multiple_locations_on_disk(self): + sitemap_folder_a_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_a') + sitemap_folder_b_path = os.path.join(os.path.dirname(__file__), 'fixtures', 'sitemap_folder_b') + sitemap_db = SitemapToContentDatabase(sitemap_sources=[sitemap_folder_a_path, sitemap_folder_b_path], source_type='disk') + urls = sitemap_db.get_urls() + + self.assertEqual(urls, ['https://www.example.com/docs/en/example/?topic=testing', + 'https://www.example.com/docs/en/example/?topic=contact-us', + 'https://www.example.com/docs/en/example/?topic=library-overview', + 'https://www.example.com/docs/en/example/?topic=about-this-content' + ]) \ No newline at end of file