Skip to content

Commit

Permalink
Add Cohere provider
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs-cohere committed Feb 14, 2025
1 parent e15c809 commit 37f2b11
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_COHERE_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ await uploadFile({
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
content: new Blob(...)
content: new Blob(...)
}
});

Expand All @@ -39,7 +39,7 @@ await inference.chatCompletion({
],
max_tokens: 512,
temperature: 0.5,
provider: "sambanova", // or together, fal-ai, replicate, …
provider: "sambanova", // or together, fal-ai, replicate, cohere
});

await inference.textToImage({
Expand Down Expand Up @@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({
console.log(chunk.choices[0].delta.content);
}

/// Using a third-party provider:
/// Using a third-party provider:
await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
max_tokens: 512,
provider: "sambanova", // or together, fal-ai, replicate, …
provider: "sambanova", // or together, fal-ai, replicate, cohere
})

await inference.textToImage({
Expand Down Expand Up @@ -211,7 +211,7 @@ await uploadFile({
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
content: new Blob(...)
content: new Blob(...)
}
});

Expand Down Expand Up @@ -244,7 +244,7 @@ console.log(messages); // contains the data

// or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk.
const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.")
console.log(messages);
console.log(messages);
```

There are more features of course, check each library's README!
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Currently, we support the following providers:
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand All @@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
Expand Down
10 changes: 10 additions & 0 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { NOVITA_API_BASE_URL } from "../providers/novita";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
import { COHERE_API_BASE_URL } from "../providers/cohere";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -256,6 +257,15 @@ function makeUrl(params: {
}
return baseUrl;
}
case "cohere": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: COHERE_API_BASE_URL;
if (params.taskHint === "text-generation") {
return `${baseUrl}/v2/chat`;
}
return baseUrl;
}
default: {
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/cohere.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const COHERE_API_BASE_URL = "https://api.cohere.com";

/**
* See the registered mapping of HF model ID => Cohere model ID here:
*
* https://huggingface.co/api/partners/cohere/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
* and we will tag Cohere team members.
*
* Thanks!
*/
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
sambanova: {},
together: {},
novita: {},
cohere: {},
};
167 changes: 146 additions & 21 deletions packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,158 @@ import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";

export type CohereTextGenerationOutputFinishReason =
| "COMPLETE"
| "STOP_SEQUENCE"
| "MAX_TOKENS"
| "TOOL_CALL"
| "ERROR";

interface CohereChatCompletionOutput {
id: string;
finish_reason: CohereTextGenerationOutputFinishReason;
message: CohereMessage;
usage: CohereChatCompletionOutputUsage;
logprobs?: CohereLogprob[]; // Optional field for log probabilities
}

interface CohereMessage {
role: string;
content: Array<{
type: string;
text: string;
}>;
tool_calls?: CohereToolCall[]; // Optional field for tool calls
}

interface CohereChatCompletionOutputUsage {
billed_units: CohereInputOutputTokens;
tokens: CohereInputOutputTokens;
}

interface CohereInputOutputTokens {
input_tokens: number;
output_tokens: number;
}

interface CohereLogprob {
logprob: number;
token: string;
top_logprobs: CohereTopLogprob[];
}

interface CohereTopLogprob {
logprob: number;
token: string;
}

interface CohereToolCall {
function: CohereFunctionDefinition;
id: string;
type: string;
}

interface CohereFunctionDefinition {
arguments: unknown;
description?: string;
name: string;
}

function convertCohereToChatCompletionOutput(res: CohereChatCompletionOutput): ChatCompletionOutput {
// Create a ChatCompletionOutput object from the CohereChatCompletionOutput
return {
id: res.id,
created: Date.now(),
model: "cohere-model",
system_fingerprint: "cohere-fingerprint",
usage: {
completion_tokens: res.usage.tokens.output_tokens,
prompt_tokens: res.usage.tokens.input_tokens,
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
},
choices: [
{
finish_reason: res.finish_reason,
index: 0,
message: {
role: res.message.role,
content: res.message.content.map((c) => c.text).join(" "),
tool_calls: res.message.tool_calls?.map((toolCall) => ({
function: {
arguments: toolCall.function.arguments,
description: toolCall.function.description,
name: toolCall.function.name,
},
id: toolCall.id,
type: toolCall.type,
})),
},
logprobs: res.logprobs
? {
content: res.logprobs.map((logprob) => ({
logprob: logprob.logprob,
token: logprob.token,
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
logprob: topLogprob.logprob,
token: topLogprob.token,
})),
})),
}
: undefined,
},
],
};
}

/**
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
*/
export async function chatCompletion(
args: BaseArgs & ChatCompletionInput,
options?: Options
): Promise<ChatCompletionOutput> {
const res = await request<ChatCompletionOutput>(args, {
...options,
taskHint: "text-generation",
chatCompletion: true,
});

const isValidOutput =
typeof res === "object" &&
Array.isArray(res?.choices) &&
typeof res?.created === "number" &&
typeof res?.id === "string" &&
typeof res?.model === "string" &&
/// Together.ai and Nebius do not output a system_fingerprint
(res.system_fingerprint === undefined ||
res.system_fingerprint === null ||
typeof res.system_fingerprint === "string") &&
typeof res?.usage === "object";

if (!isValidOutput) {
throw new InferenceOutputError("Expected ChatCompletionOutput");
if (args.provider === "cohere") {
const res = await request<CohereChatCompletionOutput>(args, {
...options,
taskHint: "text-generation",
chatCompletion: true,
});

const isValidOutput =
typeof res === "object" &&
typeof res?.id === "string" &&
typeof res?.finish_reason === "string" &&
typeof res?.message === "object" &&
Array.isArray(res?.message.content) &&
typeof res?.usage === "object";

if (!isValidOutput) {
throw new InferenceOutputError("Expected CohereChatCompletionOutput");
}

return convertCohereToChatCompletionOutput(res);
} else {
const res = await request<ChatCompletionOutput>(args, {
...options,
taskHint: "text-generation",
chatCompletion: true,
});

const isValidOutput =
typeof res === "object" &&
Array.isArray(res?.choices) &&
typeof res?.created === "number" &&
typeof res?.id === "string" &&
typeof res?.model === "string" &&
/// Together.ai and Nebius do not output a system_fingerprint
(res.system_fingerprint === undefined ||
res.system_fingerprint === null ||
typeof res.system_fingerprint === "string") &&
typeof res?.usage === "object";

if (!isValidOutput) {
throw new InferenceOutputError("Expected ChatCompletionOutput");
}
return res;
}
return res;
}
Loading

0 comments on commit 37f2b11

Please sign in to comment.