Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

@ai-sdk/cohere

Package Overview
Dependencies
Maintainers
0
Versions
43
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

@ai-sdk/cohere - npm Package Compare versions

Comparing version 0.0.27 to 0.0.28

434

./dist/index.js

@@ -54,3 +54,3 @@ "use strict";

case "system": {
messages.push({ role: "SYSTEM", message: content });
messages.push({ role: "system", content });
break;

@@ -60,4 +60,4 @@ }

messages.push({
role: "USER",
message: content.map((part) => {
role: "user",
content: content.map((part) => {
switch (part.type) {

@@ -88,4 +88,8 @@ case "text": {

toolCalls.push({
name: part.toolName,
parameters: part.args
id: part.toolCallId,
type: "function",
function: {
name: part.toolName,
arguments: JSON.stringify(part.args)
}
});

@@ -101,4 +105,7 @@ break;

messages.push({
role: "CHATBOT",
message: text,
role: "assistant",
// note: this is a workaround for a Cohere API bug
// that requires content to be provided
// even if there are tool calls
content: text !== "" ? text : "call tool",
tool_calls: toolCalls.length > 0 ? toolCalls : void 0

@@ -109,20 +116,9 @@ });

case "tool": {
messages.push({
role: "TOOL",
tool_results: content.map((toolResult) => ({
call: {
name: toolResult.toolName,
/*
Note: Currently the tool_results field requires we pass the parameters of the tool results again. It it is blank for two reasons:
1. The parameters are already present in chat_history as a tool message
2. The tool core message of the ai sdk does not include parameters
It is possible to traverse through the chat history and get the parameters by id but it's currently empty since there wasn't any degradation in the output when left blank.
*/
parameters: {}
},
outputs: [toolResult.result]
messages.push(
...content.map((toolResult) => ({
role: "tool",
content: JSON.stringify(toolResult.result),
tool_call_id: toolResult.toolCallId
}))
});
);
break;

@@ -166,3 +162,3 @@ }

if (tools == null) {
return { tools: void 0, force_single_step: void 0, toolWarnings };
return { tools: void 0, tool_choice: void 0, toolWarnings };
}

@@ -174,45 +170,9 @@ const cohereTools = [];

} else {
const { properties, required } = tool.parameters;
const parameterDefinitions = {};
if (properties) {
for (const [key, value] of Object.entries(properties)) {
if (typeof value === "object" && value !== null) {
const { type: JSONType, description } = value;
let type2;
if (typeof JSONType === "string") {
switch (JSONType) {
case "string":
type2 = "str";
break;
case "number":
type2 = "float";
break;
case "integer":
type2 = "int";
break;
case "boolean":
type2 = "bool";
break;
default:
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool parameter type: ${JSONType}`
});
}
} else {
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool parameter type: ${JSONType}`
});
}
parameterDefinitions[key] = {
required: required ? required.includes(key) : false,
type: type2,
description
};
}
cohereTools.push({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
}
cohereTools.push({
name: tool.name,
description: tool.description,
parameterDefinitions
});

@@ -223,3 +183,3 @@ }

if (toolChoice == null) {
return { tools: cohereTools, force_single_step: false, toolWarnings };
return { tools: cohereTools, tool_choice: void 0, toolWarnings };
}

@@ -229,13 +189,10 @@ const type = toolChoice.type;

case "auto":
return { tools: cohereTools, force_single_step: false, toolWarnings };
return { tools: cohereTools, tool_choice: type, toolWarnings };
case "none":
return { tools: void 0, tool_choice: "any", toolWarnings };
case "required":
return { tools: cohereTools, force_single_step: true, toolWarnings };
case "none":
return { tools: void 0, force_single_step: false, toolWarnings };
case "tool":
return {
tools: cohereTools.filter((tool) => tool.name === toolChoice.toolName),
force_single_step: true,
toolWarnings
};
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool choice type: ${type}`
});
default: {

@@ -277,4 +234,2 @@ const _exhaustiveCheck = type;

const chatPrompt = convertToCohereChatPrompt(prompt);
const lastMessage = chatPrompt.at(-1);
const history = chatPrompt.slice(0, -1);
const baseArgs = {

@@ -297,13 +252,10 @@ // model id:

// messages:
chat_history: history,
...(lastMessage == null ? void 0 : lastMessage.role) === "TOOL" ? { tool_results: lastMessage.tool_results } : {},
message: lastMessage ? lastMessage.role === "USER" ? lastMessage.message : void 0 : void 0
messages: chatPrompt
};
switch (type) {
case "regular": {
const { tools, force_single_step, toolWarnings } = prepareTools(mode);
const { tools, tool_choice, toolWarnings } = prepareTools(mode);
return {
...baseArgs,
tools,
force_single_step,
warnings: toolWarnings

@@ -324,9 +276,37 @@ };

const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
throw new import_provider3.UnsupportedFunctionalityError({
functionality: `Unsupported mode: ${_exhaustiveCheck}`
});
}
}
}
concatenateMessageText(messages) {
return messages.filter(
(message) => "content" in message
).map((message) => message.content).join("");
}
/*
Remove `additionalProperties` and `$schema` from the `parameters` object of each tool.
Though these are part of JSON schema, Cohere chokes if we include them in the request.
*/
// TODO(shaper): Look at defining a type to simplify the params here and a couple of other places.
removeJsonSchemaExtras(tools) {
return tools.map((tool) => {
if (tool.type === "function" && tool.function.parameters && typeof tool.function.parameters === "object") {
const { additionalProperties, $schema, ...restParameters } = tool.function.parameters;
return {
...tool,
function: {
...tool.function,
parameters: restParameters
}
};
}
return tool;
});
}
async doGenerate(options) {
var _a;
var _a, _b, _c, _d;
const { warnings, ...args } = this.getArgs(options);
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({

@@ -343,10 +323,9 @@ url: `${this.config.baseURL}/chat`,

});
const { chat_history, message, ...rawSettings } = args;
const generateId2 = this.config.generateId;
const { messages, ...rawSettings } = args;
return {
text: response.text,
toolCalls: response.tool_calls ? response.tool_calls.map((toolCall) => ({
toolCallId: generateId2(),
toolName: toolCall.name,
args: JSON.stringify(toolCall.parameters),
text: (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "",
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments,
toolCallType: "function"

@@ -356,9 +335,8 @@ })) : [],

usage: {
promptTokens: response.meta.tokens.input_tokens,
completionTokens: response.meta.tokens.output_tokens
promptTokens: response.usage.tokens.input_tokens,
completionTokens: response.usage.tokens.output_tokens
},
rawCall: {
rawPrompt: {
chat_history,
message
messages
},

@@ -368,3 +346,3 @@ rawSettings

response: {
id: (_a = response.generation_id) != null ? _a : void 0
id: (_d = response.generation_id) != null ? _d : void 0
},

@@ -378,2 +356,3 @@ rawResponse: { headers: responseHeaders },

const { warnings, ...args } = this.getArgs(options);
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
const body = { ...args, stream: true };

@@ -385,3 +364,3 @@ const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({

failedResponseHandler: cohereFailedResponseHandler,
successfulResponseHandler: (0, import_provider_utils2.createJsonStreamResponseHandler)(
successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)(
cohereChatChunkSchema

@@ -392,3 +371,3 @@ ),

});
const { chat_history, message, ...rawSettings } = args;
const { messages, ...rawSettings } = args;
let finishReason = "unknown";

@@ -399,4 +378,7 @@ let usage = {

};
const generateId2 = this.config.generateId;
const toolCalls = [];
let pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return {

@@ -413,65 +395,64 @@ stream: response.pipeThrough(

const value = chunk.value;
const type = value.event_type;
const type = value.type;
switch (type) {
case "text-generation": {
case "content-delta": {
controller.enqueue({
type: "text-delta",
textDelta: value.text
textDelta: value.delta.message.content.text
});
return;
}
case "tool-calls-chunk": {
if (value.tool_call_delta) {
const { index } = value.tool_call_delta;
if (toolCalls[index] === void 0) {
const toolCallId = generateId2();
toolCalls[index] = {
toolCallId,
toolName: ""
};
}
if (value.tool_call_delta.name) {
toolCalls[index].toolName = value.tool_call_delta.name;
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
argsTextDelta: ""
});
} else if (value.tool_call_delta.parameters) {
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
argsTextDelta: value.tool_call_delta.parameters
});
}
}
case "tool-call-start": {
pendingToolCallDelta = {
toolCallId: value.delta.message.tool_calls.id,
toolName: value.delta.message.tool_calls.function.name,
argsTextDelta: value.delta.message.tool_calls.function.arguments
};
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: pendingToolCallDelta.argsTextDelta
});
return;
}
case "tool-calls-generation": {
for (let index = 0; index < value.tool_calls.length; index++) {
const toolCall = value.tool_calls[index];
controller.enqueue({
type: "tool-call",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
toolCallType: "function",
args: JSON.stringify(toolCall.parameters)
});
}
case "tool-call-delta": {
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments;
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: value.delta.message.tool_calls.function.arguments
});
return;
}
case "stream-start": {
case "tool-call-end": {
controller.enqueue({
type: "tool-call",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
args: JSON.stringify(
JSON.parse(pendingToolCallDelta.argsTextDelta)
)
});
pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return;
}
case "message-start": {
controller.enqueue({
type: "response-metadata",
id: (_a = value.generation_id) != null ? _a : void 0
id: (_a = value.id) != null ? _a : void 0
});
return;
}
case "stream-end": {
finishReason = mapCohereFinishReason(value.finish_reason);
const tokens = value.response.meta.tokens;
case "message-end": {
finishReason = mapCohereFinishReason(value.delta.finish_reason);
const tokens = value.delta.usage.tokens;
usage = {

@@ -498,4 +479,3 @@ promptTokens: tokens.input_tokens,

rawPrompt: {
chat_history,
message
messages
},

@@ -512,11 +492,27 @@ rawSettings

generation_id: import_zod2.z.string().nullish(),
text: import_zod2.z.string(),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
name: import_zod2.z.string(),
parameters: import_zod2.z.unknown({})
})
).nullish(),
message: import_zod2.z.object({
role: import_zod2.z.string(),
content: import_zod2.z.array(
import_zod2.z.object({
type: import_zod2.z.string(),
text: import_zod2.z.string()
})
).nullish(),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
).nullish()
}),
finish_reason: import_zod2.z.string(),
meta: import_zod2.z.object({
usage: import_zod2.z.object({
billed_units: import_zod2.z.object({
input_tokens: import_zod2.z.number(),
output_tokens: import_zod2.z.number()
}),
tokens: import_zod2.z.object({

@@ -528,43 +524,34 @@ input_tokens: import_zod2.z.number(),

});
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("event_type", [
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [
import_zod2.z.object({
event_type: import_zod2.z.literal("stream-start"),
generation_id: import_zod2.z.string().nullish()
type: import_zod2.z.literal("citation-start")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("search-queries-generation")
type: import_zod2.z.literal("citation-end")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("search-results")
type: import_zod2.z.literal("content-start")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("text-generation"),
text: import_zod2.z.string()
type: import_zod2.z.literal("content-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
content: import_zod2.z.object({
text: import_zod2.z.string()
})
})
})
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("citation-generation")
type: import_zod2.z.literal("content-end")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("tool-calls-generation"),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
name: import_zod2.z.string(),
parameters: import_zod2.z.unknown({})
})
)
type: import_zod2.z.literal("message-start"),
id: import_zod2.z.string().nullish()
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("tool-calls-chunk"),
text: import_zod2.z.string().optional(),
tool_call_delta: import_zod2.z.object({
index: import_zod2.z.number(),
name: import_zod2.z.string().optional(),
parameters: import_zod2.z.string().optional()
}).optional()
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("stream-end"),
finish_reason: import_zod2.z.string(),
response: import_zod2.z.object({
meta: import_zod2.z.object({
type: import_zod2.z.literal("message-end"),
delta: import_zod2.z.object({
finish_reason: import_zod2.z.string(),
usage: import_zod2.z.object({
tokens: import_zod2.z.object({

@@ -576,2 +563,44 @@ input_tokens: import_zod2.z.number(),

})
}),
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling
import_zod2.z.object({
type: import_zod2.z.literal("tool-plan-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_plan: import_zod2.z.string()
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-start"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
})
})
}),
// A single tool call's `arguments` stream in chunks and must be accumulated
// in a string and so the full tool object info can only be parsed once we see
// `tool-call-end`.
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
function: import_zod2.z.object({
arguments: import_zod2.z.string()
})
})
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-end")
})

@@ -615,2 +644,7 @@ ]);

model: this.modelId,
// TODO(shaper): There are other embedding types. Do we need to support them?
// For now we only support 'float' embeddings which are also the only ones
// the Cohere API docs state are supported for all models.
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
embedding_types: ["float"],
texts: values,

@@ -628,3 +662,3 @@ input_type: (_a = this.settings.inputType) != null ? _a : "search_query",

return {
embeddings: response.embeddings,
embeddings: response.embeddings.float,
usage: { tokens: response.meta.billed_units.input_tokens },

@@ -636,3 +670,5 @@ rawResponse: { headers: responseHeaders }

var cohereTextEmbeddingResponseSchema = import_zod3.z.object({
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())),
embeddings: import_zod3.z.object({
float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number()))
}),
meta: import_zod3.z.object({

@@ -648,3 +684,3 @@ billed_units: import_zod3.z.object({

var _a;
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1";
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2";
const getHeaders = () => ({

@@ -658,12 +694,8 @@ Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({

});
const createChatModel = (modelId, settings = {}) => {
var _a2;
return new CohereChatLanguageModel(modelId, settings, {
provider: "cohere.chat",
baseURL,
headers: getHeaders,
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId,
fetch: options.fetch
});
};
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
provider: "cohere.chat",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {

@@ -670,0 +702,0 @@ provider: "cohere.textEmbedding",

# @ai-sdk/cohere
## 0.0.28
### Patch Changes
- a7cbdf6: feat (provider/cohere): Use Cohere v2 API
## 0.0.27

@@ -4,0 +10,0 @@

@@ -42,3 +42,3 @@ import { ProviderV1, LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider';

Use a different URL prefix for API calls, e.g. to use proxy servers.
The default prefix is `https://api.cohere.com/v1`.
The default prefix is `https://api.cohere.com/v2`.
*/

@@ -60,3 +60,2 @@ baseURL?: string;

fetch?: FetchFunction;
generateId?: () => string;
}

@@ -63,0 +62,0 @@ /**

@@ -54,3 +54,3 @@ "use strict";

case "system": {
messages.push({ role: "SYSTEM", message: content });
messages.push({ role: "system", content });
break;

@@ -60,4 +60,4 @@ }

messages.push({
role: "USER",
message: content.map((part) => {
role: "user",
content: content.map((part) => {
switch (part.type) {

@@ -88,4 +88,8 @@ case "text": {

toolCalls.push({
name: part.toolName,
parameters: part.args
id: part.toolCallId,
type: "function",
function: {
name: part.toolName,
arguments: JSON.stringify(part.args)
}
});

@@ -101,4 +105,7 @@ break;

messages.push({
role: "CHATBOT",
message: text,
role: "assistant",
// note: this is a workaround for a Cohere API bug
// that requires content to be provided
// even if there are tool calls
content: text !== "" ? text : "call tool",
tool_calls: toolCalls.length > 0 ? toolCalls : void 0

@@ -109,20 +116,9 @@ });

case "tool": {
messages.push({
role: "TOOL",
tool_results: content.map((toolResult) => ({
call: {
name: toolResult.toolName,
/*
Note: Currently the tool_results field requires we pass the parameters of the tool results again. It it is blank for two reasons:
1. The parameters are already present in chat_history as a tool message
2. The tool core message of the ai sdk does not include parameters
It is possible to traverse through the chat history and get the parameters by id but it's currently empty since there wasn't any degradation in the output when left blank.
*/
parameters: {}
},
outputs: [toolResult.result]
messages.push(
...content.map((toolResult) => ({
role: "tool",
content: JSON.stringify(toolResult.result),
tool_call_id: toolResult.toolCallId
}))
});
);
break;

@@ -166,3 +162,3 @@ }

if (tools == null) {
return { tools: void 0, force_single_step: void 0, toolWarnings };
return { tools: void 0, tool_choice: void 0, toolWarnings };
}

@@ -174,45 +170,9 @@ const cohereTools = [];

} else {
const { properties, required } = tool.parameters;
const parameterDefinitions = {};
if (properties) {
for (const [key, value] of Object.entries(properties)) {
if (typeof value === "object" && value !== null) {
const { type: JSONType, description } = value;
let type2;
if (typeof JSONType === "string") {
switch (JSONType) {
case "string":
type2 = "str";
break;
case "number":
type2 = "float";
break;
case "integer":
type2 = "int";
break;
case "boolean":
type2 = "bool";
break;
default:
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool parameter type: ${JSONType}`
});
}
} else {
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool parameter type: ${JSONType}`
});
}
parameterDefinitions[key] = {
required: required ? required.includes(key) : false,
type: type2,
description
};
}
cohereTools.push({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
}
cohereTools.push({
name: tool.name,
description: tool.description,
parameterDefinitions
});

@@ -223,3 +183,3 @@ }

if (toolChoice == null) {
return { tools: cohereTools, force_single_step: false, toolWarnings };
return { tools: cohereTools, tool_choice: void 0, toolWarnings };
}

@@ -229,13 +189,10 @@ const type = toolChoice.type;

case "auto":
return { tools: cohereTools, force_single_step: false, toolWarnings };
return { tools: cohereTools, tool_choice: type, toolWarnings };
case "none":
return { tools: void 0, tool_choice: "any", toolWarnings };
case "required":
return { tools: cohereTools, force_single_step: true, toolWarnings };
case "none":
return { tools: void 0, force_single_step: false, toolWarnings };
case "tool":
return {
tools: cohereTools.filter((tool) => tool.name === toolChoice.toolName),
force_single_step: true,
toolWarnings
};
throw new import_provider2.UnsupportedFunctionalityError({
functionality: `Unsupported tool choice type: ${type}`
});
default: {

@@ -277,4 +234,2 @@ const _exhaustiveCheck = type;

const chatPrompt = convertToCohereChatPrompt(prompt);
const lastMessage = chatPrompt.at(-1);
const history = chatPrompt.slice(0, -1);
const baseArgs = {

@@ -297,13 +252,10 @@ // model id:

// messages:
chat_history: history,
...(lastMessage == null ? void 0 : lastMessage.role) === "TOOL" ? { tool_results: lastMessage.tool_results } : {},
message: lastMessage ? lastMessage.role === "USER" ? lastMessage.message : void 0 : void 0
messages: chatPrompt
};
switch (type) {
case "regular": {
const { tools, force_single_step, toolWarnings } = prepareTools(mode);
const { tools, tool_choice, toolWarnings } = prepareTools(mode);
return {
...baseArgs,
tools,
force_single_step,
warnings: toolWarnings

@@ -324,9 +276,37 @@ };

const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
throw new import_provider3.UnsupportedFunctionalityError({
functionality: `Unsupported mode: ${_exhaustiveCheck}`
});
}
}
}
concatenateMessageText(messages) {
return messages.filter(
(message) => "content" in message
).map((message) => message.content).join("");
}
/*
Remove `additionalProperties` and `$schema` from the `parameters` object of each tool.
Though these are part of JSON schema, Cohere chokes if we include them in the request.
*/
// TODO(shaper): Look at defining a type to simplify the params here and a couple of other places.
removeJsonSchemaExtras(tools) {
return tools.map((tool) => {
if (tool.type === "function" && tool.function.parameters && typeof tool.function.parameters === "object") {
const { additionalProperties, $schema, ...restParameters } = tool.function.parameters;
return {
...tool,
function: {
...tool.function,
parameters: restParameters
}
};
}
return tool;
});
}
async doGenerate(options) {
var _a;
var _a, _b, _c, _d;
const { warnings, ...args } = this.getArgs(options);
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({

@@ -343,10 +323,9 @@ url: `${this.config.baseURL}/chat`,

});
const { chat_history, message, ...rawSettings } = args;
const generateId2 = this.config.generateId;
const { messages, ...rawSettings } = args;
return {
text: response.text,
toolCalls: response.tool_calls ? response.tool_calls.map((toolCall) => ({
toolCallId: generateId2(),
toolName: toolCall.name,
args: JSON.stringify(toolCall.parameters),
text: (_c = (_b = (_a = response.message.content) == null ? void 0 : _a[0]) == null ? void 0 : _b.text) != null ? _c : "",
toolCalls: response.message.tool_calls ? response.message.tool_calls.map((toolCall) => ({
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments,
toolCallType: "function"

@@ -356,9 +335,8 @@ })) : [],

usage: {
promptTokens: response.meta.tokens.input_tokens,
completionTokens: response.meta.tokens.output_tokens
promptTokens: response.usage.tokens.input_tokens,
completionTokens: response.usage.tokens.output_tokens
},
rawCall: {
rawPrompt: {
chat_history,
message
messages
},

@@ -368,3 +346,3 @@ rawSettings

response: {
id: (_a = response.generation_id) != null ? _a : void 0
id: (_d = response.generation_id) != null ? _d : void 0
},

@@ -378,2 +356,3 @@ rawResponse: { headers: responseHeaders },

const { warnings, ...args } = this.getArgs(options);
args.tools = args.tools && this.removeJsonSchemaExtras(args.tools);
const body = { ...args, stream: true };

@@ -385,3 +364,3 @@ const { responseHeaders, value: response } = await (0, import_provider_utils2.postJsonToApi)({

failedResponseHandler: cohereFailedResponseHandler,
successfulResponseHandler: (0, import_provider_utils2.createJsonStreamResponseHandler)(
successfulResponseHandler: (0, import_provider_utils2.createEventSourceResponseHandler)(
cohereChatChunkSchema

@@ -392,3 +371,3 @@ ),

});
const { chat_history, message, ...rawSettings } = args;
const { messages, ...rawSettings } = args;
let finishReason = "unknown";

@@ -399,4 +378,7 @@ let usage = {

};
const generateId2 = this.config.generateId;
const toolCalls = [];
let pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return {

@@ -413,65 +395,64 @@ stream: response.pipeThrough(

const value = chunk.value;
const type = value.event_type;
const type = value.type;
switch (type) {
case "text-generation": {
case "content-delta": {
controller.enqueue({
type: "text-delta",
textDelta: value.text
textDelta: value.delta.message.content.text
});
return;
}
case "tool-calls-chunk": {
if (value.tool_call_delta) {
const { index } = value.tool_call_delta;
if (toolCalls[index] === void 0) {
const toolCallId = generateId2();
toolCalls[index] = {
toolCallId,
toolName: ""
};
}
if (value.tool_call_delta.name) {
toolCalls[index].toolName = value.tool_call_delta.name;
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
argsTextDelta: ""
});
} else if (value.tool_call_delta.parameters) {
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
argsTextDelta: value.tool_call_delta.parameters
});
}
}
case "tool-call-start": {
pendingToolCallDelta = {
toolCallId: value.delta.message.tool_calls.id,
toolName: value.delta.message.tool_calls.function.name,
argsTextDelta: value.delta.message.tool_calls.function.arguments
};
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: pendingToolCallDelta.argsTextDelta
});
return;
}
case "tool-calls-generation": {
for (let index = 0; index < value.tool_calls.length; index++) {
const toolCall = value.tool_calls[index];
controller.enqueue({
type: "tool-call",
toolCallId: toolCalls[index].toolCallId,
toolName: toolCalls[index].toolName,
toolCallType: "function",
args: JSON.stringify(toolCall.parameters)
});
}
case "tool-call-delta": {
pendingToolCallDelta.argsTextDelta += value.delta.message.tool_calls.function.arguments;
controller.enqueue({
type: "tool-call-delta",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
argsTextDelta: value.delta.message.tool_calls.function.arguments
});
return;
}
case "stream-start": {
case "tool-call-end": {
controller.enqueue({
type: "tool-call",
toolCallId: pendingToolCallDelta.toolCallId,
toolName: pendingToolCallDelta.toolName,
toolCallType: "function",
args: JSON.stringify(
JSON.parse(pendingToolCallDelta.argsTextDelta)
)
});
pendingToolCallDelta = {
toolCallId: "",
toolName: "",
argsTextDelta: ""
};
return;
}
case "message-start": {
controller.enqueue({
type: "response-metadata",
id: (_a = value.generation_id) != null ? _a : void 0
id: (_a = value.id) != null ? _a : void 0
});
return;
}
case "stream-end": {
finishReason = mapCohereFinishReason(value.finish_reason);
const tokens = value.response.meta.tokens;
case "message-end": {
finishReason = mapCohereFinishReason(value.delta.finish_reason);
const tokens = value.delta.usage.tokens;
usage = {

@@ -498,4 +479,3 @@ promptTokens: tokens.input_tokens,

rawPrompt: {
chat_history,
message
messages
},

@@ -512,11 +492,27 @@ rawSettings

generation_id: import_zod2.z.string().nullish(),
text: import_zod2.z.string(),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
name: import_zod2.z.string(),
parameters: import_zod2.z.unknown({})
})
).nullish(),
message: import_zod2.z.object({
role: import_zod2.z.string(),
content: import_zod2.z.array(
import_zod2.z.object({
type: import_zod2.z.string(),
text: import_zod2.z.string()
})
).nullish(),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
).nullish()
}),
finish_reason: import_zod2.z.string(),
meta: import_zod2.z.object({
usage: import_zod2.z.object({
billed_units: import_zod2.z.object({
input_tokens: import_zod2.z.number(),
output_tokens: import_zod2.z.number()
}),
tokens: import_zod2.z.object({

@@ -528,43 +524,34 @@ input_tokens: import_zod2.z.number(),

});
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("event_type", [
var cohereChatChunkSchema = import_zod2.z.discriminatedUnion("type", [
import_zod2.z.object({
event_type: import_zod2.z.literal("stream-start"),
generation_id: import_zod2.z.string().nullish()
type: import_zod2.z.literal("citation-start")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("search-queries-generation")
type: import_zod2.z.literal("citation-end")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("search-results")
type: import_zod2.z.literal("content-start")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("text-generation"),
text: import_zod2.z.string()
type: import_zod2.z.literal("content-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
content: import_zod2.z.object({
text: import_zod2.z.string()
})
})
})
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("citation-generation")
type: import_zod2.z.literal("content-end")
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("tool-calls-generation"),
tool_calls: import_zod2.z.array(
import_zod2.z.object({
name: import_zod2.z.string(),
parameters: import_zod2.z.unknown({})
})
)
type: import_zod2.z.literal("message-start"),
id: import_zod2.z.string().nullish()
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("tool-calls-chunk"),
text: import_zod2.z.string().optional(),
tool_call_delta: import_zod2.z.object({
index: import_zod2.z.number(),
name: import_zod2.z.string().optional(),
parameters: import_zod2.z.string().optional()
}).optional()
}),
import_zod2.z.object({
event_type: import_zod2.z.literal("stream-end"),
finish_reason: import_zod2.z.string(),
response: import_zod2.z.object({
meta: import_zod2.z.object({
type: import_zod2.z.literal("message-end"),
delta: import_zod2.z.object({
finish_reason: import_zod2.z.string(),
usage: import_zod2.z.object({
tokens: import_zod2.z.object({

@@ -576,2 +563,44 @@ input_tokens: import_zod2.z.number(),

})
}),
// https://docs.cohere.com/v2/docs/streaming#tool-use-stream-events-for-tool-calling
import_zod2.z.object({
type: import_zod2.z.literal("tool-plan-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_plan: import_zod2.z.string()
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-start"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
id: import_zod2.z.string(),
type: import_zod2.z.literal("function"),
function: import_zod2.z.object({
name: import_zod2.z.string(),
arguments: import_zod2.z.string()
})
})
})
})
}),
// A single tool call's `arguments` stream in chunks and must be accumulated
// in a string and so the full tool object info can only be parsed once we see
// `tool-call-end`.
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-delta"),
delta: import_zod2.z.object({
message: import_zod2.z.object({
tool_calls: import_zod2.z.object({
function: import_zod2.z.object({
arguments: import_zod2.z.string()
})
})
})
})
}),
import_zod2.z.object({
type: import_zod2.z.literal("tool-call-end")
})

@@ -615,2 +644,7 @@ ]);

model: this.modelId,
// TODO(shaper): There are other embedding types. Do we need to support them?
// For now we only support 'float' embeddings which are also the only ones
// the Cohere API docs state are supported for all models.
// https://docs.cohere.com/v2/reference/embed#request.body.embedding_types
embedding_types: ["float"],
texts: values,

@@ -628,3 +662,3 @@ input_type: (_a = this.settings.inputType) != null ? _a : "search_query",

return {
embeddings: response.embeddings,
embeddings: response.embeddings.float,
usage: { tokens: response.meta.billed_units.input_tokens },

@@ -636,3 +670,5 @@ rawResponse: { headers: responseHeaders }

var cohereTextEmbeddingResponseSchema = import_zod3.z.object({
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())),
embeddings: import_zod3.z.object({
float: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number()))
}),
meta: import_zod3.z.object({

@@ -648,3 +684,3 @@ billed_units: import_zod3.z.object({

var _a;
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1";
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v2";
const getHeaders = () => ({

@@ -658,12 +694,8 @@ Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({

});
const createChatModel = (modelId, settings = {}) => {
var _a2;
return new CohereChatLanguageModel(modelId, settings, {
provider: "cohere.chat",
baseURL,
headers: getHeaders,
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId,
fetch: options.fetch
});
};
const createChatModel = (modelId, settings = {}) => new CohereChatLanguageModel(modelId, settings, {
provider: "cohere.chat",
baseURL,
headers: getHeaders,
fetch: options.fetch
});
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, {

@@ -670,0 +702,0 @@ provider: "cohere.textEmbedding",

{
"name": "@ai-sdk/cohere",
"version": "0.0.27",
"version": "0.0.28",
"license": "Apache-2.0",

@@ -5,0 +5,0 @@ "sideEffects": false,

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc