Skip to content

Commit

Permalink
Check in Dev's gist (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
pgarbacki authored Feb 9, 2024
1 parent fb3158a commit e1d9f79
Showing 1 changed file with 197 additions and 0 deletions.
197 changes: 197 additions & 0 deletions examples/function_calling/fireworks_function_calling_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1bfdf168",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "eb849e97",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"content\": null,\n",
" \"role\": \"assistant\",\n",
" \"function_call\": null,\n",
" \"tool_calls\": [\n",
" {\n",
" \"id\": \"call_GKEtwoiSicLsGwT60Pkd4Xpz\",\n",
" \"function\": {\n",
" \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n",
" \"name\": \"get_financial_data\"\n",
" },\n",
" \"type\": \"function\",\n",
" \"index\": 0\n",
" }\n",
" ]\n",
"}\n"
]
}
],
"source": [
"client = openai.OpenAI(\n",
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
" api_key = \"<YOUR_FIREWORKS_API_KEY>\"\n",
")\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions. Use them if required.\"},\n",
" {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n",
"]\n",
"\n",
"# Describe the functions available to the agent in great detail in JSON Schema\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_financial_data\",\n",
" \"description\": \"Get financial data for a company given the metric and year.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"metric\": {\n",
" \"type\": \"enum\",\n",
" \"choices\": [\"net_income\", \"revenue\", \"ebdita\"],\n",
" },\n",
" \"financial_year\": {\"type\": \"int\", \n",
" \"description\": \"Year for which we want to get financial data.\"\n",
" },\n",
" \"company\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Name of the company for which we want to get financial data.\"\n",
" }\n",
" },\n",
" \"required\": [\"metric\", \"financial_year\", \"company\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"\n",
"chat_completion = client.chat.completions.create(\n",
" model=\"accounts/fireworks/models/fw-function-call-34b-v0\",\n",
" messages=messages,\n",
" tools=tools,\n",
" temperature=0.1\n",
")\n",
"print(chat_completion.choices[0].message.model_dump_json(indent=4))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3b8f1629",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"metric='net_income' financial_year=2022 company='Nike'\n",
"{'net_income': 6046000000}\n"
]
}
],
"source": [
"# Locally available tool implementation. Can be replaced by any remote API\n",
"def get_financial_data(metric: str, financial_year: int, company: str):\n",
" print(f\"{metric=} {financial_year=} {company=}\")\n",
" if metric == \"net_income\" and financial_year == 2022 and company == \"Nike\":\n",
" return {\"net_income\": 6_046_000_000}\n",
" else:\n",
" raise NotImplementedError()\n",
"\n",
"function_call = chat_completion.choices[0].message.tool_calls[0].function\n",
"tool_response = locals()[function_call.name](**json.loads(function_call.arguments))\n",
"print(tool_response)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2db51424",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nike's net income for the year 2022 was $6,046,000,000. \n"
]
}
],
"source": [
"agent_response = chat_completion.choices[0].message\n",
"\n",
"# Append the response from the agent\n",
"messages.append(\n",
" {\n",
" \"role\": agent_response.role, \n",
" \"content\": \"\",\n",
" \"tool_calls\": [\n",
" tool_call.model_dump()\n",
" for tool_call in chat_completion.choices[0].message.tool_calls\n",
" ]\n",
" }\n",
")\n",
"\n",
"# Append the response from the tool \n",
"messages.append(\n",
" {\n",
" \"role\": \"function\",\n",
" \"content\": json.dumps(tool_response)\n",
" }\n",
")\n",
"\n",
"next_chat_completion = client.chat.completions.create(\n",
" model=\"accounts/fireworks/models/fw-function-call-34b-v0\",\n",
" messages=messages,\n",
" tools=tools,\n",
" temperature=0.1\n",
")\n",
"\n",
"print(next_chat_completion.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a6b85b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit e1d9f79

Please sign in to comment.