Skip to content

Commit

Permalink
Update code to py36 syntax
Browse files Browse the repository at this point in the history
Work done by pyupgrade pre-commit plugin. Included:
- conversion to f-strings where possible
- removal of redundant syntax
- application of Exception hierarchy changes in py 3.3
  (IOError is now a subclass of OSError, and need not be listed
  separately)

This was done to get a baseline before updating to py38 syntax
  • Loading branch information
Hal Wine committed Sep 24, 2020
1 parent c341d4e commit 9716ea9
Show file tree
Hide file tree
Showing 19 changed files with 98 additions and 149 deletions.
29 changes: 14 additions & 15 deletions aws/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

import os
from collections import namedtuple
import functools
Expand Down Expand Up @@ -34,15 +32,16 @@ def get_session(profile=None):

@functools.lru_cache()
def get_client(profile, region, service):
"""Returns a new or cached botocore service client for the AWS profile, region, and service.
"""Returns a new or cached botocore service client for the AWS profile,
region, and service.
Warns when a service is not available for a region, which means we
need to update botocore or skip that call for that region.
"""
session = get_session(profile)

if region not in session.get_available_regions(service):
warnings.warn("service {} not available in {}".format(service, region))
warnings.warn(f"service {service} not available in {region}")

return session.create_client(service, region_name=region)

Expand All @@ -63,7 +62,10 @@ def get_available_services(profile=None):


def full_results(client, method, args, kwargs):
"""Returns JSON results for an AWS botocore call. Flattens paginated results (if any)."""
"""Returns JSON results for an AWS botocore call.
Flattens paginated results (if any).
"""
if client.can_paginate(method):
paginator = client.get_paginator(method)
return paginator.paginate(*args, **kwargs).build_full_result()
Expand Down Expand Up @@ -96,7 +98,7 @@ def cache_key(call):
str(call.service),
str(call.method),
",".join(call.args),
",".join("{}={}".format(k, v) for (k, v) in call.kwargs.items()),
",".join(f"{k}={v}" for (k, v) in call.kwargs.items()),
]
)
+ ".json"
Expand All @@ -115,9 +117,8 @@ def get_aws_resource(
debug_calls=False,
debug_cache=False,
):
"""
Fetches and yields AWS API JSON responses for all profiles and regions (list params)
"""
"""Fetches and yields AWS API JSON responses for all profiles and regions
(list params)"""
for profile, region in itertools.product(profiles, regions):
call = default_call._replace(
profile=profile,
Expand Down Expand Up @@ -224,7 +225,7 @@ def get(
return self

def values(self):
"""Returns the wrapped value
"""Returns the wrapped value.
>>> c = BotocoreClient([None], None, None, None, offline=True)
>>> c.results = []
Expand All @@ -234,9 +235,8 @@ def values(self):
return self.results

def extract_key(self, key, default=None):
"""
From an iterable of dicts returns the value with the given
keys discarding other values:
"""From an iterable of dicts returns the value with the given keys
discarding other values:
>>> c = BotocoreClient([None], None, None, None, offline=True)
>>> c.results = [{'id': 1}, {'id': 2}]
Expand Down Expand Up @@ -299,8 +299,7 @@ def extract_key(self, key, default=None):
return self

def flatten(self):
"""
Flattens one level of a nested list:
"""Flattens one level of a nested list:
>>> c = BotocoreClient([None], None, None, None, offline=True)
>>> c.results = [['A', 1], ['B']]
Expand Down
57 changes: 22 additions & 35 deletions aws/ec2/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
def ip_permission_opens_all_ports(ipp):
"""
Returns True if an EC2 security group IP permission opens all
ports and False otherwise.
"""Returns True if an EC2 security group IP permission opens all ports and
False otherwise.
>>> ip_permission_opens_all_ports({'FromPort': 1, 'ToPort': 65535})
True
Expand Down Expand Up @@ -31,10 +30,8 @@ def ip_permission_opens_all_ports(ipp):


def ip_permission_cidr_allows_all_ips(ipp):
"""
Returns True if any IPv4 or IPv6 range for an EC2 security group
IP permission opens allows access to or from all IPs and False
otherwise.
"""Returns True if any IPv4 or IPv6 range for an EC2 security group IP
permission opens allows access to or from all IPs and False otherwise.
>>> ip_permission_cidr_allows_all_ips({'IpRanges': [{'CidrIp': '0.0.0.0/0'}]})
True
Expand All @@ -60,9 +57,8 @@ def ip_permission_cidr_allows_all_ips(ipp):


def ip_permission_grants_access_to_group_with_id(ipp, security_group_id):
"""
Returns True if an EC2 security group IP permission opens access to
a security with the given ID and False otherwise.
"""Returns True if an EC2 security group IP permission opens access to a
security with the given ID and False otherwise.
>>> ip_permission_grants_access_to_group_with_id(
... {'UserIdGroupPairs': [{'GroupId': 'test-sgid'}]}, 'test-sgid')
Expand All @@ -81,10 +77,8 @@ def ip_permission_grants_access_to_group_with_id(ipp, security_group_id):


def ec2_security_group_opens_all_ports(ec2_security_group):
"""
Returns True if an ec2 security group includes a permission
allowing inbound access on all ports and False otherwise
or if protocol is ICMP.
"""Returns True if an ec2 security group includes a permission allowing
inbound access on all ports and False otherwise or if protocol is ICMP.
>>> ec2_security_group_opens_all_ports(
... {'IpPermissions': [{}, {'FromPort': -1,'ToPort': 65536}]})
Expand All @@ -109,10 +103,8 @@ def ec2_security_group_opens_all_ports(ec2_security_group):


def ec2_security_group_opens_all_ports_to_self(ec2_security_group):
"""
Returns True if an ec2 security group includes a permission
allowing all IPs inbound access on all ports and False otherwise
or if protocol is ICMP.
"""Returns True if an ec2 security group includes a permission allowing all
IPs inbound access on all ports and False otherwise or if protocol is ICMP.
>>> ec2_security_group_opens_all_ports_to_self({
... 'GroupId': 'test-sgid',
Expand Down Expand Up @@ -166,10 +158,8 @@ def ec2_security_group_opens_all_ports_to_self(ec2_security_group):


def ec2_security_group_opens_all_ports_to_all(ec2_security_group):
"""
Returns True if an ec2 security group includes a permission
allowing all IPs inbound access on all ports and False otherwise
or if protocol is ICMP.
"""Returns True if an ec2 security group includes a permission allowing all
IPs inbound access on all ports and False otherwise or if protocol is ICMP.
>>> ec2_security_group_opens_all_ports_to_all({'IpPermissions': [
... {'FromPort': -1,'ToPort': 65535,'IpRanges': [{'CidrIp': '0.0.0.0/0'}]},
Expand Down Expand Up @@ -208,10 +198,9 @@ def ec2_security_group_opens_all_ports_to_all(ec2_security_group):
def ec2_security_group_opens_specific_ports_to_all(
ec2_security_group, whitelisted_ports=None
):
"""
Returns True if an ec2 security group includes a permission
allowing all IPs inbound access on specific unsafe ports and False
otherwise or if protocol is ICMP.
"""Returns True if an ec2 security group includes a permission allowing all
IPs inbound access on specific unsafe ports and False otherwise or if
protocol is ICMP.
>>> ec2_security_group_opens_specific_ports_to_all({'IpPermissions': [
... {'FromPort': 22,'ToPort': 22,'IpRanges': [{'CidrIp': '0.0.0.0/0'}]},
Expand Down Expand Up @@ -260,18 +249,17 @@ def ec2_security_group_opens_specific_ports_to_all(


def ec2_instance_test_id(ec2_instance):
"""A getter fn for test ids for EC2 instances"""
"""A getter fn for test ids for EC2 instances."""
return "{0[InstanceId]}".format(ec2_instance)


def ec2_security_group_test_id(ec2_security_group):
"""A getter fn for test ids for EC2 security groups"""
"""A getter fn for test ids for EC2 security groups."""
return "{0[GroupId]} {0[GroupName]}".format(ec2_security_group)


def is_ebs_volume_encrypted(ebs):
"""
Checks the EBS volume 'Encrypted' value.
"""Checks the EBS volume 'Encrypted' value.
>>> is_ebs_volume_encrypted({'Encrypted': True})
True
Expand All @@ -294,8 +282,8 @@ def is_ebs_volume_encrypted(ebs):


def is_ebs_snapshot_public(ebs_snapshot):
"""
Checks if the EBS snapshot's 'CreateVolumePermissions' attribute allows for public creation.
"""Checks if the EBS snapshot's 'CreateVolumePermissions' attribute allows
for public creation.
>>> is_ebs_snapshot_public({'CreateVolumePermissions':[{'Group': 'all'}]})
True
Expand All @@ -315,8 +303,7 @@ def is_ebs_snapshot_public(ebs_snapshot):


def ec2_instance_missing_tag_names(ec2_instance, required_tag_names):
"""
Returns any tag names that are missing from an EC2 Instance.
"""Returns any tag names that are missing from an EC2 Instance.
>>> ec2_instance_missing_tag_names({'Tags': [{'Key': 'Name'}]}, frozenset(['Name']))
frozenset()
Expand All @@ -325,5 +312,5 @@ def ec2_instance_missing_tag_names(ec2_instance, required_tag_names):
frozenset({'Name'})
"""
tags = ec2_instance.get("Tags", [])
instance_tag_names = set(tag["Key"] for tag in tags if "Key" in tag)
instance_tag_names = {tag["Key"] for tag in tags if "Key" in tag}
return required_tag_names - instance_tag_names
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
def test_rds_db_instance_not_publicly_accessible_by_vpc_security_group(
rds_db_instance, ec2_security_groups
):
"""
Checks whether any VPC/EC2 security groups that are attached to an RDS instance
allow for access from the public internet.
"""
"""Checks whether any VPC/EC2 security groups that are attached to an RDS
instance allow for access from the public internet."""
if not ec2_security_groups:
assert not rds_db_instance["VpcSecurityGroups"]
else:
assert set(sg["GroupId"] for sg in ec2_security_groups) == set(
assert {sg["GroupId"] for sg in ec2_security_groups} == {
sg["VpcSecurityGroupId"] for sg in rds_db_instance["VpcSecurityGroups"]
)
}

assert not any(
does_vpc_security_group_grant_public_access(sg)
Expand Down
24 changes: 10 additions & 14 deletions cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Patch for pytest cache to serialize datetime.datetime
"""
"""Patch for pytest cache to serialize datetime.datetime."""

import datetime
import functools
Expand All @@ -11,17 +9,16 @@


def json_iso_datetimes(obj):
"""JSON serializer for objects not serializable by default json
module."""
"""JSON serializer for objects not serializable by default json module."""
if isinstance(obj, datetime.datetime):
return obj.isoformat()

raise TypeError("Unserializable type %s" % type(obj))


def json_iso_datetime_string_to_datetime(obj):
"""JSON object hook that converts object vals from ISO datetime
strings to python datetime.datetime`s if possible."""
"""JSON object hook that converts object vals from ISO datetime strings to
python datetime.datetime`s if possible."""

for k, v in obj.items():
if not isinstance(v, str):
Expand All @@ -36,7 +33,7 @@ def json_iso_datetime_string_to_datetime(obj):


def datetime_encode_set(self, key, value):
""" save value for the given key.
"""save value for the given key.
:param key: must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
Expand All @@ -47,12 +44,12 @@ def datetime_encode_set(self, key, value):
path = self._getvaluepath(key)
try:
path.parent.mkdir(exist_ok=True, parents=True)
except (IOError, OSError):
except OSError:
self.warn("could not create cache path {path}", path=path)
return
try:
f = path.open("w")
except (IOError, OSError):
except OSError:
self.warn("cache could not write path {path}", path=path)
else:
with f:
Expand All @@ -61,9 +58,8 @@ def datetime_encode_set(self, key, value):


def datetime_encode_get(self, key, default):
""" return cached value for the given key. If no value
was yet cached or the value cannot be read, the specified
default is returned.
"""return cached value for the given key. If no value was yet cached or
the value cannot be read, the specified default is returned.
:param key: must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
Expand All @@ -74,7 +70,7 @@ def datetime_encode_get(self, key, default):
try:
with path.open("r") as f:
return json.load(f, object_hook=json_iso_datetime_string_to_datetime)
except (ValueError, IOError, OSError):
except (ValueError, OSError):
return default


Expand Down
2 changes: 1 addition & 1 deletion custom_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_whitelisted_ports_from_test_id(self, test_id):
if test_id == rule["test_param_id"]:
return set(rule["ports"])

return set([])
return set()

def get_access_key_expiration_date(self):
if self.access_key_expires_after is None:
Expand Down
6 changes: 2 additions & 4 deletions gcp/iam/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ def service_account_keys(service_account):

def all_service_account_keys():
for sa in service_accounts():
for key in service_account_keys(sa):
yield key
yield from service_account_keys(sa)


def project_iam_bindings():
policy = gcp_client.get_project_iam_policy()
for binding in policy.get("bindings", []):
yield binding
yield from policy.get("bindings", [])
13 changes: 7 additions & 6 deletions gcp/iam/test_only_allowed_org_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def allowed_org_domains(pytestconfig):
@pytest.mark.gcp_iam
@pytest.mark.parametrize("iam_binding", project_iam_bindings(), ids=lambda r: r["role"])
def test_only_allowed_org_accounts(iam_binding, allowed_org_domains):
"""
Only allow specified org domains as members within this project, with a few exceptions.
* Service Accounts are excluded
* The following roles are excluded:
- roles/logging.viewer
"""Only allow specified org domains as members within this project, with a
few exceptions.
* Service Accounts are excluded
* The following roles are excluded:
- roles/logging.viewer
"""
if len(allowed_org_domains) == 0:
assert False, "No allowed org domains specified"
Expand All @@ -28,4 +29,4 @@ def test_only_allowed_org_accounts(iam_binding, allowed_org_domains):
if not member.startswith("serviceAccount"):
assert (
member.split("@")[-1] in allowed_org_domains
), "{} was found and is not in the allowed_org_domains".format(member)
), f"{member} was found and is not in the allowed_org_domains"
Loading

0 comments on commit 9716ea9

Please sign in to comment.