@ai-sdk/cohere
Advanced tools
Comparing version 0.0.27 to 0.0.28
@@ -54,3 +54,3 @@ "use strict"; | ||
case "system": { | ||
messages.push({ role: "SYSTEM", message: content }); | ||
messages.push({ role: "system", content }); | ||
break; | ||
@@ -60,4 +60,4 @@ } | ||
messages.push({ | ||
role: "USER", | ||
message: content.map((part) => { | ||
role: "user", | ||
content: content.map((part) => { | ||
switch (part.type) { | ||
@@ -88,4 +88,8 @@ case "text": { | ||
toolCalls.push({ | ||
name: part.toolName, | ||
parameters: part.args | ||
id: part.toolCallId, | ||
type: "function", | ||
function: { | ||
name: part.toolName, | ||
arguments: JSON.stringify(part.args) | ||
} | ||
}); | ||
@@ -101,4 +105,7 @@ break; | ||
messages.push({ | ||
role: "CHATBOT", | ||
message: text, | ||
role: "assistant", | ||
// note: this is a workaround for a Cohere API bug | ||
// that requires content to be provided | ||
// even if there are tool calls | ||
content: text !== "" ? text : "call tool", | ||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0 | ||
@@ -109,20 +116,9 @@ }); | ||
case "tool": { | ||
messages.push({ | ||
role: "TOOL", | ||
tool_results: content.map((toolResult) => ({ | ||
call: { | ||
name: toolResult.toolName, | ||
/* | ||
Note: Currently the tool_results field requires we pass the parameters of the tool results again. It it is blank for two reasons: | ||
1. The parameters are already present in chat_history as a tool message | ||
2. The tool core message of the ai sdk does not include parameters | ||
It is possible to traverse through the chat history and get the parameters by id but it's currently empty since there wasn't any degradation in the output when left blank. | ||
*/ | ||
parameters: {} | ||
}, | ||
outputs: [toolResult.result] | ||
messages.push( | ||
...content.map((toolResult) => ({ | ||
role: "tool", | ||
content: JSON.stringify(toolResult.result), | ||
tool_call_id: toolResult.toolCallId | ||
})) | ||
}); | ||
); | ||
break; | ||
@@ -166,3 +162,3 @@ } | ||
if (tools == null) { | ||
return { tools: void 0, force_single_step: void 0, toolWarnings }; | ||
return { tools: void 0, tool_choice: void 0, toolWarnings }; | ||
} | ||
@@ -174,45 +170,9 @@ const cohereTools = []; | ||
} else { | ||
const { properties, required } = tool.parameters; | ||
const parameterDefinitions = {}; | ||
if (properties) { | ||
for (const [key, value] of Object.entries(properties)) { | ||
if (typeof value === "object" && value !== null) { | ||
const { type: JSONType, description } = value; | ||
let type2; | ||
if (typeof JSONType === "string") { | ||
switch (JSONType) { | ||
case "string": | ||
type2 = "str"; | ||
break; | ||
case "number": | ||
type2 = "float"; | ||
break; | ||
case "integer": | ||
type2 = "int"; | ||
break; | ||
case "boolean": | ||
type2 = "bool"; | ||
break; | ||
default: | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool parameter type: ${JSONType}` | ||
}); | ||
} | ||
} else { | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool parameter type: ${JSONType}` | ||
}); | ||
} | ||
parameterDefinitions[key] = { | ||
required: required ? required.includes(key) : false, | ||
type: type2, | ||
description | ||
}; | ||
} | ||
cohereTools.push({ | ||
type: "function", | ||
function: { | ||
name: tool.name, | ||
description: tool.description, | ||
parameters: tool.parameters | ||
} | ||
} | ||
cohereTools.push({ | ||
name: tool.name, | ||
description: tool.description, | ||
parameterDefinitions | ||
}); | ||
@@ -223,3 +183,3 @@ } | ||
if (toolChoice == null) { | ||
return { tools: cohereTools, force_single_step: false, toolWarnings }; | ||
return { tools: cohereTools, tool_choice: void 0, toolWarnings }; | ||
} | ||
@@ -229,13 +189,10 @@ const type = toolChoice.type; | ||
case "auto": | ||
return { tools: cohereTools, force_single_step: false, toolWarnings }; | ||
return { tools: cohereTools, tool_choice: type, toolWarnings }; | ||
case "none": | ||
return { tools: void 0, tool_choice: "any", toolWarnings }; | ||
case "required": | ||
return { tools: cohereTools, force_single_step: true, toolWarnings }; | ||
case "none": | ||
return { tools: void 0, force_single_step: false, toolWarnings }; | ||
case "tool": | ||
return { | ||
tools: cohereTools.filter((tool) => tool.name === toolChoice.toolName), | ||
force_single_step: true, | ||
toolWarnings | ||
}; | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool choice type: ${type}` | ||
}); | ||
default: { | ||
@@ -277,4 +234,2 @@ const _exhaustiveCheck = type; | ||
const chatPrompt = convertToCohereChatPrompt(prompt); | ||
const lastMessage = chatPrompt.at(-1); | ||
const history = chatPrompt.slice(0, -1); | ||
const baseArgs = { | ||
@@ -297,13 +252,10 @@ // model id: | ||
// messages: | ||
chat_history: history, | ||
...(lastMessage == null ? void 0 : lastMessage.role) === "TOOL" ? { tool_results: lastMessage.tool_results } : {}, | ||
message: lastMessage ? lastMessage.role === "USER" ? lastMessage.message : void 0 : void 0 | ||
messages: chatPrompt | ||
}; | ||
switch (type) { | ||
case "regular": { | ||
const { tools, force_single_step, toolWarnings } = prepareTools(mode); | ||
const { tools, tool_choice, toolWarnings } = prepareTools(mode); | ||
return { | ||
...baseArgs, | ||
tools, | ||
force_single_step, | ||
warnings: toolWarnings | ||
@@ -324,9 +276,37 @@ }; | ||
const _exhaustiveCheck = type; | ||
throw new Error(`Unsupported type: ${_exhaustiveCheck}`); | ||
throw new import_provider3.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported mode: ${_exhaustiveCheck}` | ||
}); | ||
} | ||
} | ||
} | ||
concatenateMessageText(messages) { | ||
return messages.filter( | ||
(message) => "content" in message | ||
).map((message) => message.content).join(""); | ||
} | ||
/* | ||
Remove `additionalProperties` and `$schema` from the `parameters` object of each tool. | ||
Though these are part of JSON schema, Cohere chokes if we include them in the request. | ||
*/ | ||
// TODO(shaper): Look at defining a type to simplify the params here and a couple of other places. | ||
removeJsonSchemaExtras(tools) { | ||
return tools.map((tool) => { | ||
if (tool.type === "function" && tool.function.parameters && typeof tool.function.parameters === "object") { | ||
const { additionalProperties, $schema, ...restParameters } = tool.function.parameters; | ||
return { | ||
...tool, | ||
function: { | ||
...tool.function, | ||
parameters: restParameters | ||
} | ||
}; | ||
} | ||
return tool; | ||
}); | ||
} | ||
async doGenerate(options) { | ||
var _a; | ||
var _a, _b, _c, _d; | ||
const { warnings, ...args } = this.getArgs(options); | ||
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools); | ||
const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({ | ||
@@ -343,10 +323,9 @@ url: `${this.config.baseURL}/chat`, | ||
}); | ||
const { chat_history, message, ...rawSettings } = args; | ||
const generateId2 = this.config.generateId; | ||
const { messages, ...rawSettings } = args; | ||
return { | ||
text: response.text, | ||
toolCalls: response.tool_calls ? response.tool_calls.map((toolCall) => ({ | ||
toolCallId: generateId2(), | ||
toolName: toolCall.name, | ||
args: JSON.stringify(toolCall.parameters), | ||
text: (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "", | ||
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({ | ||
toolCallId: toolCall.id, | ||
toolName: toolCall.function.name, | ||
args: toolCall.function.arguments, | ||
toolCallType: "function" | ||
@@ -356,9 +335,8 @@ })) : [], | ||
usage: { | ||
promptTokens: response.meta.tokens.input_tokens, | ||
completionTokens: response.meta.tokens.output_tokens | ||
promptTokens: response.usage.tokens.input_tokens, | ||
completionTokens: response.usage.tokens.output_tokens | ||
}, | ||
rawCall: { | ||
rawPrompt: { | ||
chat_history, | ||
message | ||
messages | ||
}, | ||
@@ -368,3 +346,3 @@ rawSettings | ||
response: { | ||
id: (_a = response.generation_id) != null ? _a : void 0 | ||
id: (_d = response.generation_id) != null ? _d : void 0 | ||
}, | ||
@@ -378,2 +356,3 @@ rawResponse: { headers: responseHeaders }, | ||
const { warnings, ...args } = this.getArgs(options); | ||
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools); | ||
const body = { ...args, stream: true }; | ||
@@ -385,3 +364,3 @@ const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({ | ||
failedResponseHandler: cohereFailedResponseHandler, | ||
successfulResponseHandler: (0, import_provider_utils2.createJsonStreamResponseHandler)( | ||
successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)( | ||
cohereChatChunkSchema | ||
@@ -392,3 +371,3 @@ ), | ||
}); | ||
const { chat_history, message, ...rawSettings } = args; | ||
const { messages, ...rawSettings } = args; | ||
let finishReason = "unknown"; | ||
@@ -399,4 +378,7 @@ let usage = { | ||
}; | ||
const generateId2 = this.config.generateId; | ||
const toolCalls = []; | ||
let pendingToolCallDelta = { | ||
toolCallId: "", | ||
toolName: "", | ||
argsTextDelta: "" | ||
}; | ||
return { | ||
@@ -413,65 +395,64 @@ stream: response.pipeThrough( | ||
const value = chunk.value; | ||
const type = value.event_type; | ||
const type = value.type; | ||
switch (type) { | ||
case "text-generation": { | ||
case "content-delta": { | ||
controller.enqueue({ | ||
type: "text-delta", | ||
textDelta: value.text | ||
textDelta: value.delta.message.content.text | ||
}); | ||
return; | ||
} | ||
case "tool-calls-chunk": { | ||
if (value.tool_call_delta) { | ||
const { index } = value.tool_call_delta; | ||
if (toolCalls[index] === void 0) { | ||
const toolCallId = generateId2(); | ||
toolCalls[index] = { | ||
toolCallId, | ||
toolName: "" | ||
}; | ||
} | ||
if (value.tool_call_delta.name) { | ||
toolCalls[index].toolName = value.tool_call_delta.name; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallType: "function", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
argsTextDelta: "" | ||
}); | ||
} else if (value.tool_call_delta.parameters) { | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallType: "function", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
argsTextDelta: value.tool_call_delta.parameters | ||
}); | ||
} | ||
} | ||
case "tool-call-start": { | ||
pendingToolCallDelta = { | ||
toolCallId: value.delta.message.tool_calls.id, | ||
toolName: value.delta.message.tool_calls.function.name, | ||
argsTextDelta: value.delta.message.tool_calls.function.arguments | ||
}; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
argsTextDelta: pendingToolCallDelta.argsTextDelta | ||
}); | ||
return; | ||
} | ||
case "tool-calls-generation": { | ||
for (let index = 0; index < value.tool_calls.length; index++) { | ||
const toolCall = value.tool_calls[index]; | ||
controller.enqueue({ | ||
type: "tool-call", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
toolCallType: "function", | ||
args: JSON.stringify(toolCall.parameters) | ||
}); | ||
} | ||
case "tool-call-delta": { | ||
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
argsTextDelta: value.delta.message.tool_calls.function.arguments | ||
}); | ||
return; | ||
} | ||
case "stream-start": { | ||
case "tool-call-end": { | ||
controller.enqueue({ | ||
type: "tool-call", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
args: JSON.stringify( | ||
JSON.parse(pendingToolCallDelta.argsTextDelta) | ||
) | ||
}); | ||
pendingToolCallDelta = { | ||
toolCallId: "", | ||
toolName: "", | ||
argsTextDelta: "" | ||
}; | ||
return; | ||
} | ||
case "message-start": { | ||
controller.enqueue({ | ||
type: "response-metadata", | ||
id: (_a = value.generation_id) != null ? _a : void 0 | ||
id: (_a = value.id) != null ? _a : void 0 | ||
}); | ||
return; | ||
} | ||
case "stream-end": { | ||
finishReason = mapCohereFinishReason(value.finish_reason); | ||
const tokens = value.response.meta.tokens; | ||
case "message-end": { | ||
finishReason = mapCohereFinishReason(value.delta.finish_reason); | ||
const tokens = value.delta.usage.tokens; | ||
usage = { | ||
@@ -498,4 +479,3 @@ promptTokens: tokens.input_tokens, | ||
rawPrompt: { | ||
chat_history, | ||
message | ||
messages | ||
}, | ||
@@ -512,11 +492,27 @@ rawSettings | ||
generation_id: import_zod2.z.string().nullish(), | ||
text: import_zod2.z.string(), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
parameters: import_zod2.z.unknown({}) | ||
}) | ||
).nullish(), | ||
message: import_zod2.z.object({ | ||
role: import_zod2.z.string(), | ||
content: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
type: import_zod2.z.string(), | ||
text: import_zod2.z.string() | ||
}) | ||
).nullish(), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
id: import_zod2.z.string(), | ||
type: import_zod2.z.literal("function"), | ||
function: import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
).nullish() | ||
}), | ||
finish_reason: import_zod2.z.string(), | ||
meta: import_zod2.z.object({ | ||
usage: import_zod2.z.object({ | ||
billed_units: import_zod2.z.object({ | ||
input_tokens: import_zod2.z.number(), | ||
output_tokens: import_zod2.z.number() | ||
}), | ||
tokens: import_zod2.z.object({ | ||
@@ -528,43 +524,34 @@ input_tokens: import_zod2.z.number(), | ||
}); | ||
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("event_type", [ | ||
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [ | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("stream-start"), | ||
generation_id: import_zod2.z.string().nullish() | ||
type: import_zod2.z.literal("citation-start") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("search-queries-generation") | ||
type: import_zod2.z.literal("citation-end") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("search-results") | ||
type: import_zod2.z.literal("content-start") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("text-generation"), | ||
text: import_zod2.z.string() | ||
type: import_zod2.z.literal("content-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
content: import_zod2.z.object({ | ||
text: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("citation-generation") | ||
type: import_zod2.z.literal("content-end") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("tool-calls-generation"), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
parameters: import_zod2.z.unknown({}) | ||
}) | ||
) | ||
type: import_zod2.z.literal("message-start"), | ||
id: import_zod2.z.string().nullish() | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("tool-calls-chunk"), | ||
text: import_zod2.z.string().optional(), | ||
tool_call_delta: import_zod2.z.object({ | ||
index: import_zod2.z.number(), | ||
name: import_zod2.z.string().optional(), | ||
parameters: import_zod2.z.string().optional() | ||
}).optional() | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("stream-end"), | ||
finish_reason: import_zod2.z.string(), | ||
response: import_zod2.z.object({ | ||
meta: import_zod2.z.object({ | ||
type: import_zod2.z.literal("message-end"), | ||
delta: import_zod2.z.object({ | ||
finish_reason: import_zod2.z.string(), | ||
usage: import_zod2.z.object({ | ||
tokens: import_zod2.z.object({ | ||
@@ -576,2 +563,44 @@ input_tokens: import_zod2.z.number(), | ||
}) | ||
}), | ||
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-plan-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_plan: import_zod2.z.string() | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-start"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_calls: import_zod2.z.object({ | ||
id: import_zod2.z.string(), | ||
type: import_zod2.z.literal("function"), | ||
function: import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}) | ||
}), | ||
// A single tool call's `arguments` stream in chunks and must be accumulated | ||
// in a string and so the full tool object info can only be parsed once we see | ||
// `tool-call-end`. | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_calls: import_zod2.z.object({ | ||
function: import_zod2.z.object({ | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-end") | ||
}) | ||
@@ -615,2 +644,7 @@ ]); | ||
model: this.modelId, | ||
// TODO(shaper): There are other embedding types. Do we need to support them? | ||
// For now we only support 'float' embeddings which are also the only ones | ||
// the Cohere API docs state are supported for all models. | ||
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types | ||
embedding_types: ["float"], | ||
texts: values, | ||
@@ -628,3 +662,3 @@ input_type: (_a = this.settings.inputType) != null ? _a : "search_query", | ||
return { | ||
embeddings: response.embeddings, | ||
embeddings: response.embeddings.float, | ||
usage: { tokens: response.meta.billed_units.input_tokens }, | ||
@@ -636,3 +670,5 @@ rawResponse: { headers: responseHeaders } | ||
var cohereTextEmbeddingResponseSchema = import_zod3.z.object({ | ||
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())), | ||
embeddings: import_zod3.z.object({ | ||
float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())) | ||
}), | ||
meta: import_zod3.z.object({ | ||
@@ -648,3 +684,3 @@ billed_units: import_zod3.z.object({ | ||
var _a; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2"; | ||
const getHeaders = () => ({ | ||
@@ -658,12 +694,8 @@ Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({ | ||
}); | ||
const createChatModel = (modelId, settings = {}) => { | ||
var _a2; | ||
return new CohereChatLanguageModel(modelId, settings, { | ||
provider: "cohere.chat", | ||
baseURL, | ||
headers: getHeaders, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId, | ||
fetch: options.fetch | ||
}); | ||
}; | ||
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, { | ||
provider: "cohere.chat", | ||
baseURL, | ||
headers: getHeaders, | ||
fetch: options.fetch | ||
}); | ||
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, { | ||
@@ -670,0 +702,0 @@ provider: "cohere.textEmbedding", |
# @ai-sdk/cohere | ||
## 0.0.28 | ||
### Patch Changes | ||
- a7cbdf6: feat (provider/cohere): Use Cohere v2 API | ||
## 0.0.27 | ||
@@ -4,0 +10,0 @@ |
@@ -42,3 +42,3 @@ import { ProviderV1, LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; | ||
Use a different URL prefix for API calls, e.g. to use proxy servers. | ||
The default prefix is `https://api.cohere.com/v1`. | ||
The default prefix is `https://api.cohere.com/v2`. | ||
*/ | ||
@@ -60,3 +60,2 @@ baseURL?: string; | ||
fetch?: FetchFunction; | ||
generateId?: () => string; | ||
} | ||
@@ -63,0 +62,0 @@ /** |
@@ -54,3 +54,3 @@ "use strict"; | ||
case "system": { | ||
messages.push({ role: "SYSTEM", message: content }); | ||
messages.push({ role: "system", content }); | ||
break; | ||
@@ -60,4 +60,4 @@ } | ||
messages.push({ | ||
role: "USER", | ||
message: content.map((part) => { | ||
role: "user", | ||
content: content.map((part) => { | ||
switch (part.type) { | ||
@@ -88,4 +88,8 @@ case "text": { | ||
toolCalls.push({ | ||
name: part.toolName, | ||
parameters: part.args | ||
id: part.toolCallId, | ||
type: "function", | ||
function: { | ||
name: part.toolName, | ||
arguments: JSON.stringify(part.args) | ||
} | ||
}); | ||
@@ -101,4 +105,7 @@ break; | ||
messages.push({ | ||
role: "CHATBOT", | ||
message: text, | ||
role: "assistant", | ||
// note: this is a workaround for a Cohere API bug | ||
// that requires content to be provided | ||
// even if there are tool calls | ||
content: text !== "" ? text : "call tool", | ||
tool_calls: toolCalls.length > 0 ? toolCalls : void 0 | ||
@@ -109,20 +116,9 @@ }); | ||
case "tool": { | ||
messages.push({ | ||
role: "TOOL", | ||
tool_results: content.map((toolResult) => ({ | ||
call: { | ||
name: toolResult.toolName, | ||
/* | ||
Note: Currently the tool_results field requires we pass the parameters of the tool results again. It it is blank for two reasons: | ||
1. The parameters are already present in chat_history as a tool message | ||
2. The tool core message of the ai sdk does not include parameters | ||
It is possible to traverse through the chat history and get the parameters by id but it's currently empty since there wasn't any degradation in the output when left blank. | ||
*/ | ||
parameters: {} | ||
}, | ||
outputs: [toolResult.result] | ||
messages.push( | ||
...content.map((toolResult) => ({ | ||
role: "tool", | ||
content: JSON.stringify(toolResult.result), | ||
tool_call_id: toolResult.toolCallId | ||
})) | ||
}); | ||
); | ||
break; | ||
@@ -166,3 +162,3 @@ } | ||
if (tools == null) { | ||
return { tools: void 0, force_single_step: void 0, toolWarnings }; | ||
return { tools: void 0, tool_choice: void 0, toolWarnings }; | ||
} | ||
@@ -174,45 +170,9 @@ const cohereTools = []; | ||
} else { | ||
const { properties, required } = tool.parameters; | ||
const parameterDefinitions = {}; | ||
if (properties) { | ||
for (const [key, value] of Object.entries(properties)) { | ||
if (typeof value === "object" && value !== null) { | ||
const { type: JSONType, description } = value; | ||
let type2; | ||
if (typeof JSONType === "string") { | ||
switch (JSONType) { | ||
case "string": | ||
type2 = "str"; | ||
break; | ||
case "number": | ||
type2 = "float"; | ||
break; | ||
case "integer": | ||
type2 = "int"; | ||
break; | ||
case "boolean": | ||
type2 = "bool"; | ||
break; | ||
default: | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool parameter type: ${JSONType}` | ||
}); | ||
} | ||
} else { | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool parameter type: ${JSONType}` | ||
}); | ||
} | ||
parameterDefinitions[key] = { | ||
required: required ? required.includes(key) : false, | ||
type: type2, | ||
description | ||
}; | ||
} | ||
cohereTools.push({ | ||
type: "function", | ||
function: { | ||
name: tool.name, | ||
description: tool.description, | ||
parameters: tool.parameters | ||
} | ||
} | ||
cohereTools.push({ | ||
name: tool.name, | ||
description: tool.description, | ||
parameterDefinitions | ||
}); | ||
@@ -223,3 +183,3 @@ } | ||
if (toolChoice == null) { | ||
return { tools: cohereTools, force_single_step: false, toolWarnings }; | ||
return { tools: cohereTools, tool_choice: void 0, toolWarnings }; | ||
} | ||
@@ -229,13 +189,10 @@ const type = toolChoice.type; | ||
case "auto": | ||
return { tools: cohereTools, force_single_step: false, toolWarnings }; | ||
return { tools: cohereTools, tool_choice: type, toolWarnings }; | ||
case "none": | ||
return { tools: void 0, tool_choice: "any", toolWarnings }; | ||
case "required": | ||
return { tools: cohereTools, force_single_step: true, toolWarnings }; | ||
case "none": | ||
return { tools: void 0, force_single_step: false, toolWarnings }; | ||
case "tool": | ||
return { | ||
tools: cohereTools.filter((tool) => tool.name === toolChoice.toolName), | ||
force_single_step: true, | ||
toolWarnings | ||
}; | ||
throw new import_provider2.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported tool choice type: ${type}` | ||
}); | ||
default: { | ||
@@ -277,4 +234,2 @@ const _exhaustiveCheck = type; | ||
const chatPrompt = convertToCohereChatPrompt(prompt); | ||
const lastMessage = chatPrompt.at(-1); | ||
const history = chatPrompt.slice(0, -1); | ||
const baseArgs = { | ||
@@ -297,13 +252,10 @@ // model id: | ||
// messages: | ||
chat_history: history, | ||
...(lastMessage == null ? void 0 : lastMessage.role) === "TOOL" ? { tool_results: lastMessage.tool_results } : {}, | ||
message: lastMessage ? lastMessage.role === "USER" ? lastMessage.message : void 0 : void 0 | ||
messages: chatPrompt | ||
}; | ||
switch (type) { | ||
case "regular": { | ||
const { tools, force_single_step, toolWarnings } = prepareTools(mode); | ||
const { tools, tool_choice, toolWarnings } = prepareTools(mode); | ||
return { | ||
...baseArgs, | ||
tools, | ||
force_single_step, | ||
warnings: toolWarnings | ||
@@ -324,9 +276,37 @@ }; | ||
const _exhaustiveCheck = type; | ||
throw new Error(`Unsupported type: ${_exhaustiveCheck}`); | ||
throw new import_provider3.UnsupportedFunctionalityError({ | ||
functionality: `Unsupported mode: ${_exhaustiveCheck}` | ||
}); | ||
} | ||
} | ||
} | ||
concatenateMessageText(messages) { | ||
return messages.filter( | ||
(message) => "content" in message | ||
).map((message) => message.content).join(""); | ||
} | ||
/* | ||
Remove `additionalProperties` and `$schema` from the `parameters` object of each tool. | ||
Though these are part of JSON schema, Cohere chokes if we include them in the request. | ||
*/ | ||
// TODO(shaper): Look at defining a type to simplify the params here and a couple of other places. | ||
removeJsonSchemaExtras(tools) { | ||
return tools.map((tool) => { | ||
if (tool.type === "function" && tool.function.parameters && typeof tool.function.parameters === "object") { | ||
const { additionalProperties, $schema, ...restParameters } = tool.function.parameters; | ||
return { | ||
...tool, | ||
function: { | ||
...tool.function, | ||
parameters: restParameters | ||
} | ||
}; | ||
} | ||
return tool; | ||
}); | ||
} | ||
async doGenerate(options) { | ||
var _a; | ||
var _a, _b, _c, _d; | ||
const { warnings, ...args } = this.getArgs(options); | ||
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools); | ||
const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({ | ||
@@ -343,10 +323,9 @@ url: `${this.config.baseURL}/chat`, | ||
}); | ||
const { chat_history, message, ...rawSettings } = args; | ||
const generateId2 = this.config.generateId; | ||
const { messages, ...rawSettings } = args; | ||
return { | ||
text: response.text, | ||
toolCalls: response.tool_calls ? response.tool_calls.map((toolCall) => ({ | ||
toolCallId: generateId2(), | ||
toolName: toolCall.name, | ||
args: JSON.stringify(toolCall.parameters), | ||
text: (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "", | ||
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({ | ||
toolCallId: toolCall.id, | ||
toolName: toolCall.function.name, | ||
args: toolCall.function.arguments, | ||
toolCallType: "function" | ||
@@ -356,9 +335,8 @@ })) : [], | ||
usage: { | ||
promptTokens: response.meta.tokens.input_tokens, | ||
completionTokens: response.meta.tokens.output_tokens | ||
promptTokens: response.usage.tokens.input_tokens, | ||
completionTokens: response.usage.tokens.output_tokens | ||
}, | ||
rawCall: { | ||
rawPrompt: { | ||
chat_history, | ||
message | ||
messages | ||
}, | ||
@@ -368,3 +346,3 @@ rawSettings | ||
response: { | ||
id: (_a = response.generation_id) != null ? _a : void 0 | ||
id: (_d = response.generation_id) != null ? _d : void 0 | ||
}, | ||
@@ -378,2 +356,3 @@ rawResponse: { headers: responseHeaders }, | ||
const { warnings, ...args } = this.getArgs(options); | ||
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools); | ||
const body = { ...args, stream: true }; | ||
@@ -385,3 +364,3 @@ const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({ | ||
failedResponseHandler: cohereFailedResponseHandler, | ||
successfulResponseHandler: (0, import_provider_utils2.createJsonStreamResponseHandler)( | ||
successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)( | ||
cohereChatChunkSchema | ||
@@ -392,3 +371,3 @@ ), | ||
}); | ||
const { chat_history, message, ...rawSettings } = args; | ||
const { messages, ...rawSettings } = args; | ||
let finishReason = "unknown"; | ||
@@ -399,4 +378,7 @@ let usage = { | ||
}; | ||
const generateId2 = this.config.generateId; | ||
const toolCalls = []; | ||
let pendingToolCallDelta = { | ||
toolCallId: "", | ||
toolName: "", | ||
argsTextDelta: "" | ||
}; | ||
return { | ||
@@ -413,65 +395,64 @@ stream: response.pipeThrough( | ||
const value = chunk.value; | ||
const type = value.event_type; | ||
const type = value.type; | ||
switch (type) { | ||
case "text-generation": { | ||
case "content-delta": { | ||
controller.enqueue({ | ||
type: "text-delta", | ||
textDelta: value.text | ||
textDelta: value.delta.message.content.text | ||
}); | ||
return; | ||
} | ||
case "tool-calls-chunk": { | ||
if (value.tool_call_delta) { | ||
const { index } = value.tool_call_delta; | ||
if (toolCalls[index] === void 0) { | ||
const toolCallId = generateId2(); | ||
toolCalls[index] = { | ||
toolCallId, | ||
toolName: "" | ||
}; | ||
} | ||
if (value.tool_call_delta.name) { | ||
toolCalls[index].toolName = value.tool_call_delta.name; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallType: "function", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
argsTextDelta: "" | ||
}); | ||
} else if (value.tool_call_delta.parameters) { | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallType: "function", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
argsTextDelta: value.tool_call_delta.parameters | ||
}); | ||
} | ||
} | ||
case "tool-call-start": { | ||
pendingToolCallDelta = { | ||
toolCallId: value.delta.message.tool_calls.id, | ||
toolName: value.delta.message.tool_calls.function.name, | ||
argsTextDelta: value.delta.message.tool_calls.function.arguments | ||
}; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
argsTextDelta: pendingToolCallDelta.argsTextDelta | ||
}); | ||
return; | ||
} | ||
case "tool-calls-generation": { | ||
for (let index = 0; index < value.tool_calls.length; index++) { | ||
const toolCall = value.tool_calls[index]; | ||
controller.enqueue({ | ||
type: "tool-call", | ||
toolCallId: toolCalls[index].toolCallId, | ||
toolName: toolCalls[index].toolName, | ||
toolCallType: "function", | ||
args: JSON.stringify(toolCall.parameters) | ||
}); | ||
} | ||
case "tool-call-delta": { | ||
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments; | ||
controller.enqueue({ | ||
type: "tool-call-delta", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
argsTextDelta: value.delta.message.tool_calls.function.arguments | ||
}); | ||
return; | ||
} | ||
case "stream-start": { | ||
case "tool-call-end": { | ||
controller.enqueue({ | ||
type: "tool-call", | ||
toolCallId: pendingToolCallDelta.toolCallId, | ||
toolName: pendingToolCallDelta.toolName, | ||
toolCallType: "function", | ||
args: JSON.stringify( | ||
JSON.parse(pendingToolCallDelta.argsTextDelta) | ||
) | ||
}); | ||
pendingToolCallDelta = { | ||
toolCallId: "", | ||
toolName: "", | ||
argsTextDelta: "" | ||
}; | ||
return; | ||
} | ||
case "message-start": { | ||
controller.enqueue({ | ||
type: "response-metadata", | ||
id: (_a = value.generation_id) != null ? _a : void 0 | ||
id: (_a = value.id) != null ? _a : void 0 | ||
}); | ||
return; | ||
} | ||
case "stream-end": { | ||
finishReason = mapCohereFinishReason(value.finish_reason); | ||
const tokens = value.response.meta.tokens; | ||
case "message-end": { | ||
finishReason = mapCohereFinishReason(value.delta.finish_reason); | ||
const tokens = value.delta.usage.tokens; | ||
usage = { | ||
@@ -498,4 +479,3 @@ promptTokens: tokens.input_tokens, | ||
rawPrompt: { | ||
chat_history, | ||
message | ||
messages | ||
}, | ||
@@ -512,11 +492,27 @@ rawSettings | ||
generation_id: import_zod2.z.string().nullish(), | ||
text: import_zod2.z.string(), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
parameters: import_zod2.z.unknown({}) | ||
}) | ||
).nullish(), | ||
message: import_zod2.z.object({ | ||
role: import_zod2.z.string(), | ||
content: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
type: import_zod2.z.string(), | ||
text: import_zod2.z.string() | ||
}) | ||
).nullish(), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
id: import_zod2.z.string(), | ||
type: import_zod2.z.literal("function"), | ||
function: import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
).nullish() | ||
}), | ||
finish_reason: import_zod2.z.string(), | ||
meta: import_zod2.z.object({ | ||
usage: import_zod2.z.object({ | ||
billed_units: import_zod2.z.object({ | ||
input_tokens: import_zod2.z.number(), | ||
output_tokens: import_zod2.z.number() | ||
}), | ||
tokens: import_zod2.z.object({ | ||
@@ -528,43 +524,34 @@ input_tokens: import_zod2.z.number(), | ||
}); | ||
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("event_type", [ | ||
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [ | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("stream-start"), | ||
generation_id: import_zod2.z.string().nullish() | ||
type: import_zod2.z.literal("citation-start") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("search-queries-generation") | ||
type: import_zod2.z.literal("citation-end") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("search-results") | ||
type: import_zod2.z.literal("content-start") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("text-generation"), | ||
text: import_zod2.z.string() | ||
type: import_zod2.z.literal("content-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
content: import_zod2.z.object({ | ||
text: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("citation-generation") | ||
type: import_zod2.z.literal("content-end") | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("tool-calls-generation"), | ||
tool_calls: import_zod2.z.array( | ||
import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
parameters: import_zod2.z.unknown({}) | ||
}) | ||
) | ||
type: import_zod2.z.literal("message-start"), | ||
id: import_zod2.z.string().nullish() | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("tool-calls-chunk"), | ||
text: import_zod2.z.string().optional(), | ||
tool_call_delta: import_zod2.z.object({ | ||
index: import_zod2.z.number(), | ||
name: import_zod2.z.string().optional(), | ||
parameters: import_zod2.z.string().optional() | ||
}).optional() | ||
}), | ||
import_zod2.z.object({ | ||
event_type: import_zod2.z.literal("stream-end"), | ||
finish_reason: import_zod2.z.string(), | ||
response: import_zod2.z.object({ | ||
meta: import_zod2.z.object({ | ||
type: import_zod2.z.literal("message-end"), | ||
delta: import_zod2.z.object({ | ||
finish_reason: import_zod2.z.string(), | ||
usage: import_zod2.z.object({ | ||
tokens: import_zod2.z.object({ | ||
@@ -576,2 +563,44 @@ input_tokens: import_zod2.z.number(), | ||
}) | ||
}), | ||
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-plan-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_plan: import_zod2.z.string() | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-start"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_calls: import_zod2.z.object({ | ||
id: import_zod2.z.string(), | ||
type: import_zod2.z.literal("function"), | ||
function: import_zod2.z.object({ | ||
name: import_zod2.z.string(), | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}) | ||
}), | ||
// A single tool call's `arguments` stream in chunks and must be accumulated | ||
// in a string and so the full tool object info can only be parsed once we see | ||
// `tool-call-end`. | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-delta"), | ||
delta: import_zod2.z.object({ | ||
message: import_zod2.z.object({ | ||
tool_calls: import_zod2.z.object({ | ||
function: import_zod2.z.object({ | ||
arguments: import_zod2.z.string() | ||
}) | ||
}) | ||
}) | ||
}) | ||
}), | ||
import_zod2.z.object({ | ||
type: import_zod2.z.literal("tool-call-end") | ||
}) | ||
@@ -615,2 +644,7 @@ ]); | ||
model: this.modelId, | ||
// TODO(shaper): There are other embedding types. Do we need to support them? | ||
// For now we only support 'float' embeddings which are also the only ones | ||
// the Cohere API docs state are supported for all models. | ||
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types | ||
embedding_types: ["float"], | ||
texts: values, | ||
@@ -628,3 +662,3 @@ input_type: (_a = this.settings.inputType) != null ? _a : "search_query", | ||
return { | ||
embeddings: response.embeddings, | ||
embeddings: response.embeddings.float, | ||
usage: { tokens: response.meta.billed_units.input_tokens }, | ||
@@ -636,3 +670,5 @@ rawResponse: { headers: responseHeaders } | ||
var cohereTextEmbeddingResponseSchema = import_zod3.z.object({ | ||
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())), | ||
embeddings: import_zod3.z.object({ | ||
float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())) | ||
}), | ||
meta: import_zod3.z.object({ | ||
@@ -648,3 +684,3 @@ billed_units: import_zod3.z.object({ | ||
var _a; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2"; | ||
const getHeaders = () => ({ | ||
@@ -658,12 +694,8 @@ Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({ | ||
}); | ||
const createChatModel = (modelId, settings = {}) => { | ||
var _a2; | ||
return new CohereChatLanguageModel(modelId, settings, { | ||
provider: "cohere.chat", | ||
baseURL, | ||
headers: getHeaders, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId, | ||
fetch: options.fetch | ||
}); | ||
}; | ||
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, { | ||
provider: "cohere.chat", | ||
baseURL, | ||
headers: getHeaders, | ||
fetch: options.fetch | ||
}); | ||
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, { | ||
@@ -670,0 +702,0 @@ provider: "cohere.textEmbedding", |
{ | ||
"name": "@ai-sdk/cohere", | ||
"version": "0.0.27", | ||
"version": "0.0.28", | ||
"license": "Apache-2.0", | ||
@@ -5,0 +5,0 @@ "sideEffects": false, |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
162529
2117