Skip to content

Commit

Permalink
[FIX] Task using correct db sevice name
Browse files Browse the repository at this point in the history
  • Loading branch information
josep-tecnativa committed Feb 5, 2025
1 parent 843e062 commit bdc8ded
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
43 changes: 37 additions & 6 deletions tasks_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,35 @@
PROJECT_ROOT = Path(__file__).parent.absolute()
SRC_PATH = PROJECT_ROOT / "odoo" / "custom" / "src"

_key = os.environ.get("_key", "")
DB_SERVICE = os.environ.get("DB_HOST", f"{_key}-db")
# _key = os.environ.get("_key", "").strip()
# DB_SERVICE = os.environ.get("DB_HOST") or (_key + "-db" if _key else "db")


def get_db_service_name():
"""
Return the database service name that ends with '-db'. If not found,
fall back to any service containing 'postgres' or 'db'. As a last
resort, return 'db'.
"""
for filename in ("devel.yaml", "devel.yml", "docker-compose.yml"):
compose_file = PROJECT_ROOT / filename
if compose_file.exists():
with open(compose_file) as f:
try:
compose_data = yaml.safe_load(f) or {}
services = compose_data.get("services", {})
for svc in services:
if svc.lower().endswith("-db"):
return svc
for svc in services:
if "postgres" in svc.lower() or "db" in svc.lower():
return svc
except yaml.YAMLError:
pass
return "db"


DB_SERVICE = get_db_service_name()

UID_ENV = {
"GID": os.environ.get("DOODBA_GID", str(os.getgid())),
Expand Down Expand Up @@ -987,7 +1014,9 @@ def snapshot(
if not destination_db:
destination_db = f"{source_db}-{datetime.now().strftime('%Y_%m_%d-%H_%M')}"
with c.cd(str(PROJECT_ROOT)):
cur_state = c.run(f"{DOCKER_COMPOSE_CMD} stop odoo db", pty=True).stdout
cur_state = c.run(
f"{DOCKER_COMPOSE_CMD} stop odoo {DB_SERVICE}", pty=True
).stdout
_logger.info("Snapshoting current %s DB to %s", (source_db, destination_db))
_run = f"{DOCKER_COMPOSE_CMD} run --rm -l traefik.enable=false odoo"
c.run(
Expand All @@ -997,7 +1026,7 @@ def snapshot(
)
if "Stopping" in cur_state:
# Restart services if they were previously active
c.run(f"{DOCKER_COMPOSE_CMD} start odoo db", pty=True)
c.run(f"{DOCKER_COMPOSE_CMD} start odoo {DB_SERVICE}", pty=True)


@task(
Expand All @@ -1018,7 +1047,9 @@ def restore_snapshot(
Uses click-odoo-copydb behind the scenes to restore a DB snapshot.
"""
with c.cd(str(PROJECT_ROOT)):
cur_state = c.run(f"{DOCKER_COMPOSE_CMD} stop odoo db", pty=True).stdout
cur_state = c.run(
f"{DOCKER_COMPOSE_CMD} stop odoo {DB_SERVICE}", pty=True
).stdout
if not snapshot_name:
# List DBs
res = c.run(
Expand Down Expand Up @@ -1059,4 +1090,4 @@ def restore_snapshot(
pty=True,
)
if "Stopping" in cur_state:
c.run(f"{DOCKER_COMPOSE_CMD} start odoo db", pty=True)
c.run(f"{DOCKER_COMPOSE_CMD} start odoo {DB_SERVICE}", pty=True)
25 changes: 24 additions & 1 deletion tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ def _get_db_service_name(dc: DockerClient) -> str:
return "db"


def _get_db_service_name(dc: DockerClient) -> str:
"""
Use python-on-whales to retrieve the final Compose configuration
and return the DB service name that ends with '-db'. Fallback to
any service containing 'db' or 'postgres'. Finally, fallback to 'db'.
"""
config_data = dc.compose.config() # Returns a dict with { "services": {...}, ... }
services_dict = config_data["services"]
# First pass: Look for a service name that exactly ends with '-db'
for svc_name in services_dict:
if svc_name.lower().endswith("-db"):
return svc_name

# Second pass: Fallback to any name containing 'db' or 'postgres'
for svc_name in services_dict:
if "postgres" in svc_name.lower() or "db" in svc_name.lower():
return svc_name

# Final fallback
return "db"


@pytest.mark.parametrize("dbver", ("oldest", "latest"))
def test_postgresql_client_versions(
cloned_template: Path,
Expand Down Expand Up @@ -73,6 +95,7 @@ def test_postgresql_client_versions(
)
try:
dc_prod.compose.build()
db_svc = _get_db_service_name(dc_prod)
odoo_pgdump_stdout = dc_prod.compose.run(
"odoo",
command=["pg_dump", "--version"],
Expand All @@ -83,7 +106,7 @@ def test_postgresql_client_versions(
odoo_pgdump_stdout.splitlines()[-1].strip().split(" ")[2].split(".")[0]
)
db_pgdump_stdout = dc_prod.compose.run(
"db",
db_svc,
command=["pg_dump", "--version"],
remove=True,
tty=False,
Expand Down

0 comments on commit bdc8ded

Please sign in to comment.