Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(RHINENG-15555): Fix infinite export when a host is deleted #2236

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions api/host_query_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Query
from sqlalchemy.orm import load_only
from sqlalchemy.sql.expression import ColumnElement
Expand Down Expand Up @@ -235,7 +236,7 @@ def _find_hosts_entities_query(


def _find_hosts_model_query(columns: list[ColumnElement] | None = None, identity: Any = None) -> Query:
query_base = select(Host).join(HostGroupAssoc, isouter=True).join(Group, isouter=True)
query_base = db.session.query(Host).join(HostGroupAssoc, isouter=True).join(Group, isouter=True)
query = query_base.filter(Host.org_id == identity.org_id)

# In this case, return a list of Hosts
Expand Down Expand Up @@ -625,10 +626,15 @@ def get_hosts_to_export(
export_host_query = _find_hosts_model_query(identity=identity, columns=columns).filter(*q_filters)
export_host_query = export_host_query.execution_options(yield_per=batch_size)

num_hosts = select(func.count()).select_from(export_host_query.subquery())
logger.debug(f"Number of hosts to be exported: {num_hosts}")
try:
num_hosts_query = select(func.count()).select_from(export_host_query.subquery())
num_hosts = db.session.scalar(num_hosts_query)
logger.debug(f"Number of hosts to be exported: {num_hosts}")

for host in db.session.scalars(export_host_query):
yield serialize_host_for_export_svc(host, staleness_timestamps=st_timestamps, staleness=staleness)

for host in db.session.scalars(export_host_query):
yield serialize_host_for_export_svc(host, staleness_timestamps=st_timestamps, staleness=staleness)
except SQLAlchemyError as e: # Most likely ObjectDeletedError, but catching all DB errors
raise InventoryException(title="DB Error", detail=str(e)) from e

db.session.close()
18 changes: 14 additions & 4 deletions app/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
# mypy: disallow-untyped-defs

from __future__ import annotations


class InventoryException(Exception):
def __init__(self, status=400, title=None, detail=None, type="about:blank"):
def __init__(
self, status: int = 400, title: str | None = None, detail: str | None = None, type: str = "about:blank"
):
self.status = status
self.title = title
self.detail = detail
self.type = type

def to_json(self):
def __str__(self) -> str:
return str(self.to_json())

def to_json(self) -> dict[str, str | int | None]:
return {
"detail": self.detail,
"status": self.status,
Expand All @@ -15,10 +25,10 @@ def to_json(self):


class InputFormatException(InventoryException):
def __init__(self, detail):
def __init__(self, detail: str):
InventoryException.__init__(self, title="Bad Request", detail=detail)


class ValidationException(InventoryException):
def __init__(self, detail):
def __init__(self, detail: str):
InventoryException.__init__(self, title="Validation Error", detail=detail)
4 changes: 2 additions & 2 deletions app/queue/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ def _handle_export_error(
response = session.post(
url=request_url,
headers=request_headers,
data=json.dumps({"message": str(error_message), "error": status_code}),
data=json.dumps({"message": error_message, "error": status_code}),
)
_handle_export_response(response, exportUUID, exportFormat)


# This function is used by create_export, needs improvement
def _handle_export_response(response: Response, exportUUID: UUID, exportFormat: str):
if response.status_code != HTTPStatus.ACCEPTED:
raise InventoryException(response.text)
raise InventoryException(detail=response.text)
elif response.text != "":
logger.info(f"{response.text} for export ID {str(exportUUID)} in {exportFormat.upper()} format")

Expand Down
13 changes: 13 additions & 0 deletions tests/test_export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
from marshmallow.exceptions import ValidationError
from sqlalchemy.orm.exc import ObjectDeletedError

from api.staleness_query import get_sys_default_staleness
from app.auth.identity import Identity
Expand Down Expand Up @@ -221,3 +222,15 @@ def test_export_one_host(flask_app, db_create_host, inventory_config):
host_list = get_host_list(identity=identity, rbac_filter=None, inventory_config=inventory_config)

assert len(host_list) == 1


@mock.patch("api.host_query_db.db.session.scalars", side_effect=ObjectDeletedError(None))
def test_export_catches_db_error(flask_app, inventory_config, mocker):
with flask_app.app.app_context():
handle_export_error_mock = mocker.patch("app.queue.export_service._handle_export_error")

validated_msg = parse_export_service_message(es_utils.create_export_message_mock())
base64_x_rh_identity = validated_msg["data"]["resource_request"]["x_rh_identity"]

create_export(validated_msg, base64_x_rh_identity, inventory_config)
handle_export_error_mock.assert_called_once()
6 changes: 3 additions & 3 deletions tests/test_host_mq_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,10 @@ def test_add_host_with_wrong_owner(mocker, mq_create_or_update_host):
)

with pytest.raises(ValidationException) as ve:
key, event, headers = mq_create_or_update_host(
mq_create_or_update_host(
host, return_all_data=True, notification_event_producer=mock_notification_event_producer
)
assert str(ve.value) == "The owner in host does not match the owner in identity"
assert ve.value.detail == "The owner in host does not match the owner in identity"
mock_notification_event_producer.write_event.assert_called_once()


Expand Down Expand Up @@ -1618,7 +1618,7 @@ def test_owner_id_different_from_cn(mocker):

with pytest.raises(ValidationException) as ve:
handle_message(json.dumps(message), mock_notification_event_producer)
assert str(ve.value) == "The owner in host does not match the owner in identity"
assert ve.value.detail == "The owner in host does not match the owner in identity"
mock_notification_event_producer.write_event.assert_called_once()


Expand Down