@langchain/openai
Advanced tools
Comparing version 0.0.27 to 0.0.28
import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; | ||
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; | ||
import { type BaseMessage } from "@langchain/core/messages"; | ||
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages"; | ||
import { ChatGenerationChunk, type ChatResult } from "@langchain/core/outputs"; | ||
@@ -9,3 +9,3 @@ import { type StructuredToolInterface } from "@langchain/core/tools"; | ||
import { z } from "zod"; | ||
import { Runnable } from "@langchain/core/runnables"; | ||
import { Runnable, RunnableInterface } from "@langchain/core/runnables"; | ||
import type { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput, OpenAICoreRequestOptions, LegacyOpenAIInput } from "./types.js"; | ||
@@ -66,3 +66,3 @@ export type { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput }; | ||
*/ | ||
export declare class ChatOpenAI<CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions> extends BaseChatModel<CallOptions> implements OpenAIChatInput, AzureOpenAIInput { | ||
export declare class ChatOpenAI<CallOptions extends ChatOpenAICallOptions = ChatOpenAICallOptions> extends BaseChatModel<CallOptions, AIMessageChunk> implements OpenAIChatInput, AzureOpenAIInput { | ||
static lc_name(): string; | ||
@@ -107,2 +107,3 @@ get callKeys(): string[]; | ||
configuration?: ClientOptions & LegacyOpenAIInput); | ||
bindTools(tools: (Record<string, unknown> | StructuredToolInterface)[], kwargs?: Partial<CallOptions>): RunnableInterface<BaseLanguageModelInput, AIMessageChunk, CallOptions>; | ||
/** | ||
@@ -109,0 +110,0 @@ * Get the parameters used to invoke the model |
import { OpenAI as OpenAIClient } from "openai"; | ||
import { AIMessage, AIMessageChunk, ChatMessage, ChatMessageChunk, FunctionMessageChunk, HumanMessageChunk, SystemMessageChunk, ToolMessageChunk, } from "@langchain/core/messages"; | ||
import { AIMessage, AIMessageChunk, ChatMessage, ChatMessageChunk, FunctionMessageChunk, HumanMessageChunk, SystemMessageChunk, ToolMessageChunk, isAIMessage, } from "@langchain/core/messages"; | ||
import { ChatGenerationChunk, } from "@langchain/core/outputs"; | ||
@@ -9,3 +9,3 @@ import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
import { JsonOutputParser, StructuredOutputParser, } from "@langchain/core/output_parsers"; | ||
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"; | ||
import { JsonOutputKeyToolsParser, convertLangChainToolCallToOpenAI, makeInvalidToolCall, parseToolCall, } from "@langchain/core/output_parsers/openai_tools"; | ||
import { zodToJsonSchema } from "zod-to-json-schema"; | ||
@@ -48,8 +48,26 @@ import { getEndpoint } from "./utils/azure.js"; | ||
function openAIResponseToChatMessage(message) { | ||
const rawToolCalls = message.tool_calls; | ||
switch (message.role) { | ||
case "assistant": | ||
return new AIMessage(message.content || "", { | ||
function_call: message.function_call, | ||
tool_calls: message.tool_calls, | ||
case "assistant": { | ||
const toolCalls = []; | ||
const invalidToolCalls = []; | ||
for (const rawToolCall of rawToolCalls ?? []) { | ||
try { | ||
toolCalls.push(parseToolCall(rawToolCall, { returnId: true })); | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
} | ||
catch (e) { | ||
invalidToolCalls.push(makeInvalidToolCall(rawToolCall, e.message)); | ||
} | ||
} | ||
return new AIMessage({ | ||
content: message.content || "", | ||
tool_calls: toolCalls, | ||
invalid_tool_calls: invalidToolCalls, | ||
additional_kwargs: { | ||
function_call: message.function_call, | ||
tool_calls: rawToolCalls, | ||
}, | ||
}); | ||
} | ||
default: | ||
@@ -82,3 +100,18 @@ return new ChatMessage(message.content || "", message.role ?? "unknown"); | ||
else if (role === "assistant") { | ||
return new AIMessageChunk({ content, additional_kwargs }); | ||
const toolCallChunks = []; | ||
if (Array.isArray(delta.tool_calls)) { | ||
for (const rawToolCall of delta.tool_calls) { | ||
toolCallChunks.push({ | ||
name: rawToolCall.function?.name, | ||
args: rawToolCall.function?.arguments, | ||
id: rawToolCall.id, | ||
index: rawToolCall.index, | ||
}); | ||
} | ||
} | ||
return new AIMessageChunk({ | ||
content, | ||
tool_call_chunks: toolCallChunks, | ||
additional_kwargs, | ||
}); | ||
} | ||
@@ -120,7 +153,12 @@ else if (role === "system") { | ||
} | ||
if (message.additional_kwargs.tool_calls != null) { | ||
completionParam.tool_calls = message.additional_kwargs.tool_calls; | ||
if (isAIMessage(message) && !!message.tool_calls?.length) { | ||
completionParam.tool_calls = message.tool_calls.map(convertLangChainToolCallToOpenAI); | ||
} | ||
if (message.tool_call_id != null) { | ||
completionParam.tool_call_id = message.tool_call_id; | ||
else { | ||
if (message.additional_kwargs.tool_calls != null) { | ||
completionParam.tool_calls = message.additional_kwargs.tool_calls; | ||
} | ||
if (message.tool_call_id != null) { | ||
completionParam.tool_call_id = message.tool_call_id; | ||
} | ||
} | ||
@@ -440,2 +478,8 @@ return completionParam; | ||
} | ||
bindTools(tools, kwargs) { | ||
return this.bind({ | ||
tools: tools.map(convertToOpenAITool), | ||
...kwargs, | ||
}); | ||
} | ||
/** | ||
@@ -442,0 +486,0 @@ * Get the parameters used to invoke the model |
import { test, expect, jest } from "@jest/globals"; | ||
import { HumanMessage, ToolMessage } from "@langchain/core/messages"; | ||
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; | ||
import { InMemoryCache } from "@langchain/core/caches"; | ||
@@ -63,3 +63,6 @@ import { ChatOpenAI } from "../chat_models.js"; | ||
console.log(JSON.stringify(res)); | ||
expect(res.additional_kwargs.tool_calls?.length).toBeGreaterThan(1); | ||
expect(res.additional_kwargs.tool_calls?.length).toEqual(3); | ||
expect(res.tool_calls?.[0].args).toEqual(JSON.parse(res.additional_kwargs.tool_calls?.[0].function.arguments ?? "{}")); | ||
expect(res.tool_calls?.[1].args).toEqual(JSON.parse(res.additional_kwargs.tool_calls?.[1].function.arguments ?? "{}")); | ||
expect(res.tool_calls?.[2].args).toEqual(JSON.parse(res.additional_kwargs.tool_calls?.[2].function.arguments ?? "{}")); | ||
}); | ||
@@ -207,1 +210,52 @@ test("Test ChatOpenAI streaming logprobs", async () => { | ||
}); | ||
test("Few shotting with tool calls", async () => { | ||
const chat = new ChatOpenAI({ | ||
modelName: "gpt-3.5-turbo-1106", | ||
temperature: 1, | ||
}).bind({ | ||
tools: [ | ||
{ | ||
type: "function", | ||
function: { | ||
name: "get_current_weather", | ||
description: "Get the current weather in a given location", | ||
parameters: { | ||
type: "object", | ||
properties: { | ||
location: { | ||
type: "string", | ||
description: "The city and state, e.g. San Francisco, CA", | ||
}, | ||
unit: { type: "string", enum: ["celsius", "fahrenheit"] }, | ||
}, | ||
required: ["location"], | ||
}, | ||
}, | ||
}, | ||
], | ||
tool_choice: "auto", | ||
}); | ||
const res = await chat.invoke([ | ||
new HumanMessage("What is the weather in SF?"), | ||
new AIMessage({ | ||
content: "", | ||
tool_calls: [ | ||
{ | ||
id: "12345", | ||
name: "get_current_weather", | ||
args: { | ||
location: "SF", | ||
}, | ||
}, | ||
], | ||
}), | ||
new ToolMessage({ | ||
tool_call_id: "12345", | ||
content: "It is currently 24 degrees with hail in SF.", | ||
}), | ||
new AIMessage("It is currently 24 degrees in SF with hail in SF."), | ||
new HumanMessage("What did you say the weather was?"), | ||
]); | ||
console.log(res); | ||
expect(res.content).toContain("24"); | ||
}); |
{ | ||
"name": "@langchain/openai", | ||
"version": "0.0.27", | ||
"version": "0.0.28", | ||
"description": "OpenAI integrations for LangChain.js", | ||
@@ -42,3 +42,3 @@ "type": "module", | ||
"dependencies": { | ||
"@langchain/core": "~0.1.45", | ||
"@langchain/core": "~0.1.56", | ||
"js-tiktoken": "^1.0.7", | ||
@@ -45,0 +45,0 @@ "openai": "^4.32.1", |
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
301577
7803
Updated@langchain/core@~0.1.56