Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Application api key added #75 #94

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions paig-client/src/paig_client/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,25 @@ def __init__(self, **kwargs):
self.request_kwargs = kwargs.get('request_kwargs', {})
if 'timeout' not in self.request_kwargs:
self.request_kwargs['timeout'] = Timeout(connect=2.0, read=7.0)
if kwargs['encryption_keys_info']:
self._load_plugin_access_request_encryptor(kwargs['encryption_keys_info'])

def load_config_from_json(self, config):
"""
Load the configuration from a JSON string.

Args:
config (str): JSON string containing the configuration.
"""
self.tenant_id = config.get('tenantId')
self.base_url = config.get('shieldServerUrl')
self.api_key = config.get('apiKey')
self._load_plugin_access_request_encryptor(config.get('encryptionKeysInfo'))


def _load_plugin_access_request_encryptor(self, encryption_keys_info):
self.plugin_access_request_encryptor = PluginAccessRequestEncryptor(self.tenant_id,
kwargs["encryption_keys_info"])
encryption_keys_info)

def get_plugin_access_request_encryptor(self):
return self.plugin_access_request_encryptor
Expand Down Expand Up @@ -545,17 +561,20 @@ def is_access_allowed(self, request: ShieldAccessRequest) -> ShieldAccessResult:
_logger.error(error_message)
raise Exception(error_message)

def init_shield_server(self, application_key) -> None:
def init_shield_server(self, application_key, application_api_key=None):
"""
Initialize shield server for the tenant id.
"""

if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Initializing shield server for tenant: tenant_id={self.tenant_id}")

request = {"shieldServerKeyId": self.plugin_access_request_encryptor.shield_server_key_id,
"shieldPluginKeyId": self.plugin_access_request_encryptor.shield_plugin_key_id,
"applicationKey": application_key}
request = dict()
headers = self.get_default_headers()
if application_api_key is None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Initializing shield server for tenant: tenant_id={self.tenant_id}")
request = {"shieldServerKeyId": self.plugin_access_request_encryptor.shield_server_key_id,
"shieldPluginKeyId": self.plugin_access_request_encryptor.shield_plugin_key_id,
"applicationKey": application_key}
else:
headers['x-application-api-key'] = application_api_key

error_message = ""
init_success = False
Expand All @@ -564,7 +583,7 @@ def init_shield_server(self, application_key) -> None:
try:
response = HttpTransport.get_http().request(method="POST",
url=self.base_url + "/shield/init",
headers=self.get_default_headers(),
headers=headers,
json=json.dumps(request),
**self.request_kwargs)

Expand All @@ -575,7 +594,11 @@ def init_shield_server(self, application_key) -> None:

if response_status == 200:
init_success = True
_logger.info(f"Shield server initialized for tenant: tenant_id={self.tenant_id}")
if application_api_key:
_logger.info(f"Shield server initialized")
else:
_logger.info(f"Shield server initialized for tenant: tenant_id={self.tenant_id}")
return response.json()
else:
if response_status == 400 or response_status == 404:
error_message = str(response.data)
Expand Down
3 changes: 3 additions & 0 deletions paig-client/src/paig_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def setup(**kwargs):
connection_timeout (float): The connection timeout for the access request. This is optional.

read_timeout (float): The read timeout for the access request. This is optional.

application_api_key (str): The API key for the application. This is optional.
server_url (str): The URL of the Shield server. This is optional.
"""
core.setup(**kwargs)

Expand Down
126 changes: 92 additions & 34 deletions paig-client/src/paig_client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __init__(self, **kwargs):
connection_timeout (float): The connection timeout for the access request.

read_timeout (float): The read timeout for the access request.

application_api_key (str): The API key for the application. This is optional.

server_url (str): The URL of the Shield server. This is optional.
"""

try:
Expand Down Expand Up @@ -484,47 +488,74 @@ def __init__(self, **kwargs):
you want to override the values that are in the application config file.
"""

plugin_app_config_dict = self.get_plugin_app_config(kwargs)
self.application_api_key = None

# Init from the loaded file
self.client_application_key = plugin_app_config_dict.get("clientApplicationKey", "*")
self.application_id = plugin_app_config_dict.get("applicationId")
self.application_key = plugin_app_config_dict.get("applicationKey")
self.tenant_id = plugin_app_config_dict.get("tenantId")
if os.environ.get("PAIG_APPLICATION_API_KEY"):
self.application_api_key = os.environ.get("PAIG_APPLICATION_API_KEY")

is_self_hosted_shield_server = False
shield_server_url = plugin_app_config_dict.get("shieldServerUrl")
if shield_server_url is not None:
self.shield_base_url = shield_server_url
is_self_hosted_shield_server = True
else:
self.shield_base_url = plugin_app_config_dict.get("apiServerUrl")
if "application_api_key" in kwargs and kwargs["application_api_key"]:
self.application_api_key = kwargs["application_api_key"]

self.api_key = plugin_app_config_dict.get("apiKey")
self.shield_server_key_id = plugin_app_config_dict.get("shieldServerKeyId")
self.shield_server_public_key = plugin_app_config_dict.get("shieldServerPublicKey")
self.shield_plugin_key_id = plugin_app_config_dict.get("shieldPluginKeyId")
self.shield_plugin_private_key = plugin_app_config_dict.get("shieldPluginPrivateKey")
self.audit_spool_dir = plugin_app_config_dict.get("auditSpoolDir", "spool/audits/")

# Allow override from kwargs
for key, value in kwargs.items():
if key in self.__dict__:
self.__dict__[key] = value
# Init from the loaded file
self.client_application_key = None
self.application_id = None
self.application_key = None
self.tenant_id = None
self.shield_base_url = "http://127.0.0.1:4545"
self.is_self_hosted_shield_server = True
self.api_key = None
self.shield_server_key_id = None
self.shield_server_public_key = None
self.shield_plugin_key_id = None
self.shield_plugin_private_key = None
self.audit_spool_dir = None
encryption_keys_info = None


# Decode and load server url from application key
if self.application_api_key:
decoded_application_api_key = util.decode_application_api_key(self.application_api_key)
self.shield_base_url = util.fetch_server_url_from_key(decoded_application_api_key)

encryption_keys_info = {
"shield_server_key_id": self.shield_server_key_id,
"shield_server_public_key": self.shield_server_public_key,
"shield_plugin_key_id": self.shield_plugin_key_id,
"shield_plugin_private_key": self.shield_plugin_private_key
}
# Allow override from kwargs
if 'server_url' in kwargs and kwargs['server_url']:
self.shield_base_url = kwargs['server_url']

if not self.application_api_key:
plugin_app_config_dict = self.get_plugin_app_config(kwargs)
self._load_application_config(plugin_app_config_dict)
# Allow override from kwargs
for key, value in kwargs.items():
if key in self.__dict__:
self.__dict__[key] = value
encryption_keys_info = {
"shield_server_key_id": self.shield_server_key_id,
"shield_server_public_key": self.shield_server_public_key,
"shield_plugin_key_id": self.shield_plugin_key_id,
"shield_plugin_private_key": self.shield_plugin_private_key
}

self.shield_client = ShieldRestHttpClient(base_url=self.shield_base_url, tenant_id=self.tenant_id,
api_key=self.api_key, encryption_keys_info=encryption_keys_info,
request_kwargs=kwargs.get("request_kwargs", {}),
is_self_hosted_shield_server=is_self_hosted_shield_server)

self.shield_client.init_shield_server(self.application_key)
api_key=self.api_key, encryption_keys_info=encryption_keys_info,
request_kwargs=kwargs.get("request_kwargs", {}),
is_self_hosted_shield_server=self.is_self_hosted_shield_server
)

resp = self.shield_client.init_shield_server(self.application_key, self.application_api_key)
if self.application_api_key:
self._load_application_config(resp)
for key, value in kwargs.items():
if key in self.__dict__:
self.__dict__[key] = value
encryption_keys_info = {
"shield_server_key_id": self.shield_server_key_id,
"shield_server_public_key": self.shield_server_public_key,
"shield_plugin_key_id": self.shield_plugin_key_id,
"shield_plugin_private_key": self.shield_plugin_private_key
}
resp['encryptionKeysInfo'] = encryption_keys_info
self.shield_client.load_config_from_json(resp)

self.llm_stream_audit_logger = None

Expand Down Expand Up @@ -553,6 +584,33 @@ def get_plugin_app_config(self, kwargs):
plugin_app_config_dict = self.read_options_from_app_config(kwargs.get("application_config_file"))
return plugin_app_config_dict

def _load_application_config(self, application_config_data):
# Init from the loaded file
self.client_application_key = application_config_data.get("clientApplicationKey", "*")
self.application_id = application_config_data.get("applicationId")
self.application_key = application_config_data.get("applicationKey")
self.tenant_id = application_config_data.get("tenantId")

shield_server_url = application_config_data.get("shieldServerUrl")
if shield_server_url is not None:
self.shield_base_url = shield_server_url
self.is_self_hosted_shield_server = True
else:
self.is_self_hosted_shield_server = False
self.shield_base_url = application_config_data.get("apiServerUrl")

self.api_key = application_config_data.get("apiKey")
self.shield_server_key_id = application_config_data.get("shieldServerKeyId")
self.shield_server_public_key = application_config_data.get("shieldServerPublicKey")
self.shield_plugin_key_id = application_config_data.get("shieldPluginKeyId")
self.shield_plugin_private_key = application_config_data.get("shieldPluginPrivateKey")
self.audit_spool_dir = application_config_data.get("auditSpoolDir", "spool/audits/")






def read_options_from_app_config(self, application_config_file=None):
"""
Read the options from the application config file.
Expand Down
10 changes: 10 additions & 0 deletions paig-client/src/paig_client/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
import socket
import threading
Expand Down Expand Up @@ -32,6 +33,15 @@ def get_my_hostname():
return socket.gethostname()


def decode_application_api_key(api_key):
return (base64.b64decode(api_key.encode('utf-8'))).decode('utf-8')


def fetch_server_url_from_key(key):
components = key.split(":")
return ":".join(components[2:])


@lru_cache(maxsize=None)
def get_my_ip_address():
ip_address = ""
Expand Down
1 change: 1 addition & 0 deletions paig-server/backend/paig/alembic_db/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from api.governance.database.db_models import ai_app_config_model
from api.governance.database.db_models import metadata_key_model, metadata_value_model
from api.governance.database.db_models import tag_model
from api.governance.database.db_models import ai_app_apikey_model
from api.user.database.db_models import user_model, groups_model
from api.audit.RDS_service.db_models import access_audit_model
from api.encryption.database.db_models import encryption_master_key_model, encryption_key_model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""encryption_key_api_key_tables_added

Revision ID: bfcd4658b3bc
Revises: 5e805b526efa
Create Date: 2024-10-20 18:26:49.316781

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import core.db_models.utils


# revision identifiers, used by Alembic.
revision: str = 'bfcd4658b3bc'
down_revision: Union[str, None] = '5e805b526efa'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('ai_application_api_key',
sa.Column('scope', sa.Integer(), nullable=False),
sa.Column('api_key', sa.String(length=255), nullable=False),
sa.Column('expiry_time', sa.DateTime(), nullable=True),
sa.Column('application_id', sa.Integer(), nullable=False),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['application_id'], ['ai_application.id'], name='fk_ai_application_policy_application_id', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_ai_application_api_key_create_time'), 'ai_application_api_key', ['create_time'], unique=False)
op.create_index(op.f('ix_ai_application_api_key_expiry_time'), 'ai_application_api_key', ['expiry_time'], unique=False)
op.create_index(op.f('ix_ai_application_api_key_id'), 'ai_application_api_key', ['id'], unique=False)
op.create_index(op.f('ix_ai_application_api_key_update_time'), 'ai_application_api_key', ['update_time'], unique=False)
op.create_table('ai_application_encryption_key',
sa.Column('key', sa.String(length=255), nullable=False),
sa.Column('application_id', sa.Integer(), nullable=False),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['application_id'], ['ai_application.id'], name='fk_ai_application_policy_application_id', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_ai_application_encryption_key_create_time'), 'ai_application_encryption_key', ['create_time'], unique=False)
op.create_index(op.f('ix_ai_application_encryption_key_id'), 'ai_application_encryption_key', ['id'], unique=False)
op.create_index(op.f('ix_ai_application_encryption_key_update_time'), 'ai_application_encryption_key', ['update_time'], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_ai_application_encryption_key_update_time'), table_name='ai_application_encryption_key')
op.drop_index(op.f('ix_ai_application_encryption_key_id'), table_name='ai_application_encryption_key')
op.drop_index(op.f('ix_ai_application_encryption_key_create_time'), table_name='ai_application_encryption_key')
op.drop_table('ai_application_encryption_key')
op.drop_index(op.f('ix_ai_application_api_key_update_time'), table_name='ai_application_api_key')
op.drop_index(op.f('ix_ai_application_api_key_id'), table_name='ai_application_api_key')
op.drop_index(op.f('ix_ai_application_api_key_expiry_time'), table_name='ai_application_api_key')
op.drop_index(op.f('ix_ai_application_api_key_create_time'), table_name='ai_application_api_key')
op.drop_table('ai_application_api_key')
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AIApplicationView(BaseView):
name: str = Field(default=None, description="The name of the AI application")
description: Optional[str] = Field(default=None, description="The description of the AI application")
application_key: Optional[str] = Field(None, description="The application key", alias="applicationKey")
application_api_key: Optional[str] = Field(None, description="The application key", alias="applicationAPIKey")
vector_dbs: Optional[List[str]] = Field([], description="The vector databases associated with the AI application", alias="vectorDBs")

model_config = BaseView.model_config
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from api.governance.services.ai_app_apikey_service import AIAppAPIKeyService
from api.governance.services.ai_app_service import AIAppService
from core.utils import SingletonDepends
from core.db_session import Transactional, Propagation



class AIAppAPIKeyController:

def __init__(self,
ai_app_apikey_service: AIAppAPIKeyService = SingletonDepends(AIAppAPIKeyService),
ai_app_service: AIAppService = SingletonDepends(AIAppService),
):
self.ai_app_apikey_service = ai_app_apikey_service
self.ai_app_service = ai_app_service

@Transactional(propagation=Propagation.REQUIRED)
async def generate_api_key(self, app_id: int):
shield_server_url = await self.ai_app_service.get_shield_server_url()
return await self.ai_app_apikey_service.generate_api_key(app_id, shield_server_url)

async def validate_api_key(self, api_key: str):
app_id = await self.ai_app_apikey_service.validate_api_key(api_key)
if app_id:
return {"message": "API key is valid", "app_id": app_id}
else:
return {"message": "API key is invalid"}
Loading