Skip to content

Python code example standards

Rachel Hagerman edited this page Aug 20, 2024 · 8 revisions

This document summarizes important points for writing and reviewing code examples written for the AWS Python (Boto3) SDK.

General Structure

  • Service folders should include a wrapper class, a test folder, and scenarios in their own files.
  • Include a requirements.txt file with all dependencies.
  • Scenario files should:
    • Import demo_tools for shared functionality.
    • Include an init function that takes the service wrapper as a parameter.
    • Include a run_scenario method.
    • Begin with a descriptive comment block.
"""
Purpose:

Shows how to use the AWS SDK for Python (Boto3) with Amazon Keyspaces (for Apache Cassandra)
to do the following:

* Create a keyspace.
* Create a table in the keyspace.
* Connect to the keyspace.
* Query the table.
* Update the table.
* Restore the table.
* Delete the table and keyspace.
"""

from datetime import datetime
import logging
import os
from pprint import pp
import sys

import boto3
import requests

from query import QueryManager
from keyspace import KeyspaceWrapper

# Add relative path to include demo_tools
sys.path.append("../..")
from demo_tools import demo_func
import demo_tools.question as q
from demo_tools.retries import wait

logger = logging.getLogger(__name__)

class KeyspaceScenario:
    """Runs an interactive scenario that shows how to get started using Amazon Keyspaces."""

    def __init__(self, ks_wrapper):
        self.ks_wrapper = ks_wrapper

    @demo_func
    def create_keyspace(self):
        print("Let's create a keyspace.")
        ks_name = q.ask("Enter a name for your new keyspace: ", q.non_empty)
        if self.ks_wrapper.exists_keyspace(ks_name):
            print(f"A keyspace named {ks_name} exists.")
        else:
            ks_arn = self.ks_wrapper.create_keyspace(ks_name)
            ks_exists = False
            while not ks_exists:
                wait(3)
                ks_exists = self.ks_wrapper.exists_keyspace(ks_name)
            print(f"Created a new keyspace: {ks_arn}")
        self.ks_wrapper.list_keyspaces(10)

    @demo_func
    def create_table(self):
        print("Let's create a table for movies in your keyspace.")
        table_name = q.ask("Enter a name for your table: ", q.non_empty)
        table = self.ks_wrapper.get_table(table_name)
        if table:
            print(f"A table named {table_name} already exists.")
        else:
            table_arn = self.ks_wrapper.create_table(table_name)
            print(f"Created table {table_name}: {table_arn}")
            table = {"status": None}
            while table["status"] != "ACTIVE":
                wait(5)
                table = self.ks_wrapper.get_table(table_name)
            print(f"Your table is {table['status']}. Its schema is:")
            pp(table["schemaDefinition"])
        self.ks_wrapper.list_tables()

    @demo_func
    def ensure_tls_cert(self):
        print("To connect to your keyspace, you must have a TLS certificate.")
        cert_path = os.path.join(os.path.dirname(__file__), QueryManager.DEFAULT_CERT_FILE)
        if not os.path.exists(cert_path):
            cert_choice = q.ask(f"Press enter to download a certificate or enter the full path to your certificate: ")
            if cert_choice:
                cert_path = cert_choice
            else:
                cert = requests.get(QueryManager.CERT_URL).text
                with open(cert_path, "w") as cert_file:
                    cert_file.write(cert)
        else:
            q.ask(f"Certificate {cert_path} found. Press Enter to continue.")
        print(f"Certificate {cert_path} will be used.")
        return cert_path

    @demo_func
    def query_table(self, qm, movie_file):
        qm.add_movies(self.ks_wrapper.table_name, movie_file)
        movies = qm.get_movies(self.ks_wrapper.table_name)
        print(f"Added {len(movies)} movies to the table.")
        sel = q.choose("Pick one to learn more about it: ", [m.title for m in movies])
        movie_choice = qm.get_movie(self.ks_wrapper.table_name, movies[sel].title, movies[sel].year)
        print(movie_choice.title)
        print(f"\tReleased: {movie_choice.release_date}")
        print(f"\tPlot: {movie_choice.plot}")

    @demo_func
    def update_and_restore_table(self, qm):
        print("Let's add a column to record which movies you've watched.")
        pre_update_timestamp = datetime.utcnow()
        self.ks_wrapper.update_table()
        table = {"status": "UPDATING"}
        while table["status"] != "ACTIVE":
            wait(5)
            table = self.ks_wrapper.get_table(self.ks_wrapper.table_name)
        print("Column 'watched' added.")
        movies = qm.get_movies(self.ks_wrapper.table_name)
        for movie in movies[:10]:
            qm.watched_movie(self.ks_wrapper.table_name, movie.title, movie.year)
            print(f"Marked {movie.title} as watched.")
        if q.ask("Restore table to the way it was? (y/n) ", q.is_yesno):
            table_name_restored = self.ks_wrapper.restore_table(pre_update_timestamp)
            table = {"status": "RESTORING"}
            while table["status"] != "ACTIVE":
                wait(10)
                table = self.ks_wrapper.get_table(table_name_restored)
            print(f"Restored to {table_name_restored}.")
            movies = qm.get_movies(table_name_restored)
            for movie in movies:
                print(movie.title)

    def cleanup(self, cert_path):
        if q.ask("Delete your table and keyspace? (y/n) ", q.is_yesno):
            self.ks_wrapper.delete_table()
            self.ks_wrapper.delete_keyspace()
            if cert_path == os.path.join(os.path.dirname(__file__), QueryManager.DEFAULT_CERT_FILE) and os.path.exists(cert_path):
                os.remove(cert_path)
                print("Removed certificate.")

    def run_scenario(self):
        logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
        self.create_keyspace()
        self.create_table()
        cert_file_path = self.ensure_tls_cert()
        with QueryManager(cert_file_path, boto3.DEFAULT_SESSION, self.ks_wrapper.ks_name) as qm:
            self.query_table(qm, "../../../resources/sample_files/movies.json")
            self.update_and_restore_table(qm)
        self.cleanup(cert_file_path)

if __name__ == "__main__":
    try:
        scenario = KeyspaceScenario(KeyspaceWrapper.from_client())
        scenario.run_scenario()
    except Exception:
        logging.exception("Something went wrong with the demo.")

Wrapper classes

  • Methods should provide additional context and should not use Request/Response classes directly.
  • Include a from_client method to initialize the service client.
  • Use a class declaration snippet for Code Library metadata.
import logging
import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)

class KeyspaceWrapper:
    """Encapsulates Amazon Keyspaces (for Apache Cassandra) keyspace and table actions."""

    def __init__(self, keyspaces_client):
        self.keyspaces_client = keyspaces_client
        self.ks_name = None
        self.ks_arn = None
        self.table_name = None

    @classmethod
    def from_client(cls):
        keyspaces_client = boto3.client("keyspaces")
        return cls(keyspaces_client)

    def create_keyspace(self, name):
        try:
            response = self.keyspaces_client.create_keyspace(keyspaceName=name)
            self.ks_name = name
            self.ks_arn = response["resourceArn"]
        except ClientError as err:
            logger.error(f"Couldn't create {name}: {err}")
            raise
        else:
            return self.ks_arn

    def exists_keyspace(self, name):
        try:
            response = self.keyspaces_client.get_keyspace(keyspaceName=name)
            self.ks_name = response["keyspaceName"]
            self.ks_arn = response["resourceArn"]
            return True
        except ClientError as err:
            if err.response["Error"]["Code"] == "ResourceNotFoundException":
                logger.info(f"Keyspace {name} does not exist.")
                return False
            else:
                logger.error(f"Couldn't verify {name} exists: {err}")
                raise

    def list_keyspaces(self, limit):
        try:
            ks_paginator = self.keyspaces_client.get_paginator("list_keyspaces")
            for page in ks_paginator.paginate(PaginationConfig={"MaxItems": limit}):
                for ks in page["keyspaces"]:
                    print(ks["keyspaceName"])
        except ClientError as err:
            logger.error(f"Couldn't list keyspaces: {err}")
            raise

    def create_table(self, table_name):
        try:
            response = self.keyspaces_client.create_table(
                keyspaceName=self.ks_name,
                tableName=table_name,
                schemaDefinition={
                    "allColumns": [
                        {"name": "title", "type": "text"},
                        {"name": "year", "type": "int"},
                        {"name": "release_date", "type": "timestamp"},
                        {"name": "plot", "type": "text"},
                    ],
                    "partitionKeys": [{"name": "year"}, {"name": "title"}],
                },
                pointInTimeRecovery={"status": "ENABLED"},
            )
        except ClientError as err:
            logger.error(f"Couldn't create table {table_name}: {err}")
            raise
        else:
            return response["resourceArn"]

    def get_table(self, table_name):
       

 try:
            response = self.keyspaces_client.get_table(keyspaceName=self.ks_name, tableName=table_name)
            self.table_name = table_name
            return response
        except ClientError as err:
            if err.response["Error"]["Code"] == "ResourceNotFoundException":
                logger.info(f"Table {table_name} does not exist.")
                return None
            else:
                logger.error(f"Couldn't verify {table_name} exists: {err}")
                raise

    def list_tables(self):
        try:
            table_paginator = self.keyspaces_client.get_paginator("list_tables")
            for page in table_paginator.paginate(keyspaceName=self.ks_name):
                for table in page["tables"]:
                    print(table["tableName"])
        except ClientError as err:
            logger.error(f"Couldn't list tables in keyspace {self.ks_name}: {err}")
            raise

    def update_table(self):
        try:
            self.keyspaces_client.update_table(
                keyspaceName=self.ks_name,
                tableName=self.table_name,
                addColumns=[{"name": "watched", "type": "boolean"}],
            )
        except ClientError as err:
            logger.error(f"Couldn't update table {self.table_name}: {err}")
            raise

    def restore_table(self, restore_timestamp):
        try:
            restored_table_name = f"{self.table_name}_restored"
            self.keyspaces_client.restore_table(
                sourceKeyspaceName=self.ks_name,
                sourceTableName=self.table_name,
                targetKeyspaceName=self.ks_name,
                targetTableName=restored_table_name,
                restoreTimestamp=restore_timestamp,
            )
        except ClientError as err:
            logger.error(f"Couldn't restore table to {restore_timestamp}: {err}")
            raise
        else:
            return restored_table_name

    def delete_table(self):
        try:
            self.keyspaces_client.delete_table(keyspaceName=self.ks_name, tableName=self.table_name)
            self.table_name = None
        except ClientError as err:
            logger.error(f"Couldn't delete table {self.table_name}: {err}")
            raise

    def delete_keyspace(self):
        try:
            self.keyspaces_client.delete_keyspace(keyspaceName=self.ks_name)
            self.ks_name = None
        except ClientError as err:
            logger.error(f"Couldn't delete keyspace {self.ks_name}: {err}")
            raise

Language Features

  • Comments:

    • Include descriptive comment blocks for all classes and methods.
    • Use param tag comments for each parameter and method response.
  • Functions:

  • Include type hints for function declarations.

  • Logging:

    • Use the logging module rather than print statements.
    • Set up a logger at the module level.
  • Error Handling:

    • Use try and except blocks.
    • Log error messages.
    • Raise exceptions after logging errors.
  • Testing:

    • Ensure test coverage for wrapper methods and scenario methods.
    • Use established service method stubbing patterns.
    • Mark integration tests with @pytest.mark.integ.
from botocore.exceptions import ClientError
import pytest

class MockManager:
    def __init__(self, stub_runner, scenario_data, input_mocker):
        self.scenario_data = scenario_data
        self.ks_exists = False
        self.ks_name = "test-ks"
        self.ks_arn = "arn:aws:cassandra:test-region:111122223333:/keyspace/test-ks"
        self.keyspaces = [{"keyspaceName": f"ks-{ind}", "resourceArn": self.ks_arn} for ind in range(1, 4)]
        answers = [self.ks_name]
        input_mocker.mock_answers(answers)
        self.stub_runner = stub_runner

    def setup_stubs(self, error, stop_on, stubber):
        with self.stub_runner(error, stop_on) as runner:
            if self.ks_exists:
                runner.add(stubber.stub_get_keyspace, self.ks_name, self.ks_arn)
            else:
                runner.add(stubber.stub_get_keyspace, self.ks_name, self.ks_arn, error_code="ResourceNotFoundException")
                runner.add(stubber.stub_create_keyspace, self.ks_name, self.ks_arn)
                runner.add(stubber.stub_get_keyspace, self.ks_name, self.ks_arn)
            runner.add(stubber.stub_list_keyspaces, self.keyspaces)

@pytest.fixture
def mock_mgr(stub_runner, scenario_data, input_mocker):
    return MockManager(stub_runner, scenario_data, input_mocker)

@pytest.mark.parametrize("ks_exists", [True, False])
def test_create_keyspace(mock_mgr, capsys, ks_exists):
    mock_mgr.ks_exists = ks_exists
    mock_mgr.setup_stubs(None, None, mock_mgr.scenario_data.stubber)

    mock_mgr.scenario_data.scenario.create_keyspace()

    capt = capsys.readouterr()
    assert mock_mgr.ks_name in capt.out
    for ks in mock_mgr.keyspaces:
        assert ks["keyspaceName"] in capt.out

@pytest.mark.parametrize("error, stop_on_index", [("TESTERROR-stub_get_keyspace", 0), ("TESTERROR-stub_create_keyspace", 1), ("TESTERROR-stub_list_keyspaces", 2)])
def test_create_keyspace_error(mock_mgr, caplog, error, stop_on_index):
    mock_mgr.setup_stubs(error, stop_on_index, mock_mgr.scenario_data.stubber)

    with pytest.raises(ClientError) as exc_info:
        mock_mgr.scenario_data.scenario.create_keyspace()
    assert exc_info.value.response["Error"]["Code"] == error

    assert error in caplog.text
Clone this wiki locally