Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing support for tool calling in Llama models (was: Tool Calling for AzureML Managed Deployments) #39391

Closed
Luke-Dornburgh opened this issue Jan 24, 2025 · 8 comments
Assignees
Labels
AI Model Inference Issues related to the client library for Azure AI Model Inference (\sdk\ai\azure-ai-inference) Client This issue points to a problem in the data-plane of the library. customer-reported Issues that are reported by GitHub users external to the Azure organization. needs-team-attention Workflow: This issue needs attention from Azure service team or SDK team question The issue doesn't require a change to the product in order to be resolved. Most issues start as that Service Attention Workflow: This issue is responsible by Azure service team.

Comments

@Luke-Dornburgh
Copy link

  • Package Name: azure-ai-inference
  • Package Version: 1.0.0b7
  • Operating System: Windows
  • Python Version: 3.11.2

Describe the bug
Tool calling not functioning as expected for Meta-Llama-3.1-70B-Instruct or Meta-Llama-3.3-70B-Instruct deployed to Managed compute in Azure Machine Learning. The code below works just find for OpenAI models, but the two Azure ML Realtime Managed models we have deployed do not work.

import json  
import asyncio  
from azure.ai.inference.aio import ChatCompletionsClient  # Use async client  
from azure.ai.inference.models import (  
    AssistantMessage,  
    ChatCompletionsToolCall,  
    ChatCompletionsToolDefinition,  
    FunctionCall,  
    FunctionDefinition,  
    SystemMessage,  
    ToolMessage,  
    UserMessage,
    ChatCompletionsNamedToolChoice,
    ChatCompletionsNamedToolChoiceFunction
)  
from azure.core.credentials import AzureKeyCredential  
  
async def get_flight_info(origin_city: str, destination_city: str):  
    await asyncio.sleep(0)  # Simulate async operation  
    if origin_city == "Seattle" and destination_city == "Miami":  
        return json.dumps({  
            "airline": "Delta", "flight_number": "DL123", "flight_date": "May 7th, 2024", "flight_time": "10:00AM"  
        })
    if origin_city == "Tampa" and destination_city == "Richmond":  
        return json.dumps({  
            "airline": "American", "flight_number": "AM123", "flight_date": "May 4th, 2024", "flight_time": "9:00PM"  
        })
    return json.dumps({"error": "No flights found between the cities"})  
  
async def get_train_info(origin_city: str, destination_city: str):  
    await asyncio.sleep(0)  # Simulate async operation  
    if origin_city == "Seattle" and destination_city == "Miami":  
        return json.dumps({  
            "train": "Pennsylvania", "train_number": "T123", "train_date": "May 9th, 2024", "train_time": "11:00AM"  
        })  
    return json.dumps({"error": "No trains found between the cities"})  
  
flight_info = ChatCompletionsToolDefinition(  
    function=FunctionDefinition(  
        name="get_flight_info",  
        description="Returns information about the next flight between two cities.",  
        parameters={  
            "type": "object",  
            "properties": {  
                "origin_city": {"type": "string", "description": "The name of the city where the flight originates"},  
                "destination_city": {"type": "string", "description": "The flight destination city"},  
            },  
            "required": ["origin_city", "destination_city"],  
        },  
    )  
)  
  
train_info = ChatCompletionsToolDefinition(  
    function=FunctionDefinition(  
        name="get_train_info",  
        description="Returns information about the next train rides between two cities.",  
        parameters={  
            "type": "object",  
            "properties": {  
                "origin_city": {"type": "string", "description": "The name of the city where the train originates"},  
                "destination_city": {"type": "string", "description": "The train destination city"},  
            },  
            "required": ["origin_city", "destination_city"],  
        },  
    )  
)  
  
function_map = {  
    "get_flight_info": get_flight_info,  
    "get_train_info": get_train_info,  
}  
  
tool_map = {  
    "get_flight_info": flight_info,  
    "get_train_info": train_info,  
}  
  
async def handle_tool_calls(messages, tool_calls, client):  
    accumulated_arguments = {}  
    for tool_call in tool_calls:  
        tool_function = tool_call.get('function', {})  
        tool_index = tool_call.get('index')  
        if tool_index is not None:  
            if tool_index not in accumulated_arguments:  
                accumulated_arguments[tool_index] = ""  
            accumulated_arguments[tool_index] += tool_function.get('arguments', '')  
    # print(f"accumulated_arguments: \n{accumulated_arguments}\n")
    new_messages = messages[:]  
    chat_completions_tool_calls = []  
    tool_messages_list = []  
    for tool_call in tool_calls:  
        tool_function = tool_call.get('function', {})  
        tool_index = tool_call.get('index')  
        function_name = tool_function.get('name')  
        if function_name and tool_index in accumulated_arguments:  
            function_args = accumulated_arguments[tool_index]  
            print(f"Function call detected: {function_name} with arguments: {function_args}")  
  
            callable_func = function_map.get(function_name)  
            if callable_func is None:  
                print(f"Warning: Function `{function_name}` not found in the function map.")  
                continue  
  
            try:  
                function_args_mapping = json.loads(function_args)  
            except json.JSONDecodeError as e:  
                print(f"Error decoding JSON for function `{function_name}`: {e}")  
                continue  
  
            function_response = await callable_func(**function_args_mapping)  
            print(f"Function response = {function_response}")  
  
            tool_call_id = tool_call.get('id')  
            tool_messages_list.append(ToolMessage(content=function_response, tool_call_id=tool_call_id))  
            chat_completions_tool_calls.append(ChatCompletionsToolCall(id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args)))  
  
    new_messages.append(AssistantMessage(tool_calls=chat_completions_tool_calls))  
    new_messages.extend(tool_messages_list)  
    response = await client.complete(messages=new_messages, tools=list(tool_map.values()), stream=True)  
    async for update in response:  
        yield update  
  
async def sample_chat_completions_streaming_with_tools():  
    use_azure_openai_endpoint = False 
  
    try:  
        if use_azure_openai_endpoint:  
            endpoint = "<my-azure-openai-endpoint>"  
            key = "<my-api-key>"  
        else:  
            endpoint = "https://<deployment-name>.<region>.inference.ml.azure.com"
            key = "<my-api-key>"  
    except KeyError:  
        print("Missing environment variable 'AZURE_AI_CHAT_ENDPOINT' or 'AZURE_AI_CHAT_KEY'")  
        print("Set them before running this sample.")  
        return  
  
    client = ChatCompletionsClient(  
        endpoint=endpoint,  
        credential=AzureKeyCredential(key),  
        api_version="2024-06-01",  
    )  
  
    messages = [  
        SystemMessage(content="You are an assistant that helps users find travel information."),  
        UserMessage(content="when is the next flight from Seattle to Miami, also when is the next flight from Tampa to Richmond?"),  
    ]  
  
    response = await client.complete(messages=messages, tools=list(tool_map.values()), stream=True)  
  
    tool_calls = []  
    async for update in response:  
        if update.choices[0].delta.tool_calls is not None:  
            tool_calls.extend(update.choices[0].delta.tool_calls)  
        else:  
            # Directly stream back the content if no tool calls are detected  
            if update.choices[0].delta.content is not None:
                await asyncio.sleep(0.05)
                print(update.choices[0].delta.content or "", end="", flush=True)  
  
    if tool_calls:
        response = handle_tool_calls(messages, tool_calls, client)  
        async for update in response:  
            await asyncio.sleep(0.05)  
            print(update.choices[0].delta.content or "", end="", flush=True)  
    
    await client.close()   
  
async def main():
    await sample_chat_completions_streaming_with_tools()


if __name__ == "__main__":
    await main() 

To Reproduce
Steps to reproduce the behavior:

  1. Copy the code
  2. Add your own deployment information where the <> brackets are
  3. Run it first with an Azure OpenAI chat model to see expected output
  4. Run it with the Azure ML deployment to see error output

Expected behavior
Here is the output that the Azure OpenAI models return:

Function call detected: get_flight_info with arguments: {"origin_city": "Seattle", "destination_city": "Miami"}
Function response = {"airline": "Delta", "flight_number": "DL123", "flight_date": "May 7th, 2024", "flight_time": "10:00AM"}
Function call detected: get_flight_info with arguments: {"origin_city": "Tampa", "destination_city": "Richmond"}
Function response = {"airline": "American", "flight_number": "AM123", "flight_date": "May 4th, 2024", "flight_time": "9:00PM"}
Here are the details for the next flights:

- **Seattle to Miami**:  
  - **Airline**: Delta  
  - **Flight Number**: DL123  
  - **Date**: May 7th, 2024  
  - **Time**: 10:00 AM  

- **Tampa to Richmond**:  
  - **Airline**: American  
  - **Flight Number**: AM123  
  - **Date**: May 4th, 2024  
  - **Time**: 9:00 PM

Screenshots - Error using Azure ML Managed Deployment:

Image

Additional context
This code is heavily based on this sample from the repo: sdk/ai/azure-ai-inference/samples/sample_chat_completions_streaming_with_tools.py
In that sample, the comment clearly states:
DESCRIPTION:
This sample demonstrates how to do chat completions using a synchronous client,
with the assistance of tools, with streaming service response. In this sample,
we use a mock function tool to retrieve flight information in order to answer
a query about the next flight between two cities. Make sure that the AI model
you use supports tools. The sample supports either Serverless API endpoint /
Managed Compute endpoint, or Azure OpenAI endpoint.
Set the boolean variable
use_azure_openai_endpoint to select between the two. API key authentication
is used for both endpoints in this sample.

@github-actions github-actions bot added customer-reported Issues that are reported by GitHub users external to the Azure organization. needs-triage Workflow: This is a new issue that needs to be triaged to the appropriate team. question The issue doesn't require a change to the product in order to be resolved. Most issues start as that labels Jan 24, 2025
@kristapratico kristapratico added Service Attention Workflow: This issue is responsible by Azure service team. Client This issue points to a problem in the data-plane of the library. AI Model Inference Issues related to the client library for Azure AI Model Inference (\sdk\ai\azure-ai-inference) and removed needs-triage Workflow: This is a new issue that needs to be triaged to the appropriate team. labels Jan 24, 2025
@github-actions github-actions bot added the needs-team-attention Workflow: This issue needs attention from Azure service team or SDK team label Jan 24, 2025
Copy link

Thanks for the feedback! We are routing this to the appropriate team for follow-up. cc @dargilco @jhakulin @trangevi.

@dargilco
Copy link
Member

Hi @Luke-Dornburgh thank you for opening this issue. I'm taking a look at it now.

@dargilco
Copy link
Member

@Luke-Dornburgh it seems that even though the model supports tool calls, The appropriate Azure Service hosting this model has not yet been updated to expose tool calls via the Azure AI Model Inference REST APIs. That's why this page https://learn.microsoft.com/azure/ai-foundry/model-inference/concepts/models#meta still shows Llama models not supporting tool calls. I'm trying to get information from the appropriate team about when this will be fixed.

@Luke-Dornburgh
Copy link
Author

@dargilco, thanks for the update on this. I was not aware of this capabilities documentation so this will be helpful in the meantime to know what is possible. I assume these restrictions probably apply to both Managed and Serverless deployments? Please keep me updated on what PG from Azure says about when this could change.

@Luke-Dornburgh
Copy link
Author

@dargilco Any new from internal team about future support on this? I have been continuing testing out the SDK and it turns out that when I query a client using a gpt model which invokes tool calling, and then I create a new client for a Meta model, when I pass the message history to that new Meta-based client, it has issues handling the AssistantMessage that describes the tool calls that were suggested.

Seems this may be more than just a support issue and also perhaps a bug that hinders the ability to switch models mid-conversation

@dargilco
Copy link
Member

dargilco commented Jan 31, 2025

@Luke-Dornburgh I have not yet heard back, I reached out again.

I'd like to learn more about what you tried next, as it seems, from the point of view of the new client that you created to handle the tool response, this is just a regular chat completions operation with history. Can you please open a new GitHub issue and share your code? Can you modify this sample for example https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py to take the tool response you got from gpt model, and pass it as chat history on a new client that uses Meta model? I want to be sure I understand what you did.

Update: I run this sample for chat completions with history https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_history.py with model Meta-Llama-3.1-70B-Instruct and it worked fine. Can you try this sample as-is, on your model, and let me know if that works? Is this not what you were trying to do in your second client? Please reply in a new GitHub issue per the above.

Thanks!

@dargilco dargilco changed the title Tool Calling for AzureML Managed Deployments Missing support for tool calling in Llama models (was: Tool Calling for AzureML Managed Deployments) Jan 31, 2025
@Luke-Dornburgh
Copy link
Author

@dargilco here is a new issue opened with specific details behind what I am seeing: #39509

@dargilco
Copy link
Member

dargilco commented Feb 7, 2025

Since we are tracking this work request internally, I'm going to close this issue. I do not yet have an estimated time for when Llama models will support tool call.

@dargilco dargilco closed this as completed Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AI Model Inference Issues related to the client library for Azure AI Model Inference (\sdk\ai\azure-ai-inference) Client This issue points to a problem in the data-plane of the library. customer-reported Issues that are reported by GitHub users external to the Azure organization. needs-team-attention Workflow: This issue needs attention from Azure service team or SDK team question The issue doesn't require a change to the product in order to be resolved. Most issues start as that Service Attention Workflow: This issue is responsible by Azure service team.
Projects
None yet
Development

No branches or pull requests

3 participants