Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d139a89
initial commit of ACL porting changes
mattgotteiner Oct 20, 2025
3ff66db
update app init for ACL porting changes
mattgotteiner Oct 20, 2025
0cd85f5
updating auth helper tests
mattgotteiner Oct 20, 2025
7c2f7e0
test update
mattgotteiner Oct 20, 2025
3d586db
updating auth helper tests
mattgotteiner Oct 20, 2025
059d5ca
add back oid
mattgotteiner Oct 20, 2025
cb690d5
update mocks
mattgotteiner Oct 20, 2025
158e71d
update scope
mattgotteiner Oct 20, 2025
abda838
fix coverage; update tests; make sure oids, groups fields are properl…
mattgotteiner Oct 20, 2025
67300f4
refactor
mattgotteiner Oct 20, 2025
6fb1bb7
WIP - adding back controls
mattgotteiner Oct 21, 2025
0eb211d
1st round of test fixes
mattgotteiner Oct 26, 2025
0fe8445
more test fixes
mattgotteiner Oct 26, 2025
503bb79
more test fixes
mattgotteiner Oct 27, 2025
d62ffe4
more test fixes
mattgotteiner Oct 27, 2025
4397a6a
update
mattgotteiner Oct 27, 2025
0f5b1e9
update tests and add acl command to enable global access
mattgotteiner Oct 27, 2025
256a57f
admin consent + add back graph grants
mattgotteiner Oct 28, 2025
642ed8e
fix env vars
mattgotteiner Oct 28, 2025
a2d43ad
docs update
mattgotteiner Oct 28, 2025
a546111
remove oids, groups filter checkbox
mattgotteiner Oct 28, 2025
4ffc9c1
fix markdown lint issues
mattgotteiner Oct 28, 2025
7b035a2
addressing feedback
mattgotteiner Oct 28, 2025
356823a
enforce access control on user upload
mattgotteiner Oct 28, 2025
d7b4fa0
update vision mocks to pass through PNG bytes instead of test content
mattgotteiner Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions app/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ async def setup_clients():
AZURE_TENANT_ID = os.getenv("AZURE_TENANT_ID")
AZURE_USE_AUTHENTICATION = os.getenv("AZURE_USE_AUTHENTICATION", "").lower() == "true"
AZURE_ENFORCE_ACCESS_CONTROL = os.getenv("AZURE_ENFORCE_ACCESS_CONTROL", "").lower() == "true"
AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS = os.getenv("AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS", "").lower() == "true"
AZURE_ENABLE_UNAUTHENTICATED_ACCESS = os.getenv("AZURE_ENABLE_UNAUTHENTICATED_ACCESS", "").lower() == "true"
AZURE_SERVER_APP_ID = os.getenv("AZURE_SERVER_APP_ID")
AZURE_SERVER_APP_SECRET = os.getenv("AZURE_SERVER_APP_SECRET")
Expand Down Expand Up @@ -543,8 +542,7 @@ async def setup_clients():
server_app_secret=AZURE_SERVER_APP_SECRET,
client_app_id=AZURE_CLIENT_APP_ID,
tenant_id=AZURE_AUTH_TENANT_ID,
require_access_control=AZURE_ENFORCE_ACCESS_CONTROL,
enable_global_documents=AZURE_ENABLE_GLOBAL_DOCUMENT_ACCESS,
enforce_access_control=AZURE_ENFORCE_ACCESS_CONTROL,
enable_unauthenticated_access=AZURE_ENABLE_UNAUTHENTICATED_ACCESS,
)

Expand Down Expand Up @@ -578,6 +576,8 @@ async def setup_clients():
raise ValueError(
"AZURE_USERSTORAGE_ACCOUNT and AZURE_USERSTORAGE_CONTAINER must be set when USE_USER_UPLOAD is true"
)
if not AZURE_ENFORCE_ACCESS_CONTROL:
raise ValueError("AZURE_ENFORCE_ACCESS_CONTROL must be true when USE_USER_UPLOAD is true")
user_blob_manager = AdlsBlobManager(
endpoint=f"https://{AZURE_USERSTORAGE_ACCOUNT}.dfs.core.windows.net",
container=AZURE_USERSTORAGE_CONTAINER,
Expand Down Expand Up @@ -676,7 +676,6 @@ async def setup_clients():
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
agent_client=agent_client,
openai_client=openai_client,
auth_helper=auth_helper,
chatgpt_model=OPENAI_CHATGPT_MODEL,
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
embedding_model=OPENAI_EMB_MODEL,
Expand All @@ -703,7 +702,6 @@ async def setup_clients():
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
agent_client=agent_client,
openai_client=openai_client,
auth_helper=auth_helper,
chatgpt_model=OPENAI_CHATGPT_MODEL,
chatgpt_deployment=AZURE_OPENAI_CHATGPT_DEPLOYMENT,
embedding_model=OPENAI_EMB_MODEL,
Expand Down
15 changes: 7 additions & 8 deletions app/backend/approaches/approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
)

from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
from prepdocslib.embeddings import ImageEmbeddings

Expand Down Expand Up @@ -152,7 +151,6 @@ def __init__(
self,
search_client: SearchClient,
openai_client: AsyncOpenAI,
auth_helper: AuthenticationHelper,
query_language: Optional[str],
query_speller: Optional[str],
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
Expand All @@ -169,7 +167,6 @@ def __init__(
):
self.search_client = search_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.query_language = query_language
self.query_speller = query_speller
self.embedding_deployment = embedding_deployment
Expand All @@ -185,17 +182,14 @@ def __init__(
self.global_blob_manager = global_blob_manager
self.user_blob_manager = user_blob_manager

def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
def build_filter(self, overrides: dict[str, Any]) -> Optional[str]:
include_category = overrides.get("include_category")
exclude_category = overrides.get("exclude_category")
security_filter = self.auth_helper.build_security_filters(overrides, auth_claims)
filters = []
if include_category:
filters.append("category eq '{}'".format(include_category.replace("'", "''")))
if exclude_category:
filters.append("category ne '{}'".format(exclude_category.replace("'", "''")))
if security_filter:
filters.append(security_filter)
return None if len(filters) == 0 else " and ".join(filters)

async def search(
Expand All @@ -211,6 +205,7 @@ async def search(
minimum_search_score: Optional[float] = None,
minimum_reranker_score: Optional[float] = None,
use_query_rewriting: Optional[bool] = None,
access_token: Optional[str] = None,
) -> list[Document]:
search_text = query_text if use_text_search else ""
search_vectors = vectors if use_vector_search else []
Expand All @@ -227,13 +222,15 @@ async def search(
query_speller=self.query_speller,
semantic_configuration_name="default",
semantic_query=query_text,
x_ms_query_source_authorization=access_token,
)
else:
results = await self.search_client.search(
search_text=search_text,
filter=filter,
top=top,
vector_queries=search_vectors,
x_ms_query_source_authorization=access_token,
)

documents: list[Document] = []
Expand Down Expand Up @@ -275,6 +272,7 @@ async def run_agentic_retrieval(
filter_add_on: Optional[str] = None,
minimum_reranker_score: Optional[float] = None,
results_merge_strategy: Optional[str] = None,
access_token: Optional[str] = None,
) -> tuple[KnowledgeAgentRetrievalResponse, list[Document]]:
# STEP 1: Invoke agentic retrieval
response = await agent_client.retrieve(
Expand All @@ -292,7 +290,8 @@ async def run_agentic_retrieval(
filter_add_on=filter_add_on,
)
],
)
),
x_ms_query_source_authorization=access_token,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Ensure that tests were added to verify that token is passed in for all calls to search()/retrieve()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)

# Map activity id -> agent's internal search query
Expand Down
11 changes: 6 additions & 5 deletions app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ThoughtStep,
)
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
from prepdocslib.embeddings import ImageEmbeddings

Expand All @@ -42,7 +41,6 @@ def __init__(
agent_model: Optional[str],
agent_deployment: Optional[str],
agent_client: KnowledgeAgentRetrievalClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
Expand All @@ -67,7 +65,6 @@ def __init__(
self.agent_deployment = agent_deployment
self.agent_client = agent_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.chatgpt_deployment = chatgpt_deployment
self.embedding_deployment = embedding_deployment
Expand Down Expand Up @@ -279,7 +276,8 @@ async def run_search_approach(
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
search_index_filter = self.build_filter(overrides, auth_claims)
search_index_filter = self.build_filter(overrides)
access_token = auth_claims.get("access_token")
send_text_sources = overrides.get("send_text_sources", True)
send_image_sources = overrides.get("send_image_sources", self.multimodal_enabled) and self.multimodal_enabled
search_text_embeddings = overrides.get("search_text_embeddings", True)
Expand Down Expand Up @@ -337,6 +335,7 @@ async def run_search_approach(
minimum_search_score,
minimum_reranker_score,
use_query_rewriting,
access_token,
)

# STEP 3: Generate a contextual and content specific answer using the search results and chat history
Expand Down Expand Up @@ -388,7 +387,8 @@ async def run_agentic_retrieval_approach(
overrides: dict[str, Any],
auth_claims: dict[str, Any],
):
search_index_filter = self.build_filter(overrides, auth_claims)
search_index_filter = self.build_filter(overrides)
access_token = auth_claims.get("access_token")
minimum_reranker_score = overrides.get("minimum_reranker_score", 0)
top = overrides.get("top", 3)
results_merge_strategy = overrides.get("results_merge_strategy", "interleaved")
Expand All @@ -403,6 +403,7 @@ async def run_agentic_retrieval_approach(
filter_add_on=search_index_filter,
minimum_reranker_score=minimum_reranker_score,
results_merge_strategy=results_merge_strategy,
access_token=access_token,
)

data_points = await self.get_sources_content(
Expand Down
11 changes: 6 additions & 5 deletions app/backend/approaches/retrievethenread.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ThoughtStep,
)
from approaches.promptmanager import PromptManager
from core.authentication import AuthenticationHelper
from prepdocslib.blobmanager import AdlsBlobManager, BlobManager
from prepdocslib.embeddings import ImageEmbeddings

Expand All @@ -32,7 +31,6 @@ def __init__(
agent_model: Optional[str],
agent_deployment: Optional[str],
agent_client: KnowledgeAgentRetrievalClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
Expand All @@ -58,7 +56,6 @@ def __init__(
self.agent_client = agent_client
self.chatgpt_deployment = chatgpt_deployment
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
Expand Down Expand Up @@ -155,7 +152,8 @@ async def run_search_approach(
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
filter = self.build_filter(overrides)
access_token = auth_claims.get("access_token")
q = str(messages[-1]["content"])
send_text_sources = overrides.get("send_text_sources", True)
send_image_sources = overrides.get("send_image_sources", self.multimodal_enabled) and self.multimodal_enabled
Expand Down Expand Up @@ -183,6 +181,7 @@ async def run_search_approach(
minimum_search_score,
minimum_reranker_score,
use_query_rewriting,
access_token,
)

data_points = await self.get_sources_content(
Expand Down Expand Up @@ -225,7 +224,8 @@ async def run_agentic_retrieval_approach(
auth_claims: dict[str, Any],
) -> ExtraInfo:
minimum_reranker_score = overrides.get("minimum_reranker_score", 0)
search_index_filter = self.build_filter(overrides, auth_claims)
search_index_filter = self.build_filter(overrides)
access_token = auth_claims.get("access_token")
top = overrides.get("top", 3)
results_merge_strategy = overrides.get("results_merge_strategy", "interleaved")
send_text_sources = overrides.get("send_text_sources", True)
Expand All @@ -239,6 +239,7 @@ async def run_agentic_retrieval_approach(
filter_add_on=search_index_filter,
minimum_reranker_score=minimum_reranker_score,
results_merge_strategy=results_merge_strategy,
access_token=access_token,
)

data_points = await self.get_sources_content(
Expand Down
Loading