@ai-sdk/cohere
Advanced tools
Comparing version 0.0.21 to 0.0.22
@@ -29,4 +29,3 @@ "use strict"; | ||
// src/cohere-provider.ts | ||
var import_provider3 = require("@ai-sdk/provider"); | ||
var import_provider_utils3 = require("@ai-sdk/provider-utils"); | ||
var import_provider_utils4 = require("@ai-sdk/provider-utils"); | ||
@@ -528,8 +527,70 @@ // src/cohere-chat-language-model.ts | ||
// src/cohere-embedding-model.ts | ||
var import_provider3 = require("@ai-sdk/provider"); | ||
var import_provider_utils3 = require("@ai-sdk/provider-utils"); | ||
var import_zod3 = require("zod"); | ||
var CohereEmbeddingModel = class { | ||
constructor(modelId, settings, config) { | ||
this.specificationVersion = "v1"; | ||
this.maxEmbeddingsPerCall = 96; | ||
this.supportsParallelCalls = true; | ||
this.modelId = modelId; | ||
this.settings = settings; | ||
this.config = config; | ||
} | ||
get provider() { | ||
return this.config.provider; | ||
} | ||
async doEmbed({ | ||
values, | ||
headers, | ||
abortSignal | ||
}) { | ||
var _a; | ||
if (values.length > this.maxEmbeddingsPerCall) { | ||
throw new import_provider3.TooManyEmbeddingValuesForCallError({ | ||
provider: this.provider, | ||
modelId: this.modelId, | ||
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, | ||
values | ||
}); | ||
} | ||
const { responseHeaders, value: response } = await (0, import_provider_utils3.postJsonToApi)({ | ||
url: `${this.config.baseURL}/embed`, | ||
headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers), | ||
body: { | ||
model: this.modelId, | ||
texts: values, | ||
input_type: (_a = this.settings.inputType) != null ? _a : "search_query", | ||
truncate: this.settings.truncate | ||
}, | ||
failedResponseHandler: cohereFailedResponseHandler, | ||
successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)( | ||
cohereTextEmbeddingResponseSchema | ||
), | ||
abortSignal, | ||
fetch: this.config.fetch | ||
}); | ||
return { | ||
embeddings: response.embeddings, | ||
usage: { tokens: response.meta.billed_units.input_tokens }, | ||
rawResponse: { headers: responseHeaders } | ||
}; | ||
} | ||
}; | ||
var cohereTextEmbeddingResponseSchema = import_zod3.z.object({ | ||
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())), | ||
meta: import_zod3.z.object({ | ||
billed_units: import_zod3.z.object({ | ||
input_tokens: import_zod3.z.number() | ||
}) | ||
}) | ||
}); | ||
// src/cohere-provider.ts | ||
function createCohere(options = {}) { | ||
var _a; | ||
const baseURL = (_a = (0, import_provider_utils3.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const getHeaders = () => ({ | ||
Authorization: `Bearer ${(0, import_provider_utils3.loadApiKey)({ | ||
Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({ | ||
apiKey: options.apiKey, | ||
@@ -547,6 +608,12 @@ environmentVariableName: "COHERE_API_KEY", | ||
headers: getHeaders, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils3.generateId, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId, | ||
fetch: options.fetch | ||
}); | ||
}; | ||
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, { | ||
provider: "cohere.textEmbedding", | ||
baseURL, | ||
headers: getHeaders, | ||
fetch: options.fetch | ||
}); | ||
const provider = function(modelId, settings) { | ||
@@ -561,5 +628,4 @@ if (new.target) { | ||
provider.languageModel = createChatModel; | ||
provider.textEmbeddingModel = (modelId) => { | ||
throw new import_provider3.NoSuchModelError({ modelId, modelType: "textEmbeddingModel" }); | ||
}; | ||
provider.embedding = createTextEmbeddingModel; | ||
provider.textEmbeddingModel = createTextEmbeddingModel; | ||
return provider; | ||
@@ -566,0 +632,0 @@ } |
@@ -1,2 +0,2 @@ | ||
import { ProviderV1, LanguageModelV1 } from '@ai-sdk/provider'; | ||
import { ProviderV1, LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; | ||
import { FetchFunction } from '@ai-sdk/provider-utils'; | ||
@@ -8,2 +8,24 @@ | ||
type CohereEmbeddingModelId = 'embed-english-v3.0' | 'embed-multilingual-v3.0' | 'embed-english-light-v3.0' | 'embed-multilingual-light-v3.0' | 'embed-english-v2.0' | 'embed-english-light-v2.0' | 'embed-multilingual-v2.0' | (string & {}); | ||
interface CohereEmbeddingSettings { | ||
/** | ||
* Specifies the type of input passed to the model. Default is `search_query`. | ||
* | ||
* - "search_document": Used for embeddings stored in a vector database for search use-cases. | ||
* - "search_query": Used for embeddings of search queries run against a vector DB to find relevant documents. | ||
* - "classification": Used for embeddings passed through a text classifier. | ||
* - "clustering": Used for embeddings run through a clustering algorithm. | ||
*/ | ||
inputType?: 'search_document' | 'search_query' | 'classification' | 'clustering'; | ||
/** | ||
* Specifies how the API will handle inputs longer than the maximum token length. | ||
* Default is `END`. | ||
* | ||
* - "NONE": If selected, when the input exceeds the maximum input token length will return an error. | ||
* - "START": Will discard the start of the input until the remaining input is exactly the maximum input token length for the model. | ||
* - "END": Will discard the end of the input until the remaining input is exactly the maximum input token length for the model. | ||
*/ | ||
truncate?: 'NONE' | 'START' | 'END'; | ||
} | ||
interface CohereProvider extends ProviderV1 { | ||
@@ -15,2 +37,4 @@ (modelId: CohereChatModelId, settings?: CohereChatSettings): LanguageModelV1; | ||
languageModel(modelId: CohereChatModelId, settings?: CohereChatSettings): LanguageModelV1; | ||
embedding(modelId: CohereEmbeddingModelId, settings?: CohereEmbeddingSettings): EmbeddingModelV1<string>; | ||
textEmbeddingModel(modelId: CohereEmbeddingModelId, settings?: CohereEmbeddingSettings): EmbeddingModelV1<string>; | ||
} | ||
@@ -17,0 +41,0 @@ interface CohereProviderSettings { |
@@ -29,4 +29,3 @@ "use strict"; | ||
// src/cohere-provider.ts | ||
var import_provider3 = require("@ai-sdk/provider"); | ||
var import_provider_utils3 = require("@ai-sdk/provider-utils"); | ||
var import_provider_utils4 = require("@ai-sdk/provider-utils"); | ||
@@ -528,8 +527,70 @@ // src/cohere-chat-language-model.ts | ||
// src/cohere-embedding-model.ts | ||
var import_provider3 = require("@ai-sdk/provider"); | ||
var import_provider_utils3 = require("@ai-sdk/provider-utils"); | ||
var import_zod3 = require("zod"); | ||
var CohereEmbeddingModel = class { | ||
constructor(modelId, settings, config) { | ||
this.specificationVersion = "v1"; | ||
this.maxEmbeddingsPerCall = 96; | ||
this.supportsParallelCalls = true; | ||
this.modelId = modelId; | ||
this.settings = settings; | ||
this.config = config; | ||
} | ||
get provider() { | ||
return this.config.provider; | ||
} | ||
async doEmbed({ | ||
values, | ||
headers, | ||
abortSignal | ||
}) { | ||
var _a; | ||
if (values.length > this.maxEmbeddingsPerCall) { | ||
throw new import_provider3.TooManyEmbeddingValuesForCallError({ | ||
provider: this.provider, | ||
modelId: this.modelId, | ||
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, | ||
values | ||
}); | ||
} | ||
const { responseHeaders, value: response } = await (0, import_provider_utils3.postJsonToApi)({ | ||
url: `${this.config.baseURL}/embed`, | ||
headers: (0, import_provider_utils3.combineHeaders)(this.config.headers(), headers), | ||
body: { | ||
model: this.modelId, | ||
texts: values, | ||
input_type: (_a = this.settings.inputType) != null ? _a : "search_query", | ||
truncate: this.settings.truncate | ||
}, | ||
failedResponseHandler: cohereFailedResponseHandler, | ||
successfulResponseHandler: (0, import_provider_utils3.createJsonResponseHandler)( | ||
cohereTextEmbeddingResponseSchema | ||
), | ||
abortSignal, | ||
fetch: this.config.fetch | ||
}); | ||
return { | ||
embeddings: response.embeddings, | ||
usage: { tokens: response.meta.billed_units.input_tokens }, | ||
rawResponse: { headers: responseHeaders } | ||
}; | ||
} | ||
}; | ||
var cohereTextEmbeddingResponseSchema = import_zod3.z.object({ | ||
embeddings: import_zod3.z.array(import_zod3.z.array(import_zod3.z.number())), | ||
meta: import_zod3.z.object({ | ||
billed_units: import_zod3.z.object({ | ||
input_tokens: import_zod3.z.number() | ||
}) | ||
}) | ||
}); | ||
// src/cohere-provider.ts | ||
function createCohere(options = {}) { | ||
var _a; | ||
const baseURL = (_a = (0, import_provider_utils3.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const baseURL = (_a = (0, import_provider_utils4.withoutTrailingSlash)(options.baseURL)) != null ? _a : "https://api.cohere.com/v1"; | ||
const getHeaders = () => ({ | ||
Authorization: `Bearer ${(0, import_provider_utils3.loadApiKey)({ | ||
Authorization: `Bearer ${(0, import_provider_utils4.loadApiKey)({ | ||
apiKey: options.apiKey, | ||
@@ -547,6 +608,12 @@ environmentVariableName: "COHERE_API_KEY", | ||
headers: getHeaders, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils3.generateId, | ||
generateId: (_a2 = options.generateId) != null ? _a2 : import_provider_utils4.generateId, | ||
fetch: options.fetch | ||
}); | ||
}; | ||
const createTextEmbeddingModel = (modelId, settings = {}) => new CohereEmbeddingModel(modelId, settings, { | ||
provider: "cohere.textEmbedding", | ||
baseURL, | ||
headers: getHeaders, | ||
fetch: options.fetch | ||
}); | ||
const provider = function(modelId, settings) { | ||
@@ -561,5 +628,4 @@ if (new.target) { | ||
provider.languageModel = createChatModel; | ||
provider.textEmbeddingModel = (modelId) => { | ||
throw new import_provider3.NoSuchModelError({ modelId, modelType: "textEmbeddingModel" }); | ||
}; | ||
provider.embedding = createTextEmbeddingModel; | ||
provider.textEmbeddingModel = createTextEmbeddingModel; | ||
return provider; | ||
@@ -566,0 +632,0 @@ } |
{ | ||
"name": "@ai-sdk/cohere", | ||
"version": "0.0.21", | ||
"version": "0.0.22", | ||
"license": "Apache-2.0", | ||
@@ -10,3 +10,4 @@ "sideEffects": false, | ||
"files": [ | ||
"dist/**/*" | ||
"dist/**/*", | ||
"CHANGELOG.md" | ||
], | ||
@@ -13,0 +14,0 @@ "exports": { |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
146140
11
1928