From e1d9f79796f8dc0d5866df9b2a7b8bed47940678 Mon Sep 17 00:00:00 2001 From: Pawel Garbacki Date: Fri, 9 Feb 2024 11:56:28 -0800 Subject: [PATCH] Check in Dev's gist (#67) --- .../fireworks_function_calling_demo.ipynb | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 examples/function_calling/fireworks_function_calling_demo.ipynb diff --git a/examples/function_calling/fireworks_function_calling_demo.ipynb b/examples/function_calling/fireworks_function_calling_demo.ipynb new file mode 100644 index 0000000..c95de2b --- /dev/null +++ b/examples/function_calling/fireworks_function_calling_demo.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1bfdf168", + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "eb849e97", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"content\": null,\n", + " \"role\": \"assistant\",\n", + " \"function_call\": null,\n", + " \"tool_calls\": [\n", + " {\n", + " \"id\": \"call_GKEtwoiSicLsGwT60Pkd4Xpz\",\n", + " \"function\": {\n", + " \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n", + " \"name\": \"get_financial_data\"\n", + " },\n", + " \"type\": \"function\",\n", + " \"index\": 0\n", + " }\n", + " ]\n", + "}\n" + ] + } + ], + "source": [ + "client = openai.OpenAI(\n", + " base_url = \"https://api.fireworks.ai/inference/v1\",\n", + " api_key = \"\"\n", + ")\n", + "\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions. Use them if required.\"},\n", + " {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n", + "]\n", + "\n", + "# Describe the functions available to the agent in great detail in JSON Schema\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_financial_data\",\n", + " \"description\": \"Get financial data for a company given the metric and year.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"metric\": {\n", + " \"type\": \"enum\",\n", + " \"choices\": [\"net_income\", \"revenue\", \"ebdita\"],\n", + " },\n", + " \"financial_year\": {\"type\": \"int\", \n", + " \"description\": \"Year for which we want to get financial data.\"\n", + " },\n", + " \"company\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Name of the company for which we want to get financial data.\"\n", + " }\n", + " },\n", + " \"required\": [\"metric\", \"financial_year\", \"company\"],\n", + " },\n", + " },\n", + " }\n", + "]\n", + "\n", + "chat_completion = client.chat.completions.create(\n", + " model=\"accounts/fireworks/models/fw-function-call-34b-v0\",\n", + " messages=messages,\n", + " tools=tools,\n", + " temperature=0.1\n", + ")\n", + "print(chat_completion.choices[0].message.model_dump_json(indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3b8f1629", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "metric='net_income' financial_year=2022 company='Nike'\n", + "{'net_income': 6046000000}\n" + ] + } + ], + "source": [ + "# Locally available tool implementation. Can be replaced by any remote API\n", + "def get_financial_data(metric: str, financial_year: int, company: str):\n", + " print(f\"{metric=} {financial_year=} {company=}\")\n", + " if metric == \"net_income\" and financial_year == 2022 and company == \"Nike\":\n", + " return {\"net_income\": 6_046_000_000}\n", + " else:\n", + " raise NotImplementedError()\n", + "\n", + "function_call = chat_completion.choices[0].message.tool_calls[0].function\n", + "tool_response = locals()[function_call.name](**json.loads(function_call.arguments))\n", + "print(tool_response)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2db51424", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nike's net income for the year 2022 was $6,046,000,000. \n" + ] + } + ], + "source": [ + "agent_response = chat_completion.choices[0].message\n", + "\n", + "# Append the response from the agent\n", + "messages.append(\n", + " {\n", + " \"role\": agent_response.role, \n", + " \"content\": \"\",\n", + " \"tool_calls\": [\n", + " tool_call.model_dump()\n", + " for tool_call in chat_completion.choices[0].message.tool_calls\n", + " ]\n", + " }\n", + ")\n", + "\n", + "# Append the response from the tool \n", + "messages.append(\n", + " {\n", + " \"role\": \"function\",\n", + " \"content\": json.dumps(tool_response)\n", + " }\n", + ")\n", + "\n", + "next_chat_completion = client.chat.completions.create(\n", + " model=\"accounts/fireworks/models/fw-function-call-34b-v0\",\n", + " messages=messages,\n", + " tools=tools,\n", + " temperature=0.1\n", + ")\n", + "\n", + "print(next_chat_completion.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a6b85b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}