@lancedb/lancedb
Advanced tools
Comparing version 0.5.0 to 0.5.1
@@ -51,3 +51,3 @@ { | ||
"noUnusedLabels": "error", | ||
"noUnusedVariables": "error", | ||
"noUnusedVariables": "warn", | ||
"useIsNan": "error", | ||
@@ -105,3 +105,9 @@ "useValidForDirection": "error", | ||
{ | ||
"include": ["**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"], | ||
"include": [ | ||
"**/*.ts", | ||
"**/*.tsx", | ||
"**/*.mts", | ||
"**/*.cts", | ||
"__test__/*.test.ts" | ||
], | ||
"linter": { | ||
@@ -108,0 +114,0 @@ "rules": { |
/// <reference types="node" /> | ||
import { Table as ArrowTable, type Float, Schema } from "apache-arrow"; | ||
import { Table as ArrowTable, Binary, DataType, FixedSizeBinary, FixedSizeList, Float, Int, LargeBinary, List, Null, Schema, Struct, Utf8 } from "apache-arrow"; | ||
import { type EmbeddingFunction } from "./embedding/embedding_function"; | ||
import { EmbeddingFunctionConfig } from "./embedding/registry"; | ||
export * from "apache-arrow"; | ||
export declare function isArrowTable(value: object): value is ArrowTable; | ||
export declare function isDataType(value: unknown): value is DataType; | ||
export declare function isNull(value: unknown): value is Null; | ||
export declare function isInt(value: unknown): value is Int; | ||
export declare function isFloat(value: unknown): value is Float; | ||
export declare function isBinary(value: unknown): value is Binary; | ||
export declare function isLargeBinary(value: unknown): value is LargeBinary; | ||
export declare function isUtf8(value: unknown): value is Utf8; | ||
export declare function isLargeUtf8(value: unknown): value is Utf8; | ||
export declare function isBool(value: unknown): value is Utf8; | ||
export declare function isDecimal(value: unknown): value is Utf8; | ||
export declare function isDate(value: unknown): value is Utf8; | ||
export declare function isTime(value: unknown): value is Utf8; | ||
export declare function isTimestamp(value: unknown): value is Utf8; | ||
export declare function isInterval(value: unknown): value is Utf8; | ||
export declare function isDuration(value: unknown): value is Utf8; | ||
export declare function isList(value: unknown): value is List; | ||
export declare function isStruct(value: unknown): value is Struct; | ||
export declare function isUnion(value: unknown): value is Struct; | ||
export declare function isFixedSizeBinary(value: unknown): value is FixedSizeBinary; | ||
export declare function isFixedSizeList(value: unknown): value is FixedSizeList; | ||
/** Data type accepted by NodeJS SDK */ | ||
@@ -120,7 +143,7 @@ export type Data = Record<string, unknown>[] | ArrowTable; | ||
*/ | ||
export declare function makeArrowTable(data: Array<Record<string, unknown>>, options?: Partial<MakeArrowTableOptions>): ArrowTable; | ||
export declare function makeArrowTable(data: Array<Record<string, unknown>>, options?: Partial<MakeArrowTableOptions>, metadata?: Map<string, string>): ArrowTable; | ||
/** | ||
* Create an empty Arrow table with the provided schema | ||
*/ | ||
export declare function makeEmptyTable(schema: Schema): ArrowTable; | ||
export declare function makeEmptyTable(schema: Schema, metadata?: Map<string, string>): ArrowTable; | ||
/** | ||
@@ -144,3 +167,5 @@ * Convert an Array of records into an Arrow Table, optionally applying an | ||
*/ | ||
export declare function convertToTable<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, makeTableOptions?: Partial<MakeArrowTableOptions>): Promise<ArrowTable>; | ||
export declare function convertToTable(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, makeTableOptions?: Partial<MakeArrowTableOptions>): Promise<ArrowTable>; | ||
/** Creates the Arrow Type for a Vector column with dimension `dim` */ | ||
export declare function newVectorType<T extends Float>(dim: number, innerType: T): FixedSizeList<T>; | ||
/** | ||
@@ -153,3 +178,3 @@ * Serialize an Array of records into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromRecordsToBuffer<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromRecordsToBuffer(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -162,3 +187,3 @@ * Serialize an Array of records into a buffer using the Arrow IPC Stream serialization | ||
*/ | ||
export declare function fromRecordsToStreamBuffer<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromRecordsToStreamBuffer(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -172,3 +197,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromTableToBuffer<T>(table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromTableToBuffer(table: ArrowTable, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -182,3 +207,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromDataToBuffer<T>(data: Data, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromDataToBuffer(data: Data, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -192,3 +217,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization | ||
*/ | ||
export declare function fromTableToStreamBuffer<T>(table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromTableToStreamBuffer(table: ArrowTable, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -195,0 +220,0 @@ * Create an empty table with the given schema |
@@ -15,6 +15,129 @@ "use strict"; | ||
// limitations under the License. | ||
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
var desc = Object.getOwnPropertyDescriptor(m, k); | ||
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { | ||
desc = { enumerable: true, get: function() { return m[k]; } }; | ||
} | ||
Object.defineProperty(o, k2, desc); | ||
}) : (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
o[k2] = m[k]; | ||
})); | ||
var __exportStar = (this && this.__exportStar) || function(m, exports) { | ||
for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.createEmptyTable = exports.fromTableToStreamBuffer = exports.fromDataToBuffer = exports.fromTableToBuffer = exports.fromRecordsToStreamBuffer = exports.fromRecordsToBuffer = exports.convertToTable = exports.makeEmptyTable = exports.makeArrowTable = exports.MakeArrowTableOptions = exports.VectorColumnOptions = void 0; | ||
exports.createEmptyTable = exports.fromTableToStreamBuffer = exports.fromDataToBuffer = exports.fromTableToBuffer = exports.fromRecordsToStreamBuffer = exports.fromRecordsToBuffer = exports.newVectorType = exports.convertToTable = exports.makeEmptyTable = exports.makeArrowTable = exports.MakeArrowTableOptions = exports.VectorColumnOptions = exports.isFixedSizeList = exports.isFixedSizeBinary = exports.isUnion = exports.isStruct = exports.isList = exports.isDuration = exports.isInterval = exports.isTimestamp = exports.isTime = exports.isDate = exports.isDecimal = exports.isBool = exports.isLargeUtf8 = exports.isUtf8 = exports.isLargeBinary = exports.isBinary = exports.isFloat = exports.isInt = exports.isNull = exports.isDataType = exports.isArrowTable = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const sanitize_1 = require("./sanitize"); | ||
__exportStar(require("apache-arrow"), exports); | ||
function isArrowTable(value) { | ||
if (value instanceof apache_arrow_1.Table) | ||
return true; | ||
return "schema" in value && "batches" in value; | ||
} | ||
exports.isArrowTable = isArrowTable; | ||
function isDataType(value) { | ||
return (value instanceof apache_arrow_1.DataType || | ||
apache_arrow_1.DataType.isNull(value) || | ||
apache_arrow_1.DataType.isInt(value) || | ||
apache_arrow_1.DataType.isFloat(value) || | ||
apache_arrow_1.DataType.isBinary(value) || | ||
apache_arrow_1.DataType.isLargeBinary(value) || | ||
apache_arrow_1.DataType.isUtf8(value) || | ||
apache_arrow_1.DataType.isLargeUtf8(value) || | ||
apache_arrow_1.DataType.isBool(value) || | ||
apache_arrow_1.DataType.isDecimal(value) || | ||
apache_arrow_1.DataType.isDate(value) || | ||
apache_arrow_1.DataType.isTime(value) || | ||
apache_arrow_1.DataType.isTimestamp(value) || | ||
apache_arrow_1.DataType.isInterval(value) || | ||
apache_arrow_1.DataType.isDuration(value) || | ||
apache_arrow_1.DataType.isList(value) || | ||
apache_arrow_1.DataType.isStruct(value) || | ||
apache_arrow_1.DataType.isUnion(value) || | ||
apache_arrow_1.DataType.isFixedSizeBinary(value) || | ||
apache_arrow_1.DataType.isFixedSizeList(value) || | ||
apache_arrow_1.DataType.isMap(value) || | ||
apache_arrow_1.DataType.isDictionary(value)); | ||
} | ||
exports.isDataType = isDataType; | ||
function isNull(value) { | ||
return value instanceof apache_arrow_1.Null || apache_arrow_1.DataType.isNull(value); | ||
} | ||
exports.isNull = isNull; | ||
function isInt(value) { | ||
return value instanceof apache_arrow_1.Int || apache_arrow_1.DataType.isInt(value); | ||
} | ||
exports.isInt = isInt; | ||
function isFloat(value) { | ||
return value instanceof apache_arrow_1.Float || apache_arrow_1.DataType.isFloat(value); | ||
} | ||
exports.isFloat = isFloat; | ||
function isBinary(value) { | ||
return value instanceof apache_arrow_1.Binary || apache_arrow_1.DataType.isBinary(value); | ||
} | ||
exports.isBinary = isBinary; | ||
function isLargeBinary(value) { | ||
return value instanceof apache_arrow_1.LargeBinary || apache_arrow_1.DataType.isLargeBinary(value); | ||
} | ||
exports.isLargeBinary = isLargeBinary; | ||
function isUtf8(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isUtf8(value); | ||
} | ||
exports.isUtf8 = isUtf8; | ||
function isLargeUtf8(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isLargeUtf8(value); | ||
} | ||
exports.isLargeUtf8 = isLargeUtf8; | ||
function isBool(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isBool(value); | ||
} | ||
exports.isBool = isBool; | ||
function isDecimal(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDecimal(value); | ||
} | ||
exports.isDecimal = isDecimal; | ||
function isDate(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDate(value); | ||
} | ||
exports.isDate = isDate; | ||
function isTime(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isTime(value); | ||
} | ||
exports.isTime = isTime; | ||
function isTimestamp(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isTimestamp(value); | ||
} | ||
exports.isTimestamp = isTimestamp; | ||
function isInterval(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isInterval(value); | ||
} | ||
exports.isInterval = isInterval; | ||
function isDuration(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDuration(value); | ||
} | ||
exports.isDuration = isDuration; | ||
function isList(value) { | ||
return value instanceof apache_arrow_1.List || apache_arrow_1.DataType.isList(value); | ||
} | ||
exports.isList = isList; | ||
function isStruct(value) { | ||
return value instanceof apache_arrow_1.Struct || apache_arrow_1.DataType.isStruct(value); | ||
} | ||
exports.isStruct = isStruct; | ||
function isUnion(value) { | ||
return value instanceof apache_arrow_1.Struct || apache_arrow_1.DataType.isUnion(value); | ||
} | ||
exports.isUnion = isUnion; | ||
function isFixedSizeBinary(value) { | ||
return value instanceof apache_arrow_1.FixedSizeBinary || apache_arrow_1.DataType.isFixedSizeBinary(value); | ||
} | ||
exports.isFixedSizeBinary = isFixedSizeBinary; | ||
function isFixedSizeList(value) { | ||
return value instanceof apache_arrow_1.FixedSizeList || apache_arrow_1.DataType.isFixedSizeList(value); | ||
} | ||
exports.isFixedSizeList = isFixedSizeList; | ||
/* | ||
@@ -172,3 +295,3 @@ * Options to control how a column should be converted to a vector array | ||
*/ | ||
function makeArrowTable(data, options) { | ||
function makeArrowTable(data, options, metadata) { | ||
if (data.length === 0 && | ||
@@ -251,10 +374,27 @@ (options?.schema === undefined || options?.schema === null)) { | ||
const firstTable = new apache_arrow_1.Table(columns); | ||
const batchesFixed = firstTable.batches.map( | ||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
(batch) => new apache_arrow_1.RecordBatch(opt.schema, batch.data)); | ||
return new apache_arrow_1.Table(opt.schema, batchesFixed); | ||
const batchesFixed = firstTable.batches.map((batch) => new apache_arrow_1.RecordBatch(opt.schema, batch.data)); | ||
let schema; | ||
if (metadata !== undefined) { | ||
let schemaMetadata = opt.schema.metadata; | ||
if (schemaMetadata.size === 0) { | ||
schemaMetadata = metadata; | ||
} | ||
else { | ||
for (const [key, entry] of schemaMetadata.entries()) { | ||
schemaMetadata.set(key, entry); | ||
} | ||
} | ||
schema = new apache_arrow_1.Schema(opt.schema.fields, schemaMetadata); | ||
} | ||
else { | ||
schema = opt.schema; | ||
} | ||
return new apache_arrow_1.Table(schema, batchesFixed); | ||
} | ||
else { | ||
return new apache_arrow_1.Table(columns); | ||
const tbl = new apache_arrow_1.Table(columns); | ||
if (metadata !== undefined) { | ||
// biome-ignore lint/suspicious/noExplicitAny: <explanation> | ||
tbl.schema.metadata = metadata; | ||
} | ||
return tbl; | ||
} | ||
@@ -265,4 +405,4 @@ exports.makeArrowTable = makeArrowTable; | ||
*/ | ||
function makeEmptyTable(schema) { | ||
return makeArrowTable([], { schema }); | ||
function makeEmptyTable(schema, metadata) { | ||
return makeArrowTable([], { schema }, metadata); | ||
} | ||
@@ -329,5 +469,48 @@ exports.makeEmptyTable = makeEmptyTable; | ||
} | ||
/** Helper function to apply embeddings from metadata to an input table */ | ||
async function applyEmbeddingsFromMetadata(table, schema) { | ||
const registry = (0, registry_1.getRegistry)(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const columns = Object.fromEntries(table.schema.fields.map((field) => [ | ||
field.name, | ||
table.getChild(field.name), | ||
])); | ||
for (const functionEntry of functions.values()) { | ||
const sourceColumn = columns[functionEntry.sourceColumn]; | ||
const destColumn = functionEntry.vectorColumn ?? "vector"; | ||
if (sourceColumn === undefined) { | ||
throw new Error(`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`); | ||
} | ||
if (columns[destColumn] !== undefined) { | ||
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`); | ||
} | ||
if (table.batches.length > 1) { | ||
throw new Error("Internal error: `makeArrowTable` unexpectedly created a table with more than one batch"); | ||
} | ||
const values = sourceColumn.toArray(); | ||
const vectors = await functionEntry.function.computeSourceEmbeddings(values); | ||
if (vectors.length !== values.length) { | ||
throw new Error("Embedding function did not return an embedding for each input element"); | ||
} | ||
let destType; | ||
const dtype = schema.fields.find((f) => f.name === destColumn).type; | ||
if (isFixedSizeList(dtype)) { | ||
destType = (0, sanitize_1.sanitizeType)(dtype); | ||
} | ||
else { | ||
throw new Error("Expected FixedSizeList as datatype for vector field, instead got: " + | ||
dtype); | ||
} | ||
const vector = makeVector(vectors, destType); | ||
columns[destColumn] = vector; | ||
} | ||
const newTable = new apache_arrow_1.Table(columns); | ||
return alignTable(newTable, schema); | ||
} | ||
/** Helper function to apply embeddings to an input table */ | ||
async function applyEmbeddings(table, embeddings, schema) { | ||
if (embeddings == null) { | ||
if (schema?.metadata.has("embedding_functions")) { | ||
return applyEmbeddingsFromMetadata(table, schema); | ||
} | ||
else if (embeddings == null || embeddings === undefined) { | ||
return table; | ||
@@ -347,4 +530,4 @@ } | ||
const sourceColumn = newColumns[embeddings.sourceColumn]; | ||
const destColumn = embeddings.destColumn ?? "vector"; | ||
const innerDestType = embeddings.embeddingDataType ?? new apache_arrow_1.Float32(); | ||
const destColumn = embeddings.vectorColumn ?? "vector"; | ||
const innerDestType = embeddings.function.embeddingDataType() ?? new apache_arrow_1.Float32(); | ||
if (sourceColumn === undefined) { | ||
@@ -360,4 +543,5 @@ throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`); | ||
} | ||
if (embeddings.embeddingDimension !== undefined) { | ||
const destType = newVectorType(embeddings.embeddingDimension, innerDestType); | ||
const dimensions = embeddings.function.ndims(); | ||
if (dimensions !== undefined) { | ||
const destType = newVectorType(dimensions, innerDestType); | ||
newColumns[destColumn] = makeVector([], destType); | ||
@@ -386,3 +570,3 @@ } | ||
const values = sourceColumn.toArray(); | ||
const vectors = await embeddings.embed(values); | ||
const vectors = await embeddings.function.computeSourceEmbeddings(values); | ||
if (vectors.length !== values.length) { | ||
@@ -430,5 +614,6 @@ throw new Error("Embedding function did not return an embedding for each input element"); | ||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements | ||
const children = new apache_arrow_1.Field("item", innerType, true); | ||
const children = new apache_arrow_1.Field("item", (0, sanitize_1.sanitizeType)(innerType), true); | ||
return new apache_arrow_1.FixedSizeList(dim, children); | ||
} | ||
exports.newVectorType = newVectorType; | ||
/** | ||
@@ -495,8 +680,8 @@ * Serialize an Array of records into a buffer using the Arrow IPC File serialization | ||
} | ||
if (data instanceof apache_arrow_1.Table) { | ||
if (isArrowTable(data)) { | ||
return fromTableToBuffer(data, embeddings, schema); | ||
} | ||
else { | ||
const table = await convertToTable(data); | ||
return fromTableToBuffer(table, embeddings, schema); | ||
const table = await convertToTable(data, embeddings, { schema }); | ||
return fromTableToBuffer(table); | ||
} | ||
@@ -561,6 +746,18 @@ } | ||
// if they are not, we throw an error | ||
for (const field of schema.fields) { | ||
if (field.type instanceof apache_arrow_1.FixedSizeList) { | ||
for (let field of schema.fields) { | ||
if (isFixedSizeList(field.type)) { | ||
field = (0, sanitize_1.sanitizeField)(field); | ||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { | ||
missingEmbeddingFields.push(field); | ||
if (schema.metadata.has("embedding_functions")) { | ||
const embeddings = JSON.parse(schema.metadata.get("embedding_functions")); | ||
if ( | ||
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f` | ||
embeddings.find((f) => f["vectorColumn"] === field.name) === | ||
undefined) { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} | ||
else { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} | ||
@@ -567,0 +764,0 @@ else { |
@@ -1,2 +0,3 @@ | ||
import { Table as ArrowTable, Schema } from "apache-arrow"; | ||
import { Table as ArrowTable, Schema } from "./arrow"; | ||
import { EmbeddingFunctionConfig } from "./embedding/registry"; | ||
import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; | ||
@@ -42,2 +43,4 @@ import { Table } from "./table"; | ||
storageOptions?: Record<string, string>; | ||
schema?: Schema; | ||
embeddingFunction?: EmbeddingFunctionConfig; | ||
} | ||
@@ -44,0 +47,0 @@ export interface OpenTableOptions { |
@@ -17,4 +17,4 @@ "use strict"; | ||
exports.Connection = exports.connect = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const native_1 = require("./native"); | ||
@@ -113,9 +113,9 @@ const table_1 = require("./table"); | ||
let table; | ||
if (data instanceof apache_arrow_1.Table) { | ||
if ((0, arrow_1.isArrowTable)(data)) { | ||
table = data; | ||
} | ||
else { | ||
table = (0, arrow_1.makeArrowTable)(data); | ||
table = (0, arrow_1.makeArrowTable)(data, options); | ||
} | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table); | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table, options?.embeddingFunction, options?.schema); | ||
const innerTable = await this.inner.createTable(name, buf, mode, cleanseStorageOptions(options?.storageOptions)); | ||
@@ -135,3 +135,9 @@ return new table_1.Table(innerTable); | ||
} | ||
const table = (0, arrow_1.makeEmptyTable)(schema); | ||
let metadata = undefined; | ||
if (options?.embeddingFunction !== undefined) { | ||
const embeddingFunction = options.embeddingFunction; | ||
const registry = (0, registry_1.getRegistry)(); | ||
metadata = registry.getTableMetadata([embeddingFunction]); | ||
} | ||
const table = (0, arrow_1.makeEmptyTable)(schema, metadata); | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table); | ||
@@ -138,0 +144,0 @@ const innerTable = await this.inner.createEmptyTable(name, buf, mode, cleanseStorageOptions(options?.storageOptions)); |
@@ -1,45 +0,71 @@ | ||
import { type Float } from "apache-arrow"; | ||
import "reflect-metadata"; | ||
import { DataType, Float } from "../arrow"; | ||
/** | ||
* Options for a given embedding function | ||
*/ | ||
export interface FunctionOptions { | ||
[key: string]: any; | ||
} | ||
/** | ||
* An embedding function that automatically creates vector representation for a given column. | ||
*/ | ||
export interface EmbeddingFunction<T> { | ||
export declare abstract class EmbeddingFunction<T = any, M extends FunctionOptions = FunctionOptions> { | ||
/** | ||
* The name of the column that will be used as input for the Embedding Function. | ||
*/ | ||
sourceColumn: string; | ||
/** | ||
* The data type of the embedding | ||
* Convert the embedding function to a JSON object | ||
* It is used to serialize the embedding function to the schema | ||
* It's important that any object returned by this method contains all the necessary | ||
* information to recreate the embedding function | ||
* | ||
* The embedding function should return `number`. This will be converted into | ||
* an Arrow float array. By default this will be Float32 but this property can | ||
* be used to control the conversion. | ||
*/ | ||
embeddingDataType?: Float; | ||
/** | ||
* The dimension of the embedding | ||
* It should return the same object that was passed to the constructor | ||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly | ||
* | ||
* This is optional, normally this can be determined by looking at the results of | ||
* `embed`. If this is not specified, and there is an attempt to apply the embedding | ||
* to an empty table, then that process will fail. | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* constructor(options: {model: string, timeout: number}) { | ||
* super(); | ||
* this.model = options.model; | ||
* this.timeout = options.timeout; | ||
* } | ||
* toJSON() { | ||
* return { | ||
* model: this.model, | ||
* timeout: this.timeout, | ||
* }; | ||
* } | ||
* ``` | ||
*/ | ||
embeddingDimension?: number; | ||
abstract toJSON(): Partial<M>; | ||
/** | ||
* The name of the column that will contain the embedding | ||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default this is "vector" | ||
* @param optionsOrDatatype - The options for the field or the datatype | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
destColumn?: string; | ||
sourceField(optionsOrDatatype: Partial<FieldOptions> | DataType): [DataType, Map<string, EmbeddingFunction>]; | ||
/** | ||
* Should the source column be excluded from the resulting table | ||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default the source column is included. Set this to true and | ||
* only the embedding will be stored. | ||
* @param options - The options for the field | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
excludeSource?: boolean; | ||
vectorField(options?: Partial<FieldOptions>): [DataType, Map<string, EmbeddingFunction>]; | ||
/** The number of dimensions of the embeddings */ | ||
ndims(): number | undefined; | ||
/** The datatype of the embeddings */ | ||
abstract embeddingDataType(): Float; | ||
/** | ||
* Creates a vector representation for the given values. | ||
*/ | ||
embed: (data: T[]) => Promise<number[][]>; | ||
abstract computeSourceEmbeddings(data: T[]): Promise<number[][] | Float32Array[] | Float64Array[]>; | ||
/** | ||
Compute the embeddings for a single query | ||
*/ | ||
computeQueryEmbeddings(data: T): Promise<number[] | Float32Array | Float64Array>; | ||
} | ||
/** Test if the input seems to be an embedding function */ | ||
export declare function isEmbeddingFunction<T>(value: unknown): value is EmbeddingFunction<T>; | ||
export interface FieldOptions<T extends DataType = DataType> { | ||
datatype: T; | ||
dims?: number; | ||
} |
"use strict"; | ||
// Copyright 2023 Lance Developers. | ||
// Copyright 2024 Lance Developers. | ||
// | ||
@@ -16,13 +16,74 @@ // Licensed under the Apache License, Version 2.0 (the "License"); | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.isEmbeddingFunction = void 0; | ||
/** Test if the input seems to be an embedding function */ | ||
function isEmbeddingFunction(value) { | ||
if (typeof value !== "object" || value === null) { | ||
return false; | ||
exports.EmbeddingFunction = void 0; | ||
require("reflect-metadata"); | ||
const arrow_1 = require("../arrow"); | ||
const sanitize_1 = require("../sanitize"); | ||
/** | ||
* An embedding function that automatically creates vector representation for a given column. | ||
*/ | ||
class EmbeddingFunction { | ||
/** | ||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* @param optionsOrDatatype - The options for the field or the datatype | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
sourceField(optionsOrDatatype) { | ||
let datatype = (0, arrow_1.isDataType)(optionsOrDatatype) | ||
? optionsOrDatatype | ||
: optionsOrDatatype?.datatype; | ||
if (!datatype) { | ||
throw new Error("Datatype is required"); | ||
} | ||
datatype = (0, sanitize_1.sanitizeType)(datatype); | ||
const metadata = new Map(); | ||
metadata.set("source_column_for", this); | ||
return [datatype, metadata]; | ||
} | ||
if (!("sourceColumn" in value) || !("embed" in value)) { | ||
return false; | ||
/** | ||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* @param options - The options for the field | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
vectorField(options) { | ||
let dtype; | ||
const dims = this.ndims() ?? options?.dims; | ||
if (!options?.datatype) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = new arrow_1.FixedSizeList(dims, new arrow_1.Field("item", new arrow_1.Float32(), true)); | ||
} | ||
else { | ||
if ((0, arrow_1.isFixedSizeList)(options.datatype)) { | ||
dtype = options.datatype; | ||
} | ||
else if ((0, arrow_1.isFloat)(options.datatype)) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = (0, arrow_1.newVectorType)(dims, options.datatype); | ||
} | ||
else { | ||
throw new Error("Expected FixedSizeList or Float as datatype for vector field"); | ||
} | ||
} | ||
const metadata = new Map(); | ||
metadata.set("vector_column_for", this); | ||
return [dtype, metadata]; | ||
} | ||
return (typeof value.sourceColumn === "string" && typeof value.embed === "function"); | ||
/** The number of dimensions of the embeddings */ | ||
ndims() { | ||
return undefined; | ||
} | ||
/** | ||
Compute the embeddings for a single query | ||
*/ | ||
async computeQueryEmbeddings(data) { | ||
return this.computeSourceEmbeddings([data]).then((embeddings) => embeddings[0]); | ||
} | ||
} | ||
exports.isEmbeddingFunction = isEmbeddingFunction; | ||
exports.EmbeddingFunction = EmbeddingFunction; |
@@ -1,2 +0,28 @@ | ||
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function"; | ||
export { OpenAIEmbeddingFunction } from "./openai"; | ||
import { Schema } from "../arrow"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
export { EmbeddingFunction } from "./embedding_function"; | ||
export * from "./openai"; | ||
export * from "./registry"; | ||
/** | ||
* Create a schema with embedding functions. | ||
* | ||
* @param fields | ||
* @returns Schema | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* // ... | ||
* } | ||
* const func = new MyEmbeddingFunction(); | ||
* const schema = LanceSchema({ | ||
* id: new Int32(), | ||
* text: func.sourceField(new Utf8()), | ||
* vector: func.vectorField(), | ||
* // optional: specify the datatype and/or dimensions | ||
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}), | ||
* }); | ||
* | ||
* const table = await db.createTable("my_table", data, { schema }); | ||
* ``` | ||
*/ | ||
export declare function LanceSchema(fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>): Schema; |
"use strict"; | ||
// Copyright 2023 Lance Developers. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
var desc = Object.getOwnPropertyDescriptor(m, k); | ||
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { | ||
desc = { enumerable: true, get: function() { return m[k]; } }; | ||
} | ||
Object.defineProperty(o, k2, desc); | ||
}) : (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
o[k2] = m[k]; | ||
})); | ||
var __exportStar = (this && this.__exportStar) || function(m, exports) { | ||
for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.OpenAIEmbeddingFunction = exports.isEmbeddingFunction = void 0; | ||
exports.LanceSchema = exports.EmbeddingFunction = void 0; | ||
const arrow_1 = require("../arrow"); | ||
const arrow_2 = require("../arrow"); | ||
const sanitize_1 = require("../sanitize"); | ||
const registry_1 = require("./registry"); | ||
var embedding_function_1 = require("./embedding_function"); | ||
Object.defineProperty(exports, "isEmbeddingFunction", { enumerable: true, get: function () { return embedding_function_1.isEmbeddingFunction; } }); | ||
var openai_1 = require("./openai"); | ||
Object.defineProperty(exports, "OpenAIEmbeddingFunction", { enumerable: true, get: function () { return openai_1.OpenAIEmbeddingFunction; } }); | ||
Object.defineProperty(exports, "EmbeddingFunction", { enumerable: true, get: function () { return embedding_function_1.EmbeddingFunction; } }); | ||
// We need to explicitly export '*' so that the `register` decorator actually registers the class. | ||
__exportStar(require("./openai"), exports); | ||
__exportStar(require("./registry"), exports); | ||
/** | ||
* Create a schema with embedding functions. | ||
* | ||
* @param fields | ||
* @returns Schema | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* // ... | ||
* } | ||
* const func = new MyEmbeddingFunction(); | ||
* const schema = LanceSchema({ | ||
* id: new Int32(), | ||
* text: func.sourceField(new Utf8()), | ||
* vector: func.vectorField(), | ||
* // optional: specify the datatype and/or dimensions | ||
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}), | ||
* }); | ||
* | ||
* const table = await db.createTable("my_table", data, { schema }); | ||
* ``` | ||
*/ | ||
function LanceSchema(fields) { | ||
const arrowFields = []; | ||
const embeddingFunctions = new Map(); | ||
Object.entries(fields).forEach(([key, value]) => { | ||
if ((0, arrow_2.isDataType)(value)) { | ||
arrowFields.push(new arrow_1.Field(key, (0, sanitize_1.sanitizeType)(value), true)); | ||
} | ||
else { | ||
const [dtype, metadata] = value; | ||
arrowFields.push(new arrow_1.Field(key, (0, sanitize_1.sanitizeType)(dtype), true)); | ||
parseEmbeddingFunctions(embeddingFunctions, key, metadata); | ||
} | ||
}); | ||
const registry = (0, registry_1.getRegistry)(); | ||
const metadata = registry.getTableMetadata(Array.from(embeddingFunctions.values())); | ||
const schema = new arrow_1.Schema(arrowFields, metadata); | ||
return schema; | ||
} | ||
exports.LanceSchema = LanceSchema; | ||
function parseEmbeddingFunctions(embeddingFunctions, key, metadata) { | ||
if (metadata.has("source_column_for")) { | ||
const embedFunction = metadata.get("source_column_for"); | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
sourceColumn: key, | ||
}); | ||
} | ||
else { | ||
embeddingFunctions.set(embedFunction, { | ||
sourceColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} | ||
else if (metadata.has("vector_column_for")) { | ||
const embedFunction = metadata.get("vector_column_for"); | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
vectorColumn: key, | ||
}); | ||
} | ||
else { | ||
embeddingFunctions.set(embedFunction, { | ||
vectorColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} | ||
} |
@@ -1,8 +0,17 @@ | ||
import { type EmbeddingFunction } from "./embedding_function"; | ||
export declare class OpenAIEmbeddingFunction implements EmbeddingFunction<string> { | ||
private readonly _openai; | ||
private readonly _modelName; | ||
constructor(sourceColumn: string, openAIKey: string, modelName?: string); | ||
embed(data: string[]): Promise<number[][]>; | ||
sourceColumn: string; | ||
import { Float } from "../arrow"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
export type OpenAIOptions = { | ||
apiKey?: string; | ||
model?: string; | ||
}; | ||
export declare class OpenAIEmbeddingFunction extends EmbeddingFunction<string, OpenAIOptions> { | ||
#private; | ||
constructor(options?: OpenAIOptions); | ||
toJSON(): { | ||
model: string; | ||
}; | ||
ndims(): number; | ||
embeddingDataType(): Float; | ||
computeSourceEmbeddings(data: string[]): Promise<number[][]>; | ||
computeQueryEmbeddings(data: string): Promise<number[]>; | ||
} |
@@ -15,8 +15,26 @@ "use strict"; | ||
// limitations under the License. | ||
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) { | ||
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d; | ||
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc); | ||
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r; | ||
return c > 3 && r && Object.defineProperty(target, key, r), r; | ||
}; | ||
var __metadata = (this && this.__metadata) || function (k, v) { | ||
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.OpenAIEmbeddingFunction = void 0; | ||
class OpenAIEmbeddingFunction { | ||
_openai; | ||
_modelName; | ||
constructor(sourceColumn, openAIKey, modelName = "text-embedding-ada-002") { | ||
const arrow_1 = require("../arrow"); | ||
const embedding_function_1 = require("./embedding_function"); | ||
const registry_1 = require("./registry"); | ||
let OpenAIEmbeddingFunction = class OpenAIEmbeddingFunction extends embedding_function_1.EmbeddingFunction { | ||
#openai; | ||
#modelName; | ||
constructor(options = { model: "text-embedding-ada-002" }) { | ||
super(); | ||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; | ||
if (!openAIKey) { | ||
throw new Error("OpenAI API key is required"); | ||
} | ||
const modelName = options?.model ?? "text-embedding-ada-002"; | ||
/** | ||
@@ -34,12 +52,31 @@ * @type {import("openai").default} | ||
} | ||
this.sourceColumn = sourceColumn; | ||
const configuration = { | ||
apiKey: openAIKey, | ||
}; | ||
this._openai = new Openai(configuration); | ||
this._modelName = modelName; | ||
this.#openai = new Openai(configuration); | ||
this.#modelName = modelName; | ||
} | ||
async embed(data) { | ||
const response = await this._openai.embeddings.create({ | ||
model: this._modelName, | ||
toJSON() { | ||
return { | ||
model: this.#modelName, | ||
}; | ||
} | ||
ndims() { | ||
switch (this.#modelName) { | ||
case "text-embedding-ada-002": | ||
return 1536; | ||
case "text-embedding-3-large": | ||
return 3072; | ||
case "text-embedding-3-small": | ||
return 1536; | ||
default: | ||
return null; | ||
} | ||
} | ||
embeddingDataType() { | ||
return new arrow_1.Float32(); | ||
} | ||
async computeSourceEmbeddings(data) { | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
@@ -53,4 +90,17 @@ }); | ||
} | ||
sourceColumn; | ||
} | ||
async computeQueryEmbeddings(data) { | ||
if (typeof data !== "string") { | ||
throw new Error("Data must be a string"); | ||
} | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
}); | ||
return response.data[0].embedding; | ||
} | ||
}; | ||
exports.OpenAIEmbeddingFunction = OpenAIEmbeddingFunction; | ||
exports.OpenAIEmbeddingFunction = OpenAIEmbeddingFunction = __decorate([ | ||
(0, registry_1.register)("openai"), | ||
__metadata("design:paramtypes", [Object]) | ||
], OpenAIEmbeddingFunction); |
@@ -1,2 +0,2 @@ | ||
import { Table as ArrowTable, RecordBatch } from "apache-arrow"; | ||
import { Table as ArrowTable, RecordBatch } from "./arrow"; | ||
import { RecordBatchIterator as NativeBatchIterator, Query as NativeQuery, Table as NativeTable, VectorQuery as NativeVectorQuery } from "./native"; | ||
@@ -3,0 +3,0 @@ export declare class RecordBatchIterator implements AsyncIterator<RecordBatch> { |
@@ -17,3 +17,3 @@ "use strict"; | ||
exports.Query = exports.VectorQuery = exports.QueryBase = exports.RecordBatchIterator = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
class RecordBatchIterator { | ||
@@ -38,3 +38,3 @@ promisedInner; | ||
} | ||
const tbl = (0, apache_arrow_1.tableFromIPC)(n); | ||
const tbl = (0, arrow_1.tableFromIPC)(n); | ||
if (tbl.batches.length != 1) { | ||
@@ -153,3 +153,3 @@ throw new Error("Expected only one batch"); | ||
} | ||
return new apache_arrow_1.Table(batches); | ||
return new arrow_1.Table(batches); | ||
} | ||
@@ -156,0 +156,0 @@ /** Collect the results as an array of objects. */ |
@@ -1,2 +0,23 @@ | ||
import { Schema } from "apache-arrow"; | ||
import type { TKeys } from "apache-arrow/type"; | ||
import { DataType, Date_, Decimal, DenseUnion, Dictionary, Duration, Field, FixedSizeBinary, FixedSizeList, Float, Int, Interval, List, Map_, Schema, SparseUnion, Struct, Time, Timestamp, TimestampMicrosecond, TimestampMillisecond, TimestampNanosecond, TimestampSecond, Type, Union } from "./arrow"; | ||
export declare function sanitizeMetadata(metadataLike?: unknown): Map<string, string> | undefined; | ||
export declare function sanitizeInt(typeLike: object): Int<Type.Int | Type.Int8 | Type.Int16 | Type.Int32 | Type.Int64 | Type.Uint8 | Type.Uint16 | Type.Uint32 | Type.Uint64>; | ||
export declare function sanitizeFloat(typeLike: object): Float<Type.Float | Type.Float16 | Type.Float32 | Type.Float64>; | ||
export declare function sanitizeDecimal(typeLike: object): Decimal; | ||
export declare function sanitizeDate(typeLike: object): Date_<import("apache-arrow/type").Dates>; | ||
export declare function sanitizeTime(typeLike: object): Time<Type.Time | Type.TimeSecond | Type.TimeMillisecond | Type.TimeMicrosecond | Type.TimeNanosecond>; | ||
export declare function sanitizeTimestamp(typeLike: object): Timestamp<Type.Timestamp | Type.TimestampSecond | Type.TimestampMillisecond | Type.TimestampMicrosecond | Type.TimestampNanosecond>; | ||
export declare function sanitizeTypedTimestamp(typeLike: object, Datatype: typeof TimestampNanosecond | typeof TimestampMicrosecond | typeof TimestampMillisecond | typeof TimestampSecond): TimestampSecond | TimestampMillisecond | TimestampMicrosecond | TimestampNanosecond; | ||
export declare function sanitizeInterval(typeLike: object): Interval<Type.Interval | Type.IntervalDayTime | Type.IntervalYearMonth>; | ||
export declare function sanitizeList(typeLike: object): List<any>; | ||
export declare function sanitizeStruct(typeLike: object): Struct<any>; | ||
export declare function sanitizeUnion(typeLike: object): Union<Type.Union | Type.DenseUnion | Type.SparseUnion>; | ||
export declare function sanitizeTypedUnion(typeLike: object, UnionType: typeof DenseUnion | typeof SparseUnion): SparseUnion | DenseUnion; | ||
export declare function sanitizeFixedSizeBinary(typeLike: object): FixedSizeBinary; | ||
export declare function sanitizeFixedSizeList(typeLike: object): FixedSizeList<any>; | ||
export declare function sanitizeMap(typeLike: object): Map_<any, any>; | ||
export declare function sanitizeDuration(typeLike: object): Duration<Type.Duration | Type.DurationSecond | Type.DurationMillisecond | Type.DurationMicrosecond | Type.DurationNanosecond>; | ||
export declare function sanitizeDictionary(typeLike: object): Dictionary<DataType<any, any>, TKeys>; | ||
export declare function sanitizeType(typeLike: unknown): DataType<any>; | ||
export declare function sanitizeField(fieldLike: unknown): Field; | ||
/** | ||
@@ -3,0 +24,0 @@ * Convert something schemaLike into a Schema instance |
@@ -16,11 +16,4 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.sanitizeSchema = void 0; | ||
// The utilities in this file help sanitize data from the user's arrow | ||
// library into the types expected by vectordb's arrow library. Node | ||
// generally allows for mulitple versions of the same library (and sometimes | ||
// even multiple copies of the same version) to be installed at the same | ||
// time. However, arrow-js uses instanceof which expected that the input | ||
// comes from the exact same library instance. This is not always the case | ||
// and so we must sanitize the input to ensure that it is compatible. | ||
const apache_arrow_1 = require("apache-arrow"); | ||
exports.sanitizeSchema = exports.sanitizeField = exports.sanitizeType = exports.sanitizeDictionary = exports.sanitizeDuration = exports.sanitizeMap = exports.sanitizeFixedSizeList = exports.sanitizeFixedSizeBinary = exports.sanitizeTypedUnion = exports.sanitizeUnion = exports.sanitizeStruct = exports.sanitizeList = exports.sanitizeInterval = exports.sanitizeTypedTimestamp = exports.sanitizeTimestamp = exports.sanitizeTime = exports.sanitizeDate = exports.sanitizeDecimal = exports.sanitizeFloat = exports.sanitizeInt = exports.sanitizeMetadata = void 0; | ||
const arrow_1 = require("./arrow"); | ||
function sanitizeMetadata(metadataLike) { | ||
@@ -40,2 +33,3 @@ if (metadataLike === undefined || metadataLike === null) { | ||
} | ||
exports.sanitizeMetadata = sanitizeMetadata; | ||
function sanitizeInt(typeLike) { | ||
@@ -48,4 +42,5 @@ if (!("bitWidth" in typeLike) || | ||
} | ||
return new apache_arrow_1.Int(typeLike.isSigned, typeLike.bitWidth); | ||
return new arrow_1.Int(typeLike.isSigned, typeLike.bitWidth); | ||
} | ||
exports.sanitizeInt = sanitizeInt; | ||
function sanitizeFloat(typeLike) { | ||
@@ -55,4 +50,5 @@ if (!("precision" in typeLike) || typeof typeLike.precision !== "number") { | ||
} | ||
return new apache_arrow_1.Float(typeLike.precision); | ||
return new arrow_1.Float(typeLike.precision); | ||
} | ||
exports.sanitizeFloat = sanitizeFloat; | ||
function sanitizeDecimal(typeLike) { | ||
@@ -67,4 +63,5 @@ if (!("scale" in typeLike) || | ||
} | ||
return new apache_arrow_1.Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); | ||
return new arrow_1.Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); | ||
} | ||
exports.sanitizeDecimal = sanitizeDecimal; | ||
function sanitizeDate(typeLike) { | ||
@@ -74,4 +71,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Date_(typeLike.unit); | ||
return new arrow_1.Date_(typeLike.unit); | ||
} | ||
exports.sanitizeDate = sanitizeDate; | ||
function sanitizeTime(typeLike) { | ||
@@ -84,4 +82,5 @@ if (!("unit" in typeLike) || | ||
} | ||
return new apache_arrow_1.Time(typeLike.unit, typeLike.bitWidth); | ||
return new arrow_1.Time(typeLike.unit, typeLike.bitWidth); | ||
} | ||
exports.sanitizeTime = sanitizeTime; | ||
function sanitizeTimestamp(typeLike) { | ||
@@ -95,4 +94,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Timestamp(typeLike.unit, timezone); | ||
return new arrow_1.Timestamp(typeLike.unit, timezone); | ||
} | ||
exports.sanitizeTimestamp = sanitizeTimestamp; | ||
function sanitizeTypedTimestamp(typeLike, | ||
@@ -107,2 +107,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
} | ||
exports.sanitizeTypedTimestamp = sanitizeTypedTimestamp; | ||
function sanitizeInterval(typeLike) { | ||
@@ -112,4 +113,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Interval(typeLike.unit); | ||
return new arrow_1.Interval(typeLike.unit); | ||
} | ||
exports.sanitizeInterval = sanitizeInterval; | ||
function sanitizeList(typeLike) { | ||
@@ -122,4 +124,5 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.List(sanitizeField(typeLike.children[0])); | ||
return new arrow_1.List(sanitizeField(typeLike.children[0])); | ||
} | ||
exports.sanitizeList = sanitizeList; | ||
function sanitizeStruct(typeLike) { | ||
@@ -129,4 +132,5 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.Struct(typeLike.children.map((child) => sanitizeField(child))); | ||
return new arrow_1.Struct(typeLike.children.map((child) => sanitizeField(child))); | ||
} | ||
exports.sanitizeStruct = sanitizeStruct; | ||
function sanitizeUnion(typeLike) { | ||
@@ -141,6 +145,7 @@ if (!("typeIds" in typeLike) || | ||
} | ||
return new apache_arrow_1.Union(typeLike.mode, | ||
return new arrow_1.Union(typeLike.mode, | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
typeLike.typeIds, typeLike.children.map((child) => sanitizeField(child))); | ||
} | ||
exports.sanitizeUnion = sanitizeUnion; | ||
function sanitizeTypedUnion(typeLike, | ||
@@ -157,2 +162,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
} | ||
exports.sanitizeTypedUnion = sanitizeTypedUnion; | ||
function sanitizeFixedSizeBinary(typeLike) { | ||
@@ -162,4 +168,5 @@ if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { | ||
} | ||
return new apache_arrow_1.FixedSizeBinary(typeLike.byteWidth); | ||
return new arrow_1.FixedSizeBinary(typeLike.byteWidth); | ||
} | ||
exports.sanitizeFixedSizeBinary = sanitizeFixedSizeBinary; | ||
function sanitizeFixedSizeList(typeLike) { | ||
@@ -175,4 +182,5 @@ if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") { | ||
} | ||
return new apache_arrow_1.FixedSizeList(typeLike.listSize, sanitizeField(typeLike.children[0])); | ||
return new arrow_1.FixedSizeList(typeLike.listSize, sanitizeField(typeLike.children[0])); | ||
} | ||
exports.sanitizeFixedSizeList = sanitizeFixedSizeList; | ||
function sanitizeMap(typeLike) { | ||
@@ -185,6 +193,7 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.Map_( | ||
return new arrow_1.Map_( | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
typeLike.children.map((field) => sanitizeField(field)), typeLike.keysSorted); | ||
} | ||
exports.sanitizeMap = sanitizeMap; | ||
function sanitizeDuration(typeLike) { | ||
@@ -194,4 +203,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Duration(typeLike.unit); | ||
return new arrow_1.Duration(typeLike.unit); | ||
} | ||
exports.sanitizeDuration = sanitizeDuration; | ||
function sanitizeDictionary(typeLike) { | ||
@@ -210,4 +220,5 @@ if (!("id" in typeLike) || typeof typeLike.id !== "number") { | ||
} | ||
return new apache_arrow_1.Dictionary(sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices), typeLike.id, typeLike.isOrdered); | ||
return new arrow_1.Dictionary(sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices), typeLike.id, typeLike.isOrdered); | ||
} | ||
exports.sanitizeDictionary = sanitizeDictionary; | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
@@ -232,100 +243,100 @@ function sanitizeType(typeLike) { | ||
switch (typeId) { | ||
case apache_arrow_1.Type.NONE: | ||
case arrow_1.Type.NONE: | ||
throw Error("Received a Type with a typeId of NONE"); | ||
case apache_arrow_1.Type.Null: | ||
return new apache_arrow_1.Null(); | ||
case apache_arrow_1.Type.Int: | ||
case arrow_1.Type.Null: | ||
return new arrow_1.Null(); | ||
case arrow_1.Type.Int: | ||
return sanitizeInt(typeLike); | ||
case apache_arrow_1.Type.Float: | ||
case arrow_1.Type.Float: | ||
return sanitizeFloat(typeLike); | ||
case apache_arrow_1.Type.Binary: | ||
return new apache_arrow_1.Binary(); | ||
case apache_arrow_1.Type.Utf8: | ||
return new apache_arrow_1.Utf8(); | ||
case apache_arrow_1.Type.Bool: | ||
return new apache_arrow_1.Bool(); | ||
case apache_arrow_1.Type.Decimal: | ||
case arrow_1.Type.Binary: | ||
return new arrow_1.Binary(); | ||
case arrow_1.Type.Utf8: | ||
return new arrow_1.Utf8(); | ||
case arrow_1.Type.Bool: | ||
return new arrow_1.Bool(); | ||
case arrow_1.Type.Decimal: | ||
return sanitizeDecimal(typeLike); | ||
case apache_arrow_1.Type.Date: | ||
case arrow_1.Type.Date: | ||
return sanitizeDate(typeLike); | ||
case apache_arrow_1.Type.Time: | ||
case arrow_1.Type.Time: | ||
return sanitizeTime(typeLike); | ||
case apache_arrow_1.Type.Timestamp: | ||
case arrow_1.Type.Timestamp: | ||
return sanitizeTimestamp(typeLike); | ||
case apache_arrow_1.Type.Interval: | ||
case arrow_1.Type.Interval: | ||
return sanitizeInterval(typeLike); | ||
case apache_arrow_1.Type.List: | ||
case arrow_1.Type.List: | ||
return sanitizeList(typeLike); | ||
case apache_arrow_1.Type.Struct: | ||
case arrow_1.Type.Struct: | ||
return sanitizeStruct(typeLike); | ||
case apache_arrow_1.Type.Union: | ||
case arrow_1.Type.Union: | ||
return sanitizeUnion(typeLike); | ||
case apache_arrow_1.Type.FixedSizeBinary: | ||
case arrow_1.Type.FixedSizeBinary: | ||
return sanitizeFixedSizeBinary(typeLike); | ||
case apache_arrow_1.Type.FixedSizeList: | ||
case arrow_1.Type.FixedSizeList: | ||
return sanitizeFixedSizeList(typeLike); | ||
case apache_arrow_1.Type.Map: | ||
case arrow_1.Type.Map: | ||
return sanitizeMap(typeLike); | ||
case apache_arrow_1.Type.Duration: | ||
case arrow_1.Type.Duration: | ||
return sanitizeDuration(typeLike); | ||
case apache_arrow_1.Type.Dictionary: | ||
case arrow_1.Type.Dictionary: | ||
return sanitizeDictionary(typeLike); | ||
case apache_arrow_1.Type.Int8: | ||
return new apache_arrow_1.Int8(); | ||
case apache_arrow_1.Type.Int16: | ||
return new apache_arrow_1.Int16(); | ||
case apache_arrow_1.Type.Int32: | ||
return new apache_arrow_1.Int32(); | ||
case apache_arrow_1.Type.Int64: | ||
return new apache_arrow_1.Int64(); | ||
case apache_arrow_1.Type.Uint8: | ||
return new apache_arrow_1.Uint8(); | ||
case apache_arrow_1.Type.Uint16: | ||
return new apache_arrow_1.Uint16(); | ||
case apache_arrow_1.Type.Uint32: | ||
return new apache_arrow_1.Uint32(); | ||
case apache_arrow_1.Type.Uint64: | ||
return new apache_arrow_1.Uint64(); | ||
case apache_arrow_1.Type.Float16: | ||
return new apache_arrow_1.Float16(); | ||
case apache_arrow_1.Type.Float32: | ||
return new apache_arrow_1.Float32(); | ||
case apache_arrow_1.Type.Float64: | ||
return new apache_arrow_1.Float64(); | ||
case apache_arrow_1.Type.DateMillisecond: | ||
return new apache_arrow_1.DateMillisecond(); | ||
case apache_arrow_1.Type.DateDay: | ||
return new apache_arrow_1.DateDay(); | ||
case apache_arrow_1.Type.TimeNanosecond: | ||
return new apache_arrow_1.TimeNanosecond(); | ||
case apache_arrow_1.Type.TimeMicrosecond: | ||
return new apache_arrow_1.TimeMicrosecond(); | ||
case apache_arrow_1.Type.TimeMillisecond: | ||
return new apache_arrow_1.TimeMillisecond(); | ||
case apache_arrow_1.Type.TimeSecond: | ||
return new apache_arrow_1.TimeSecond(); | ||
case apache_arrow_1.Type.TimestampNanosecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampNanosecond); | ||
case apache_arrow_1.Type.TimestampMicrosecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampMicrosecond); | ||
case apache_arrow_1.Type.TimestampMillisecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampMillisecond); | ||
case apache_arrow_1.Type.TimestampSecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampSecond); | ||
case apache_arrow_1.Type.DenseUnion: | ||
return sanitizeTypedUnion(typeLike, apache_arrow_1.DenseUnion); | ||
case apache_arrow_1.Type.SparseUnion: | ||
return sanitizeTypedUnion(typeLike, apache_arrow_1.SparseUnion); | ||
case apache_arrow_1.Type.IntervalDayTime: | ||
return new apache_arrow_1.IntervalDayTime(); | ||
case apache_arrow_1.Type.IntervalYearMonth: | ||
return new apache_arrow_1.IntervalYearMonth(); | ||
case apache_arrow_1.Type.DurationNanosecond: | ||
return new apache_arrow_1.DurationNanosecond(); | ||
case apache_arrow_1.Type.DurationMicrosecond: | ||
return new apache_arrow_1.DurationMicrosecond(); | ||
case apache_arrow_1.Type.DurationMillisecond: | ||
return new apache_arrow_1.DurationMillisecond(); | ||
case apache_arrow_1.Type.DurationSecond: | ||
return new apache_arrow_1.DurationSecond(); | ||
case arrow_1.Type.Int8: | ||
return new arrow_1.Int8(); | ||
case arrow_1.Type.Int16: | ||
return new arrow_1.Int16(); | ||
case arrow_1.Type.Int32: | ||
return new arrow_1.Int32(); | ||
case arrow_1.Type.Int64: | ||
return new arrow_1.Int64(); | ||
case arrow_1.Type.Uint8: | ||
return new arrow_1.Uint8(); | ||
case arrow_1.Type.Uint16: | ||
return new arrow_1.Uint16(); | ||
case arrow_1.Type.Uint32: | ||
return new arrow_1.Uint32(); | ||
case arrow_1.Type.Uint64: | ||
return new arrow_1.Uint64(); | ||
case arrow_1.Type.Float16: | ||
return new arrow_1.Float16(); | ||
case arrow_1.Type.Float32: | ||
return new arrow_1.Float32(); | ||
case arrow_1.Type.Float64: | ||
return new arrow_1.Float64(); | ||
case arrow_1.Type.DateMillisecond: | ||
return new arrow_1.DateMillisecond(); | ||
case arrow_1.Type.DateDay: | ||
return new arrow_1.DateDay(); | ||
case arrow_1.Type.TimeNanosecond: | ||
return new arrow_1.TimeNanosecond(); | ||
case arrow_1.Type.TimeMicrosecond: | ||
return new arrow_1.TimeMicrosecond(); | ||
case arrow_1.Type.TimeMillisecond: | ||
return new arrow_1.TimeMillisecond(); | ||
case arrow_1.Type.TimeSecond: | ||
return new arrow_1.TimeSecond(); | ||
case arrow_1.Type.TimestampNanosecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampNanosecond); | ||
case arrow_1.Type.TimestampMicrosecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampMicrosecond); | ||
case arrow_1.Type.TimestampMillisecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampMillisecond); | ||
case arrow_1.Type.TimestampSecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampSecond); | ||
case arrow_1.Type.DenseUnion: | ||
return sanitizeTypedUnion(typeLike, arrow_1.DenseUnion); | ||
case arrow_1.Type.SparseUnion: | ||
return sanitizeTypedUnion(typeLike, arrow_1.SparseUnion); | ||
case arrow_1.Type.IntervalDayTime: | ||
return new arrow_1.IntervalDayTime(); | ||
case arrow_1.Type.IntervalYearMonth: | ||
return new arrow_1.IntervalYearMonth(); | ||
case arrow_1.Type.DurationNanosecond: | ||
return new arrow_1.DurationNanosecond(); | ||
case arrow_1.Type.DurationMicrosecond: | ||
return new arrow_1.DurationMicrosecond(); | ||
case arrow_1.Type.DurationMillisecond: | ||
return new arrow_1.DurationMillisecond(); | ||
case arrow_1.Type.DurationSecond: | ||
return new arrow_1.DurationSecond(); | ||
default: | ||
@@ -335,4 +346,5 @@ throw new Error("Unrecoginized type id in schema: " + typeId); | ||
} | ||
exports.sanitizeType = sanitizeType; | ||
function sanitizeField(fieldLike) { | ||
if (fieldLike instanceof apache_arrow_1.Field) { | ||
if (fieldLike instanceof arrow_1.Field) { | ||
return fieldLike; | ||
@@ -361,4 +373,5 @@ } | ||
} | ||
return new apache_arrow_1.Field(name, type, nullable, metadata); | ||
return new arrow_1.Field(name, type, nullable, metadata); | ||
} | ||
exports.sanitizeField = sanitizeField; | ||
/** | ||
@@ -372,3 +385,3 @@ * Convert something schemaLike into a Schema instance | ||
function sanitizeSchema(schemaLike) { | ||
if (schemaLike instanceof apache_arrow_1.Schema) { | ||
if (schemaLike instanceof arrow_1.Schema) { | ||
return schemaLike; | ||
@@ -390,4 +403,4 @@ } | ||
const sanitizedFields = schemaLike.fields.map((field) => sanitizeField(field)); | ||
return new apache_arrow_1.Schema(sanitizedFields, metadata); | ||
return new arrow_1.Schema(sanitizedFields, metadata); | ||
} | ||
exports.sanitizeSchema = sanitizeSchema; |
@@ -1,3 +0,2 @@ | ||
import { Schema } from "apache-arrow"; | ||
import { Data } from "./arrow"; | ||
import { Data, Schema } from "./arrow"; | ||
import { IndexOptions } from "./indices"; | ||
@@ -4,0 +3,0 @@ import { AddColumnsSql, ColumnAlteration, IndexConfig, OptimizeStats, Table as _NativeTable } from "./native"; |
@@ -17,4 +17,4 @@ "use strict"; | ||
exports.Table = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const query_1 = require("./query"); | ||
@@ -60,3 +60,3 @@ /** | ||
const schemaBuf = await this.inner.schema(); | ||
const tbl = (0, apache_arrow_1.tableFromIPC)(schemaBuf); | ||
const tbl = (0, arrow_1.tableFromIPC)(schemaBuf); | ||
return tbl.schema; | ||
@@ -70,3 +70,6 @@ } | ||
const mode = options?.mode ?? "append"; | ||
const buffer = await (0, arrow_1.fromDataToBuffer)(data); | ||
const schema = await this.schema(); | ||
const registry = (0, registry_1.getRegistry)(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const buffer = await (0, arrow_1.fromDataToBuffer)(data, functions.values().next().value); | ||
await this.inner.add(buffer, mode); | ||
@@ -73,0 +76,0 @@ } |
@@ -20,6 +20,10 @@ // Copyright 2023 Lance Developers. | ||
Field, | ||
FixedSizeBinary, | ||
FixedSizeList, | ||
type Float, | ||
Float, | ||
Float32, | ||
Int, | ||
LargeBinary, | ||
List, | ||
Null, | ||
RecordBatch, | ||
@@ -38,4 +42,96 @@ RecordBatchFileWriter, | ||
import { type EmbeddingFunction } from "./embedding/embedding_function"; | ||
import { sanitizeSchema } from "./sanitize"; | ||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; | ||
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize"; | ||
export * from "apache-arrow"; | ||
export function isArrowTable(value: object): value is ArrowTable { | ||
if (value instanceof ArrowTable) return true; | ||
return "schema" in value && "batches" in value; | ||
} | ||
export function isDataType(value: unknown): value is DataType { | ||
return ( | ||
value instanceof DataType || | ||
DataType.isNull(value) || | ||
DataType.isInt(value) || | ||
DataType.isFloat(value) || | ||
DataType.isBinary(value) || | ||
DataType.isLargeBinary(value) || | ||
DataType.isUtf8(value) || | ||
DataType.isLargeUtf8(value) || | ||
DataType.isBool(value) || | ||
DataType.isDecimal(value) || | ||
DataType.isDate(value) || | ||
DataType.isTime(value) || | ||
DataType.isTimestamp(value) || | ||
DataType.isInterval(value) || | ||
DataType.isDuration(value) || | ||
DataType.isList(value) || | ||
DataType.isStruct(value) || | ||
DataType.isUnion(value) || | ||
DataType.isFixedSizeBinary(value) || | ||
DataType.isFixedSizeList(value) || | ||
DataType.isMap(value) || | ||
DataType.isDictionary(value) | ||
); | ||
} | ||
export function isNull(value: unknown): value is Null { | ||
return value instanceof Null || DataType.isNull(value); | ||
} | ||
export function isInt(value: unknown): value is Int { | ||
return value instanceof Int || DataType.isInt(value); | ||
} | ||
export function isFloat(value: unknown): value is Float { | ||
return value instanceof Float || DataType.isFloat(value); | ||
} | ||
export function isBinary(value: unknown): value is Binary { | ||
return value instanceof Binary || DataType.isBinary(value); | ||
} | ||
export function isLargeBinary(value: unknown): value is LargeBinary { | ||
return value instanceof LargeBinary || DataType.isLargeBinary(value); | ||
} | ||
export function isUtf8(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isUtf8(value); | ||
} | ||
export function isLargeUtf8(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isLargeUtf8(value); | ||
} | ||
export function isBool(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isBool(value); | ||
} | ||
export function isDecimal(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isDecimal(value); | ||
} | ||
export function isDate(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isDate(value); | ||
} | ||
export function isTime(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isTime(value); | ||
} | ||
export function isTimestamp(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isTimestamp(value); | ||
} | ||
export function isInterval(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isInterval(value); | ||
} | ||
export function isDuration(value: unknown): value is Utf8 { | ||
return value instanceof Utf8 || DataType.isDuration(value); | ||
} | ||
export function isList(value: unknown): value is List { | ||
return value instanceof List || DataType.isList(value); | ||
} | ||
export function isStruct(value: unknown): value is Struct { | ||
return value instanceof Struct || DataType.isStruct(value); | ||
} | ||
export function isUnion(value: unknown): value is Struct { | ||
return value instanceof Struct || DataType.isUnion(value); | ||
} | ||
export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary { | ||
return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value); | ||
} | ||
export function isFixedSizeList(value: unknown): value is FixedSizeList { | ||
return value instanceof FixedSizeList || DataType.isFixedSizeList(value); | ||
} | ||
/** Data type accepted by NodeJS SDK */ | ||
@@ -203,2 +299,3 @@ export type Data = Record<string, unknown>[] | ArrowTable; | ||
options?: Partial<MakeArrowTableOptions>, | ||
metadata?: Map<string, string>, | ||
): ArrowTable { | ||
@@ -296,9 +393,27 @@ if ( | ||
const batchesFixed = firstTable.batches.map( | ||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
(batch) => new RecordBatch(opt.schema!, batch.data), | ||
); | ||
return new ArrowTable(opt.schema, batchesFixed); | ||
} else { | ||
return new ArrowTable(columns); | ||
let schema: Schema; | ||
if (metadata !== undefined) { | ||
let schemaMetadata = opt.schema.metadata; | ||
if (schemaMetadata.size === 0) { | ||
schemaMetadata = metadata; | ||
} else { | ||
for (const [key, entry] of schemaMetadata.entries()) { | ||
schemaMetadata.set(key, entry); | ||
} | ||
} | ||
schema = new Schema(opt.schema.fields, schemaMetadata); | ||
} else { | ||
schema = opt.schema; | ||
} | ||
return new ArrowTable(schema, batchesFixed); | ||
} | ||
const tbl = new ArrowTable(columns); | ||
if (metadata !== undefined) { | ||
// biome-ignore lint/suspicious/noExplicitAny: <explanation> | ||
(<any>tbl.schema).metadata = metadata; | ||
} | ||
return tbl; | ||
} | ||
@@ -309,4 +424,7 @@ | ||
*/ | ||
export function makeEmptyTable(schema: Schema): ArrowTable { | ||
return makeArrowTable([], { schema }); | ||
export function makeEmptyTable( | ||
schema: Schema, | ||
metadata?: Map<string, string>, | ||
): ArrowTable { | ||
return makeArrowTable([], { schema }, metadata); | ||
} | ||
@@ -383,9 +501,71 @@ | ||
/** Helper function to apply embeddings from metadata to an input table */ | ||
async function applyEmbeddingsFromMetadata( | ||
table: ArrowTable, | ||
schema: Schema, | ||
): Promise<ArrowTable> { | ||
const registry = getRegistry(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const columns = Object.fromEntries( | ||
table.schema.fields.map((field) => [ | ||
field.name, | ||
table.getChild(field.name)!, | ||
]), | ||
); | ||
for (const functionEntry of functions.values()) { | ||
const sourceColumn = columns[functionEntry.sourceColumn]; | ||
const destColumn = functionEntry.vectorColumn ?? "vector"; | ||
if (sourceColumn === undefined) { | ||
throw new Error( | ||
`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`, | ||
); | ||
} | ||
if (columns[destColumn] !== undefined) { | ||
throw new Error( | ||
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`, | ||
); | ||
} | ||
if (table.batches.length > 1) { | ||
throw new Error( | ||
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch", | ||
); | ||
} | ||
const values = sourceColumn.toArray(); | ||
const vectors = | ||
await functionEntry.function.computeSourceEmbeddings(values); | ||
if (vectors.length !== values.length) { | ||
throw new Error( | ||
"Embedding function did not return an embedding for each input element", | ||
); | ||
} | ||
let destType: DataType; | ||
const dtype = schema.fields.find((f) => f.name === destColumn)!.type; | ||
if (isFixedSizeList(dtype)) { | ||
destType = sanitizeType(dtype); | ||
} else { | ||
throw new Error( | ||
"Expected FixedSizeList as datatype for vector field, instead got: " + | ||
dtype, | ||
); | ||
} | ||
const vector = makeVector(vectors, destType); | ||
columns[destColumn] = vector; | ||
} | ||
const newTable = new ArrowTable(columns); | ||
return alignTable(newTable, schema); | ||
} | ||
/** Helper function to apply embeddings to an input table */ | ||
async function applyEmbeddings<T>( | ||
table: ArrowTable, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
): Promise<ArrowTable> { | ||
if (embeddings == null) { | ||
if (schema?.metadata.has("embedding_functions")) { | ||
return applyEmbeddingsFromMetadata(table, schema!); | ||
} else if (embeddings == null || embeddings === undefined) { | ||
return table; | ||
@@ -408,4 +588,5 @@ } | ||
const sourceColumn = newColumns[embeddings.sourceColumn]; | ||
const destColumn = embeddings.destColumn ?? "vector"; | ||
const innerDestType = embeddings.embeddingDataType ?? new Float32(); | ||
const destColumn = embeddings.vectorColumn ?? "vector"; | ||
const innerDestType = | ||
embeddings.function.embeddingDataType() ?? new Float32(); | ||
if (sourceColumn === undefined) { | ||
@@ -424,7 +605,5 @@ throw new Error( | ||
} | ||
if (embeddings.embeddingDimension !== undefined) { | ||
const destType = newVectorType( | ||
embeddings.embeddingDimension, | ||
innerDestType, | ||
); | ||
const dimensions = embeddings.function.ndims(); | ||
if (dimensions !== undefined) { | ||
const destType = newVectorType(dimensions, innerDestType); | ||
newColumns[destColumn] = makeVector([], destType); | ||
@@ -457,3 +636,5 @@ } else if (schema != null) { | ||
const values = sourceColumn.toArray(); | ||
const vectors = await embeddings.embed(values as T[]); | ||
const vectors = await embeddings.function.computeSourceEmbeddings( | ||
values as T[], | ||
); | ||
if (vectors.length !== values.length) { | ||
@@ -498,5 +679,5 @@ throw new Error( | ||
*/ | ||
export async function convertToTable<T>( | ||
export async function convertToTable( | ||
data: Array<Record<string, unknown>>, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
makeTableOptions?: Partial<MakeArrowTableOptions>, | ||
@@ -509,3 +690,3 @@ ): Promise<ArrowTable> { | ||
/** Creates the Arrow Type for a Vector column with dimension `dim` */ | ||
function newVectorType<T extends Float>( | ||
export function newVectorType<T extends Float>( | ||
dim: number, | ||
@@ -516,3 +697,3 @@ innerType: T, | ||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements | ||
const children = new Field<T>("item", innerType, true); | ||
const children = new Field("item", <T>sanitizeType(innerType), true); | ||
return new FixedSizeList(dim, children); | ||
@@ -528,5 +709,5 @@ } | ||
*/ | ||
export async function fromRecordsToBuffer<T>( | ||
export async function fromRecordsToBuffer( | ||
data: Array<Record<string, unknown>>, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
@@ -549,5 +730,5 @@ ): Promise<Buffer> { | ||
*/ | ||
export async function fromRecordsToStreamBuffer<T>( | ||
export async function fromRecordsToStreamBuffer( | ||
data: Array<Record<string, unknown>>, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
@@ -571,5 +752,5 @@ ): Promise<Buffer> { | ||
*/ | ||
export async function fromTableToBuffer<T>( | ||
export async function fromTableToBuffer( | ||
table: ArrowTable, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
@@ -593,5 +774,5 @@ ): Promise<Buffer> { | ||
*/ | ||
export async function fromDataToBuffer<T>( | ||
export async function fromDataToBuffer( | ||
data: Data, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
@@ -602,7 +783,7 @@ ): Promise<Buffer> { | ||
} | ||
if (data instanceof ArrowTable) { | ||
if (isArrowTable(data)) { | ||
return fromTableToBuffer(data, embeddings, schema); | ||
} else { | ||
const table = await convertToTable(data); | ||
return fromTableToBuffer(table, embeddings, schema); | ||
const table = await convertToTable(data, embeddings, { schema }); | ||
return fromTableToBuffer(table); | ||
} | ||
@@ -619,5 +800,5 @@ } | ||
*/ | ||
export async function fromTableToStreamBuffer<T>( | ||
export async function fromTableToStreamBuffer( | ||
table: ArrowTable, | ||
embeddings?: EmbeddingFunction<T>, | ||
embeddings?: EmbeddingFunctionConfig, | ||
schema?: Schema, | ||
@@ -685,6 +866,21 @@ ): Promise<Buffer> { | ||
// if they are not, we throw an error | ||
for (const field of schema.fields) { | ||
if (field.type instanceof FixedSizeList) { | ||
for (let field of schema.fields) { | ||
if (isFixedSizeList(field.type)) { | ||
field = sanitizeField(field); | ||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { | ||
missingEmbeddingFields.push(field); | ||
if (schema.metadata.has("embedding_functions")) { | ||
const embeddings = JSON.parse( | ||
schema.metadata.get("embedding_functions")!, | ||
); | ||
if ( | ||
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f` | ||
embeddings.find((f: any) => f["vectorColumn"] === field.name) === | ||
undefined | ||
) { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} else { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} else { | ||
@@ -691,0 +887,0 @@ fields.push(field); |
@@ -15,4 +15,10 @@ // Copyright 2024 Lance Developers. | ||
import { Table as ArrowTable, Schema } from "apache-arrow"; | ||
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow"; | ||
import { Table as ArrowTable, Schema } from "./arrow"; | ||
import { | ||
fromTableToBuffer, | ||
isArrowTable, | ||
makeArrowTable, | ||
makeEmptyTable, | ||
} from "./arrow"; | ||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; | ||
import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; | ||
@@ -69,2 +75,4 @@ import { Table } from "./table"; | ||
storageOptions?: Record<string, string>; | ||
schema?: Schema; | ||
embeddingFunction?: EmbeddingFunctionConfig; | ||
} | ||
@@ -179,2 +187,3 @@ | ||
); | ||
return new Table(innerTable); | ||
@@ -202,8 +211,13 @@ } | ||
let table: ArrowTable; | ||
if (data instanceof ArrowTable) { | ||
if (isArrowTable(data)) { | ||
table = data; | ||
} else { | ||
table = makeArrowTable(data); | ||
table = makeArrowTable(data, options); | ||
} | ||
const buf = await fromTableToBuffer(table); | ||
const buf = await fromTableToBuffer( | ||
table, | ||
options?.embeddingFunction, | ||
options?.schema, | ||
); | ||
const innerTable = await this.inner.createTable( | ||
@@ -215,2 +229,3 @@ name, | ||
); | ||
return new Table(innerTable); | ||
@@ -235,4 +250,10 @@ } | ||
} | ||
let metadata: Map<string, string> | undefined = undefined; | ||
if (options?.embeddingFunction !== undefined) { | ||
const embeddingFunction = options.embeddingFunction; | ||
const registry = getRegistry(); | ||
metadata = registry.getTableMetadata([embeddingFunction]); | ||
} | ||
const table = makeEmptyTable(schema); | ||
const table = makeEmptyTable(schema, metadata); | ||
const buf = await fromTableToBuffer(table); | ||
@@ -239,0 +260,0 @@ const innerTable = await this.inner.createEmptyTable( |
@@ -1,2 +0,2 @@ | ||
// Copyright 2023 Lance Developers. | ||
// Copyright 2024 Lance Developers. | ||
// | ||
@@ -15,65 +15,149 @@ // Licensed under the Apache License, Version 2.0 (the "License"); | ||
import { type Float } from "apache-arrow"; | ||
import "reflect-metadata"; | ||
import { | ||
DataType, | ||
Field, | ||
FixedSizeList, | ||
Float, | ||
Float32, | ||
isDataType, | ||
isFixedSizeList, | ||
isFloat, | ||
newVectorType, | ||
} from "../arrow"; | ||
import { sanitizeType } from "../sanitize"; | ||
/** | ||
* Options for a given embedding function | ||
*/ | ||
export interface FunctionOptions { | ||
// biome-ignore lint/suspicious/noExplicitAny: options can be anything | ||
[key: string]: any; | ||
} | ||
/** | ||
* An embedding function that automatically creates vector representation for a given column. | ||
*/ | ||
export interface EmbeddingFunction<T> { | ||
export abstract class EmbeddingFunction< | ||
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do | ||
T = any, | ||
M extends FunctionOptions = FunctionOptions, | ||
> { | ||
/** | ||
* The name of the column that will be used as input for the Embedding Function. | ||
*/ | ||
sourceColumn: string; | ||
/** | ||
* The data type of the embedding | ||
* Convert the embedding function to a JSON object | ||
* It is used to serialize the embedding function to the schema | ||
* It's important that any object returned by this method contains all the necessary | ||
* information to recreate the embedding function | ||
* | ||
* The embedding function should return `number`. This will be converted into | ||
* an Arrow float array. By default this will be Float32 but this property can | ||
* be used to control the conversion. | ||
*/ | ||
embeddingDataType?: Float; | ||
/** | ||
* The dimension of the embedding | ||
* It should return the same object that was passed to the constructor | ||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly | ||
* | ||
* This is optional, normally this can be determined by looking at the results of | ||
* `embed`. If this is not specified, and there is an attempt to apply the embedding | ||
* to an empty table, then that process will fail. | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* constructor(options: {model: string, timeout: number}) { | ||
* super(); | ||
* this.model = options.model; | ||
* this.timeout = options.timeout; | ||
* } | ||
* toJSON() { | ||
* return { | ||
* model: this.model, | ||
* timeout: this.timeout, | ||
* }; | ||
* } | ||
* ``` | ||
*/ | ||
embeddingDimension?: number; | ||
abstract toJSON(): Partial<M>; | ||
/** | ||
* The name of the column that will contain the embedding | ||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default this is "vector" | ||
* @param optionsOrDatatype - The options for the field or the datatype | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
destColumn?: string; | ||
sourceField( | ||
optionsOrDatatype: Partial<FieldOptions> | DataType, | ||
): [DataType, Map<string, EmbeddingFunction>] { | ||
let datatype = isDataType(optionsOrDatatype) | ||
? optionsOrDatatype | ||
: optionsOrDatatype?.datatype; | ||
if (!datatype) { | ||
throw new Error("Datatype is required"); | ||
} | ||
datatype = sanitizeType(datatype); | ||
const metadata = new Map<string, EmbeddingFunction>(); | ||
metadata.set("source_column_for", this); | ||
return [datatype, metadata]; | ||
} | ||
/** | ||
* Should the source column be excluded from the resulting table | ||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default the source column is included. Set this to true and | ||
* only the embedding will be stored. | ||
* @param options - The options for the field | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
excludeSource?: boolean; | ||
vectorField( | ||
options?: Partial<FieldOptions>, | ||
): [DataType, Map<string, EmbeddingFunction>] { | ||
let dtype: DataType; | ||
const dims = this.ndims() ?? options?.dims; | ||
if (!options?.datatype) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true)); | ||
} else { | ||
if (isFixedSizeList(options.datatype)) { | ||
dtype = options.datatype; | ||
} else if (isFloat(options.datatype)) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = newVectorType(dims, options.datatype); | ||
} else { | ||
throw new Error( | ||
"Expected FixedSizeList or Float as datatype for vector field", | ||
); | ||
} | ||
} | ||
const metadata = new Map<string, EmbeddingFunction>(); | ||
metadata.set("vector_column_for", this); | ||
return [dtype, metadata]; | ||
} | ||
/** The number of dimensions of the embeddings */ | ||
ndims(): number | undefined { | ||
return undefined; | ||
} | ||
/** The datatype of the embeddings */ | ||
abstract embeddingDataType(): Float; | ||
/** | ||
* Creates a vector representation for the given values. | ||
*/ | ||
embed: (data: T[]) => Promise<number[][]>; | ||
} | ||
abstract computeSourceEmbeddings( | ||
data: T[], | ||
): Promise<number[][] | Float32Array[] | Float64Array[]>; | ||
/** Test if the input seems to be an embedding function */ | ||
export function isEmbeddingFunction<T>( | ||
value: unknown, | ||
): value is EmbeddingFunction<T> { | ||
if (typeof value !== "object" || value === null) { | ||
return false; | ||
/** | ||
Compute the embeddings for a single query | ||
*/ | ||
async computeQueryEmbeddings( | ||
data: T, | ||
): Promise<number[] | Float32Array | Float64Array> { | ||
return this.computeSourceEmbeddings([data]).then( | ||
(embeddings) => embeddings[0], | ||
); | ||
} | ||
if (!("sourceColumn" in value) || !("embed" in value)) { | ||
return false; | ||
} | ||
return ( | ||
typeof value.sourceColumn === "string" && typeof value.embed === "function" | ||
); | ||
} | ||
export interface FieldOptions<T extends DataType = DataType> { | ||
datatype: T; | ||
dims?: number; | ||
} |
@@ -1,2 +0,113 @@ | ||
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function"; | ||
export { OpenAIEmbeddingFunction } from "./openai"; | ||
// Copyright 2023 Lance Developers. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
import { DataType, Field, Schema } from "../arrow"; | ||
import { isDataType } from "../arrow"; | ||
import { sanitizeType } from "../sanitize"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
import { EmbeddingFunctionConfig, getRegistry } from "./registry"; | ||
export { EmbeddingFunction } from "./embedding_function"; | ||
// We need to explicitly export '*' so that the `register` decorator actually registers the class. | ||
export * from "./openai"; | ||
export * from "./registry"; | ||
/** | ||
* Create a schema with embedding functions. | ||
* | ||
* @param fields | ||
* @returns Schema | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* // ... | ||
* } | ||
* const func = new MyEmbeddingFunction(); | ||
* const schema = LanceSchema({ | ||
* id: new Int32(), | ||
* text: func.sourceField(new Utf8()), | ||
* vector: func.vectorField(), | ||
* // optional: specify the datatype and/or dimensions | ||
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}), | ||
* }); | ||
* | ||
* const table = await db.createTable("my_table", data, { schema }); | ||
* ``` | ||
*/ | ||
export function LanceSchema( | ||
fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>, | ||
): Schema { | ||
const arrowFields: Field[] = []; | ||
const embeddingFunctions = new Map< | ||
EmbeddingFunction, | ||
Partial<EmbeddingFunctionConfig> | ||
>(); | ||
Object.entries(fields).forEach(([key, value]) => { | ||
if (isDataType(value)) { | ||
arrowFields.push(new Field(key, sanitizeType(value), true)); | ||
} else { | ||
const [dtype, metadata] = value as [ | ||
object, | ||
Map<string, EmbeddingFunction>, | ||
]; | ||
arrowFields.push(new Field(key, sanitizeType(dtype), true)); | ||
parseEmbeddingFunctions(embeddingFunctions, key, metadata); | ||
} | ||
}); | ||
const registry = getRegistry(); | ||
const metadata = registry.getTableMetadata( | ||
Array.from(embeddingFunctions.values()) as EmbeddingFunctionConfig[], | ||
); | ||
const schema = new Schema(arrowFields, metadata); | ||
return schema; | ||
} | ||
function parseEmbeddingFunctions( | ||
embeddingFunctions: Map<EmbeddingFunction, Partial<EmbeddingFunctionConfig>>, | ||
key: string, | ||
metadata: Map<string, EmbeddingFunction>, | ||
): void { | ||
if (metadata.has("source_column_for")) { | ||
const embedFunction = metadata.get("source_column_for")!; | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
sourceColumn: key, | ||
}); | ||
} else { | ||
embeddingFunctions.set(embedFunction, { | ||
sourceColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} else if (metadata.has("vector_column_for")) { | ||
const embedFunction = metadata.get("vector_column_for")!; | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
vectorColumn: key, | ||
}); | ||
} else { | ||
embeddingFunctions.set(embedFunction, { | ||
vectorColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} | ||
} |
@@ -16,13 +16,27 @@ // Copyright 2023 Lance Developers. | ||
import type OpenAI from "openai"; | ||
import { type EmbeddingFunction } from "./embedding_function"; | ||
import { Float, Float32 } from "../arrow"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
import { register } from "./registry"; | ||
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> { | ||
private readonly _openai: OpenAI; | ||
private readonly _modelName: string; | ||
export type OpenAIOptions = { | ||
apiKey?: string; | ||
model?: string; | ||
}; | ||
constructor( | ||
sourceColumn: string, | ||
openAIKey: string, | ||
modelName: string = "text-embedding-ada-002", | ||
) { | ||
@register("openai") | ||
export class OpenAIEmbeddingFunction extends EmbeddingFunction< | ||
string, | ||
OpenAIOptions | ||
> { | ||
#openai: OpenAI; | ||
#modelName: string; | ||
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) { | ||
super(); | ||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; | ||
if (!openAIKey) { | ||
throw new Error("OpenAI API key is required"); | ||
} | ||
const modelName = options?.model ?? "text-embedding-ada-002"; | ||
/** | ||
@@ -40,3 +54,2 @@ * @type {import("openai").default} | ||
this.sourceColumn = sourceColumn; | ||
const configuration = { | ||
@@ -46,9 +59,32 @@ apiKey: openAIKey, | ||
this._openai = new Openai(configuration); | ||
this._modelName = modelName; | ||
this.#openai = new Openai(configuration); | ||
this.#modelName = modelName; | ||
} | ||
async embed(data: string[]): Promise<number[][]> { | ||
const response = await this._openai.embeddings.create({ | ||
model: this._modelName, | ||
toJSON() { | ||
return { | ||
model: this.#modelName, | ||
}; | ||
} | ||
ndims(): number { | ||
switch (this.#modelName) { | ||
case "text-embedding-ada-002": | ||
return 1536; | ||
case "text-embedding-3-large": | ||
return 3072; | ||
case "text-embedding-3-small": | ||
return 1536; | ||
default: | ||
return null as never; | ||
} | ||
} | ||
embeddingDataType(): Float { | ||
return new Float32(); | ||
} | ||
async computeSourceEmbeddings(data: string[]): Promise<number[][]> { | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
@@ -64,3 +100,13 @@ }); | ||
sourceColumn: string; | ||
async computeQueryEmbeddings(data: string): Promise<number[]> { | ||
if (typeof data !== "string") { | ||
throw new Error("Data must be a string"); | ||
} | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
}); | ||
return response.data[0].embedding; | ||
} | ||
} |
@@ -15,3 +15,3 @@ // Copyright 2024 Lance Developers. | ||
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow"; | ||
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow"; | ||
import { type IvfPqOptions } from "./indices"; | ||
@@ -174,2 +174,3 @@ import { | ||
const tbl = await this.toArrow(); | ||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return | ||
@@ -176,0 +177,0 @@ return tbl.toArray(); |
@@ -23,2 +23,3 @@ // Copyright 2023 LanceDB Developers. | ||
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type"; | ||
import { | ||
@@ -79,6 +80,5 @@ Binary, | ||
Utf8, | ||
} from "apache-arrow"; | ||
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type"; | ||
} from "./arrow"; | ||
function sanitizeMetadata( | ||
export function sanitizeMetadata( | ||
metadataLike?: unknown, | ||
@@ -102,3 +102,3 @@ ): Map<string, string> | undefined { | ||
function sanitizeInt(typeLike: object) { | ||
export function sanitizeInt(typeLike: object) { | ||
if ( | ||
@@ -117,3 +117,3 @@ !("bitWidth" in typeLike) || | ||
function sanitizeFloat(typeLike: object) { | ||
export function sanitizeFloat(typeLike: object) { | ||
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") { | ||
@@ -125,3 +125,3 @@ throw Error("Expected a Float Type to have a `precision` property"); | ||
function sanitizeDecimal(typeLike: object) { | ||
export function sanitizeDecimal(typeLike: object) { | ||
if ( | ||
@@ -142,3 +142,3 @@ !("scale" in typeLike) || | ||
function sanitizeDate(typeLike: object) { | ||
export function sanitizeDate(typeLike: object) { | ||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
@@ -150,3 +150,3 @@ throw Error("Expected a Date type to have a `unit` property"); | ||
function sanitizeTime(typeLike: object) { | ||
export function sanitizeTime(typeLike: object) { | ||
if ( | ||
@@ -165,3 +165,3 @@ !("unit" in typeLike) || | ||
function sanitizeTimestamp(typeLike: object) { | ||
export function sanitizeTimestamp(typeLike: object) { | ||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
@@ -177,3 +177,3 @@ throw Error("Expected a Timestamp type to have a `unit` property"); | ||
function sanitizeTypedTimestamp( | ||
export function sanitizeTypedTimestamp( | ||
typeLike: object, | ||
@@ -194,3 +194,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
function sanitizeInterval(typeLike: object) { | ||
export function sanitizeInterval(typeLike: object) { | ||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
@@ -202,3 +202,3 @@ throw Error("Expected an Interval type to have a `unit` property"); | ||
function sanitizeList(typeLike: object) { | ||
export function sanitizeList(typeLike: object) { | ||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
@@ -215,3 +215,3 @@ throw Error( | ||
function sanitizeStruct(typeLike: object) { | ||
export function sanitizeStruct(typeLike: object) { | ||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
@@ -225,3 +225,3 @@ throw Error( | ||
function sanitizeUnion(typeLike: object) { | ||
export function sanitizeUnion(typeLike: object) { | ||
if ( | ||
@@ -250,3 +250,3 @@ !("typeIds" in typeLike) || | ||
function sanitizeTypedUnion( | ||
export function sanitizeTypedUnion( | ||
typeLike: object, | ||
@@ -273,3 +273,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
function sanitizeFixedSizeBinary(typeLike: object) { | ||
export function sanitizeFixedSizeBinary(typeLike: object) { | ||
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { | ||
@@ -283,3 +283,3 @@ throw Error( | ||
function sanitizeFixedSizeList(typeLike: object) { | ||
export function sanitizeFixedSizeList(typeLike: object) { | ||
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") { | ||
@@ -302,3 +302,3 @@ throw Error("Expected a FixedSizeList type to have a `listSize` property"); | ||
function sanitizeMap(typeLike: object) { | ||
export function sanitizeMap(typeLike: object) { | ||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
@@ -320,3 +320,3 @@ throw Error( | ||
function sanitizeDuration(typeLike: object) { | ||
export function sanitizeDuration(typeLike: object) { | ||
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
@@ -328,3 +328,3 @@ throw Error("Expected a Duration type to have a `unit` property"); | ||
function sanitizeDictionary(typeLike: object) { | ||
export function sanitizeDictionary(typeLike: object) { | ||
if (!("id" in typeLike) || typeof typeLike.id !== "number") { | ||
@@ -351,3 +351,3 @@ throw Error("Expected a Dictionary type to have an `id` property"); | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
function sanitizeType(typeLike: unknown): DataType<any> { | ||
export function sanitizeType(typeLike: unknown): DataType<any> { | ||
if (typeof typeLike !== "object" || typeLike === null) { | ||
@@ -472,3 +472,3 @@ throw Error("Expected a Type but object was null/undefined"); | ||
function sanitizeField(fieldLike: unknown): Field { | ||
export function sanitizeField(fieldLike: unknown): Field { | ||
if (fieldLike instanceof Field) { | ||
@@ -475,0 +475,0 @@ return fieldLike; |
@@ -15,4 +15,5 @@ // Copyright 2024 Lance Developers. | ||
import { Schema, tableFromIPC } from "apache-arrow"; | ||
import { Data, fromDataToBuffer } from "./arrow"; | ||
import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow"; | ||
import { getRegistry } from "./embedding/registry"; | ||
import { IndexOptions } from "./indices"; | ||
@@ -126,4 +127,10 @@ import { | ||
const mode = options?.mode ?? "append"; | ||
const schema = await this.schema(); | ||
const registry = getRegistry(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const buffer = await fromDataToBuffer(data); | ||
const buffer = await fromDataToBuffer( | ||
data, | ||
functions.values().next().value, | ||
); | ||
await this.inner.add(buffer, mode); | ||
@@ -130,0 +137,0 @@ } |
/// <reference types="node" /> | ||
import { Table as ArrowTable, type Float, Schema } from "apache-arrow"; | ||
import { Table as ArrowTable, Binary, DataType, FixedSizeBinary, FixedSizeList, Float, Int, LargeBinary, List, Null, Schema, Struct, Utf8 } from "apache-arrow"; | ||
import { type EmbeddingFunction } from "./embedding/embedding_function"; | ||
import { EmbeddingFunctionConfig } from "./embedding/registry"; | ||
export * from "apache-arrow"; | ||
export declare function isArrowTable(value: object): value is ArrowTable; | ||
export declare function isDataType(value: unknown): value is DataType; | ||
export declare function isNull(value: unknown): value is Null; | ||
export declare function isInt(value: unknown): value is Int; | ||
export declare function isFloat(value: unknown): value is Float; | ||
export declare function isBinary(value: unknown): value is Binary; | ||
export declare function isLargeBinary(value: unknown): value is LargeBinary; | ||
export declare function isUtf8(value: unknown): value is Utf8; | ||
export declare function isLargeUtf8(value: unknown): value is Utf8; | ||
export declare function isBool(value: unknown): value is Utf8; | ||
export declare function isDecimal(value: unknown): value is Utf8; | ||
export declare function isDate(value: unknown): value is Utf8; | ||
export declare function isTime(value: unknown): value is Utf8; | ||
export declare function isTimestamp(value: unknown): value is Utf8; | ||
export declare function isInterval(value: unknown): value is Utf8; | ||
export declare function isDuration(value: unknown): value is Utf8; | ||
export declare function isList(value: unknown): value is List; | ||
export declare function isStruct(value: unknown): value is Struct; | ||
export declare function isUnion(value: unknown): value is Struct; | ||
export declare function isFixedSizeBinary(value: unknown): value is FixedSizeBinary; | ||
export declare function isFixedSizeList(value: unknown): value is FixedSizeList; | ||
/** Data type accepted by NodeJS SDK */ | ||
@@ -120,7 +143,7 @@ export type Data = Record<string, unknown>[] | ArrowTable; | ||
*/ | ||
export declare function makeArrowTable(data: Array<Record<string, unknown>>, options?: Partial<MakeArrowTableOptions>): ArrowTable; | ||
export declare function makeArrowTable(data: Array<Record<string, unknown>>, options?: Partial<MakeArrowTableOptions>, metadata?: Map<string, string>): ArrowTable; | ||
/** | ||
* Create an empty Arrow table with the provided schema | ||
*/ | ||
export declare function makeEmptyTable(schema: Schema): ArrowTable; | ||
export declare function makeEmptyTable(schema: Schema, metadata?: Map<string, string>): ArrowTable; | ||
/** | ||
@@ -144,3 +167,5 @@ * Convert an Array of records into an Arrow Table, optionally applying an | ||
*/ | ||
export declare function convertToTable<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, makeTableOptions?: Partial<MakeArrowTableOptions>): Promise<ArrowTable>; | ||
export declare function convertToTable(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, makeTableOptions?: Partial<MakeArrowTableOptions>): Promise<ArrowTable>; | ||
/** Creates the Arrow Type for a Vector column with dimension `dim` */ | ||
export declare function newVectorType<T extends Float>(dim: number, innerType: T): FixedSizeList<T>; | ||
/** | ||
@@ -153,3 +178,3 @@ * Serialize an Array of records into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromRecordsToBuffer<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromRecordsToBuffer(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -162,3 +187,3 @@ * Serialize an Array of records into a buffer using the Arrow IPC Stream serialization | ||
*/ | ||
export declare function fromRecordsToStreamBuffer<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromRecordsToStreamBuffer(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -172,3 +197,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromTableToBuffer<T>(table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromTableToBuffer(table: ArrowTable, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -182,3 +207,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC File serialization | ||
*/ | ||
export declare function fromDataToBuffer<T>(data: Data, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromDataToBuffer(data: Data, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -192,3 +217,3 @@ * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization | ||
*/ | ||
export declare function fromTableToStreamBuffer<T>(table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer>; | ||
export declare function fromTableToStreamBuffer(table: ArrowTable, embeddings?: EmbeddingFunctionConfig, schema?: Schema): Promise<Buffer>; | ||
/** | ||
@@ -195,0 +220,0 @@ * Create an empty table with the given schema |
@@ -15,6 +15,129 @@ "use strict"; | ||
// limitations under the License. | ||
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
var desc = Object.getOwnPropertyDescriptor(m, k); | ||
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { | ||
desc = { enumerable: true, get: function() { return m[k]; } }; | ||
} | ||
Object.defineProperty(o, k2, desc); | ||
}) : (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
o[k2] = m[k]; | ||
})); | ||
var __exportStar = (this && this.__exportStar) || function(m, exports) { | ||
for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.createEmptyTable = exports.fromTableToStreamBuffer = exports.fromDataToBuffer = exports.fromTableToBuffer = exports.fromRecordsToStreamBuffer = exports.fromRecordsToBuffer = exports.convertToTable = exports.makeEmptyTable = exports.makeArrowTable = exports.MakeArrowTableOptions = exports.VectorColumnOptions = void 0; | ||
exports.createEmptyTable = exports.fromTableToStreamBuffer = exports.fromDataToBuffer = exports.fromTableToBuffer = exports.fromRecordsToStreamBuffer = exports.fromRecordsToBuffer = exports.newVectorType = exports.convertToTable = exports.makeEmptyTable = exports.makeArrowTable = exports.MakeArrowTableOptions = exports.VectorColumnOptions = exports.isFixedSizeList = exports.isFixedSizeBinary = exports.isUnion = exports.isStruct = exports.isList = exports.isDuration = exports.isInterval = exports.isTimestamp = exports.isTime = exports.isDate = exports.isDecimal = exports.isBool = exports.isLargeUtf8 = exports.isUtf8 = exports.isLargeBinary = exports.isBinary = exports.isFloat = exports.isInt = exports.isNull = exports.isDataType = exports.isArrowTable = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const sanitize_1 = require("./sanitize"); | ||
__exportStar(require("apache-arrow"), exports); | ||
function isArrowTable(value) { | ||
if (value instanceof apache_arrow_1.Table) | ||
return true; | ||
return "schema" in value && "batches" in value; | ||
} | ||
exports.isArrowTable = isArrowTable; | ||
function isDataType(value) { | ||
return (value instanceof apache_arrow_1.DataType || | ||
apache_arrow_1.DataType.isNull(value) || | ||
apache_arrow_1.DataType.isInt(value) || | ||
apache_arrow_1.DataType.isFloat(value) || | ||
apache_arrow_1.DataType.isBinary(value) || | ||
apache_arrow_1.DataType.isLargeBinary(value) || | ||
apache_arrow_1.DataType.isUtf8(value) || | ||
apache_arrow_1.DataType.isLargeUtf8(value) || | ||
apache_arrow_1.DataType.isBool(value) || | ||
apache_arrow_1.DataType.isDecimal(value) || | ||
apache_arrow_1.DataType.isDate(value) || | ||
apache_arrow_1.DataType.isTime(value) || | ||
apache_arrow_1.DataType.isTimestamp(value) || | ||
apache_arrow_1.DataType.isInterval(value) || | ||
apache_arrow_1.DataType.isDuration(value) || | ||
apache_arrow_1.DataType.isList(value) || | ||
apache_arrow_1.DataType.isStruct(value) || | ||
apache_arrow_1.DataType.isUnion(value) || | ||
apache_arrow_1.DataType.isFixedSizeBinary(value) || | ||
apache_arrow_1.DataType.isFixedSizeList(value) || | ||
apache_arrow_1.DataType.isMap(value) || | ||
apache_arrow_1.DataType.isDictionary(value)); | ||
} | ||
exports.isDataType = isDataType; | ||
function isNull(value) { | ||
return value instanceof apache_arrow_1.Null || apache_arrow_1.DataType.isNull(value); | ||
} | ||
exports.isNull = isNull; | ||
function isInt(value) { | ||
return value instanceof apache_arrow_1.Int || apache_arrow_1.DataType.isInt(value); | ||
} | ||
exports.isInt = isInt; | ||
function isFloat(value) { | ||
return value instanceof apache_arrow_1.Float || apache_arrow_1.DataType.isFloat(value); | ||
} | ||
exports.isFloat = isFloat; | ||
function isBinary(value) { | ||
return value instanceof apache_arrow_1.Binary || apache_arrow_1.DataType.isBinary(value); | ||
} | ||
exports.isBinary = isBinary; | ||
function isLargeBinary(value) { | ||
return value instanceof apache_arrow_1.LargeBinary || apache_arrow_1.DataType.isLargeBinary(value); | ||
} | ||
exports.isLargeBinary = isLargeBinary; | ||
function isUtf8(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isUtf8(value); | ||
} | ||
exports.isUtf8 = isUtf8; | ||
function isLargeUtf8(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isLargeUtf8(value); | ||
} | ||
exports.isLargeUtf8 = isLargeUtf8; | ||
function isBool(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isBool(value); | ||
} | ||
exports.isBool = isBool; | ||
function isDecimal(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDecimal(value); | ||
} | ||
exports.isDecimal = isDecimal; | ||
function isDate(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDate(value); | ||
} | ||
exports.isDate = isDate; | ||
function isTime(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isTime(value); | ||
} | ||
exports.isTime = isTime; | ||
function isTimestamp(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isTimestamp(value); | ||
} | ||
exports.isTimestamp = isTimestamp; | ||
function isInterval(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isInterval(value); | ||
} | ||
exports.isInterval = isInterval; | ||
function isDuration(value) { | ||
return value instanceof apache_arrow_1.Utf8 || apache_arrow_1.DataType.isDuration(value); | ||
} | ||
exports.isDuration = isDuration; | ||
function isList(value) { | ||
return value instanceof apache_arrow_1.List || apache_arrow_1.DataType.isList(value); | ||
} | ||
exports.isList = isList; | ||
function isStruct(value) { | ||
return value instanceof apache_arrow_1.Struct || apache_arrow_1.DataType.isStruct(value); | ||
} | ||
exports.isStruct = isStruct; | ||
function isUnion(value) { | ||
return value instanceof apache_arrow_1.Struct || apache_arrow_1.DataType.isUnion(value); | ||
} | ||
exports.isUnion = isUnion; | ||
function isFixedSizeBinary(value) { | ||
return value instanceof apache_arrow_1.FixedSizeBinary || apache_arrow_1.DataType.isFixedSizeBinary(value); | ||
} | ||
exports.isFixedSizeBinary = isFixedSizeBinary; | ||
function isFixedSizeList(value) { | ||
return value instanceof apache_arrow_1.FixedSizeList || apache_arrow_1.DataType.isFixedSizeList(value); | ||
} | ||
exports.isFixedSizeList = isFixedSizeList; | ||
/* | ||
@@ -172,3 +295,3 @@ * Options to control how a column should be converted to a vector array | ||
*/ | ||
function makeArrowTable(data, options) { | ||
function makeArrowTable(data, options, metadata) { | ||
if (data.length === 0 && | ||
@@ -251,10 +374,27 @@ (options?.schema === undefined || options?.schema === null)) { | ||
const firstTable = new apache_arrow_1.Table(columns); | ||
const batchesFixed = firstTable.batches.map( | ||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
(batch) => new apache_arrow_1.RecordBatch(opt.schema, batch.data)); | ||
return new apache_arrow_1.Table(opt.schema, batchesFixed); | ||
const batchesFixed = firstTable.batches.map((batch) => new apache_arrow_1.RecordBatch(opt.schema, batch.data)); | ||
let schema; | ||
if (metadata !== undefined) { | ||
let schemaMetadata = opt.schema.metadata; | ||
if (schemaMetadata.size === 0) { | ||
schemaMetadata = metadata; | ||
} | ||
else { | ||
for (const [key, entry] of schemaMetadata.entries()) { | ||
schemaMetadata.set(key, entry); | ||
} | ||
} | ||
schema = new apache_arrow_1.Schema(opt.schema.fields, schemaMetadata); | ||
} | ||
else { | ||
schema = opt.schema; | ||
} | ||
return new apache_arrow_1.Table(schema, batchesFixed); | ||
} | ||
else { | ||
return new apache_arrow_1.Table(columns); | ||
const tbl = new apache_arrow_1.Table(columns); | ||
if (metadata !== undefined) { | ||
// biome-ignore lint/suspicious/noExplicitAny: <explanation> | ||
tbl.schema.metadata = metadata; | ||
} | ||
return tbl; | ||
} | ||
@@ -265,4 +405,4 @@ exports.makeArrowTable = makeArrowTable; | ||
*/ | ||
function makeEmptyTable(schema) { | ||
return makeArrowTable([], { schema }); | ||
function makeEmptyTable(schema, metadata) { | ||
return makeArrowTable([], { schema }, metadata); | ||
} | ||
@@ -329,5 +469,48 @@ exports.makeEmptyTable = makeEmptyTable; | ||
} | ||
/** Helper function to apply embeddings from metadata to an input table */ | ||
async function applyEmbeddingsFromMetadata(table, schema) { | ||
const registry = (0, registry_1.getRegistry)(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const columns = Object.fromEntries(table.schema.fields.map((field) => [ | ||
field.name, | ||
table.getChild(field.name), | ||
])); | ||
for (const functionEntry of functions.values()) { | ||
const sourceColumn = columns[functionEntry.sourceColumn]; | ||
const destColumn = functionEntry.vectorColumn ?? "vector"; | ||
if (sourceColumn === undefined) { | ||
throw new Error(`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`); | ||
} | ||
if (columns[destColumn] !== undefined) { | ||
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`); | ||
} | ||
if (table.batches.length > 1) { | ||
throw new Error("Internal error: `makeArrowTable` unexpectedly created a table with more than one batch"); | ||
} | ||
const values = sourceColumn.toArray(); | ||
const vectors = await functionEntry.function.computeSourceEmbeddings(values); | ||
if (vectors.length !== values.length) { | ||
throw new Error("Embedding function did not return an embedding for each input element"); | ||
} | ||
let destType; | ||
const dtype = schema.fields.find((f) => f.name === destColumn).type; | ||
if (isFixedSizeList(dtype)) { | ||
destType = (0, sanitize_1.sanitizeType)(dtype); | ||
} | ||
else { | ||
throw new Error("Expected FixedSizeList as datatype for vector field, instead got: " + | ||
dtype); | ||
} | ||
const vector = makeVector(vectors, destType); | ||
columns[destColumn] = vector; | ||
} | ||
const newTable = new apache_arrow_1.Table(columns); | ||
return alignTable(newTable, schema); | ||
} | ||
/** Helper function to apply embeddings to an input table */ | ||
async function applyEmbeddings(table, embeddings, schema) { | ||
if (embeddings == null) { | ||
if (schema?.metadata.has("embedding_functions")) { | ||
return applyEmbeddingsFromMetadata(table, schema); | ||
} | ||
else if (embeddings == null || embeddings === undefined) { | ||
return table; | ||
@@ -347,4 +530,4 @@ } | ||
const sourceColumn = newColumns[embeddings.sourceColumn]; | ||
const destColumn = embeddings.destColumn ?? "vector"; | ||
const innerDestType = embeddings.embeddingDataType ?? new apache_arrow_1.Float32(); | ||
const destColumn = embeddings.vectorColumn ?? "vector"; | ||
const innerDestType = embeddings.function.embeddingDataType() ?? new apache_arrow_1.Float32(); | ||
if (sourceColumn === undefined) { | ||
@@ -360,4 +543,5 @@ throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`); | ||
} | ||
if (embeddings.embeddingDimension !== undefined) { | ||
const destType = newVectorType(embeddings.embeddingDimension, innerDestType); | ||
const dimensions = embeddings.function.ndims(); | ||
if (dimensions !== undefined) { | ||
const destType = newVectorType(dimensions, innerDestType); | ||
newColumns[destColumn] = makeVector([], destType); | ||
@@ -386,3 +570,3 @@ } | ||
const values = sourceColumn.toArray(); | ||
const vectors = await embeddings.embed(values); | ||
const vectors = await embeddings.function.computeSourceEmbeddings(values); | ||
if (vectors.length !== values.length) { | ||
@@ -430,5 +614,6 @@ throw new Error("Embedding function did not return an embedding for each input element"); | ||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements | ||
const children = new apache_arrow_1.Field("item", innerType, true); | ||
const children = new apache_arrow_1.Field("item", (0, sanitize_1.sanitizeType)(innerType), true); | ||
return new apache_arrow_1.FixedSizeList(dim, children); | ||
} | ||
exports.newVectorType = newVectorType; | ||
/** | ||
@@ -495,8 +680,8 @@ * Serialize an Array of records into a buffer using the Arrow IPC File serialization | ||
} | ||
if (data instanceof apache_arrow_1.Table) { | ||
if (isArrowTable(data)) { | ||
return fromTableToBuffer(data, embeddings, schema); | ||
} | ||
else { | ||
const table = await convertToTable(data); | ||
return fromTableToBuffer(table, embeddings, schema); | ||
const table = await convertToTable(data, embeddings, { schema }); | ||
return fromTableToBuffer(table); | ||
} | ||
@@ -561,6 +746,18 @@ } | ||
// if they are not, we throw an error | ||
for (const field of schema.fields) { | ||
if (field.type instanceof apache_arrow_1.FixedSizeList) { | ||
for (let field of schema.fields) { | ||
if (isFixedSizeList(field.type)) { | ||
field = (0, sanitize_1.sanitizeField)(field); | ||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { | ||
missingEmbeddingFields.push(field); | ||
if (schema.metadata.has("embedding_functions")) { | ||
const embeddings = JSON.parse(schema.metadata.get("embedding_functions")); | ||
if ( | ||
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f` | ||
embeddings.find((f) => f["vectorColumn"] === field.name) === | ||
undefined) { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} | ||
else { | ||
missingEmbeddingFields.push(field); | ||
} | ||
} | ||
@@ -567,0 +764,0 @@ else { |
@@ -1,2 +0,3 @@ | ||
import { Table as ArrowTable, Schema } from "apache-arrow"; | ||
import { Table as ArrowTable, Schema } from "./arrow"; | ||
import { EmbeddingFunctionConfig } from "./embedding/registry"; | ||
import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; | ||
@@ -42,2 +43,4 @@ import { Table } from "./table"; | ||
storageOptions?: Record<string, string>; | ||
schema?: Schema; | ||
embeddingFunction?: EmbeddingFunctionConfig; | ||
} | ||
@@ -44,0 +47,0 @@ export interface OpenTableOptions { |
@@ -17,4 +17,4 @@ "use strict"; | ||
exports.Connection = exports.connect = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const native_1 = require("./native"); | ||
@@ -113,9 +113,9 @@ const table_1 = require("./table"); | ||
let table; | ||
if (data instanceof apache_arrow_1.Table) { | ||
if ((0, arrow_1.isArrowTable)(data)) { | ||
table = data; | ||
} | ||
else { | ||
table = (0, arrow_1.makeArrowTable)(data); | ||
table = (0, arrow_1.makeArrowTable)(data, options); | ||
} | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table); | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table, options?.embeddingFunction, options?.schema); | ||
const innerTable = await this.inner.createTable(name, buf, mode, cleanseStorageOptions(options?.storageOptions)); | ||
@@ -135,3 +135,9 @@ return new table_1.Table(innerTable); | ||
} | ||
const table = (0, arrow_1.makeEmptyTable)(schema); | ||
let metadata = undefined; | ||
if (options?.embeddingFunction !== undefined) { | ||
const embeddingFunction = options.embeddingFunction; | ||
const registry = (0, registry_1.getRegistry)(); | ||
metadata = registry.getTableMetadata([embeddingFunction]); | ||
} | ||
const table = (0, arrow_1.makeEmptyTable)(schema, metadata); | ||
const buf = await (0, arrow_1.fromTableToBuffer)(table); | ||
@@ -138,0 +144,0 @@ const innerTable = await this.inner.createEmptyTable(name, buf, mode, cleanseStorageOptions(options?.storageOptions)); |
@@ -1,45 +0,71 @@ | ||
import { type Float } from "apache-arrow"; | ||
import "reflect-metadata"; | ||
import { DataType, Float } from "../arrow"; | ||
/** | ||
* Options for a given embedding function | ||
*/ | ||
export interface FunctionOptions { | ||
[key: string]: any; | ||
} | ||
/** | ||
* An embedding function that automatically creates vector representation for a given column. | ||
*/ | ||
export interface EmbeddingFunction<T> { | ||
export declare abstract class EmbeddingFunction<T = any, M extends FunctionOptions = FunctionOptions> { | ||
/** | ||
* The name of the column that will be used as input for the Embedding Function. | ||
*/ | ||
sourceColumn: string; | ||
/** | ||
* The data type of the embedding | ||
* Convert the embedding function to a JSON object | ||
* It is used to serialize the embedding function to the schema | ||
* It's important that any object returned by this method contains all the necessary | ||
* information to recreate the embedding function | ||
* | ||
* The embedding function should return `number`. This will be converted into | ||
* an Arrow float array. By default this will be Float32 but this property can | ||
* be used to control the conversion. | ||
*/ | ||
embeddingDataType?: Float; | ||
/** | ||
* The dimension of the embedding | ||
* It should return the same object that was passed to the constructor | ||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly | ||
* | ||
* This is optional, normally this can be determined by looking at the results of | ||
* `embed`. If this is not specified, and there is an attempt to apply the embedding | ||
* to an empty table, then that process will fail. | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* constructor(options: {model: string, timeout: number}) { | ||
* super(); | ||
* this.model = options.model; | ||
* this.timeout = options.timeout; | ||
* } | ||
* toJSON() { | ||
* return { | ||
* model: this.model, | ||
* timeout: this.timeout, | ||
* }; | ||
* } | ||
* ``` | ||
*/ | ||
embeddingDimension?: number; | ||
abstract toJSON(): Partial<M>; | ||
/** | ||
* The name of the column that will contain the embedding | ||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default this is "vector" | ||
* @param optionsOrDatatype - The options for the field or the datatype | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
destColumn?: string; | ||
sourceField(optionsOrDatatype: Partial<FieldOptions> | DataType): [DataType, Map<string, EmbeddingFunction>]; | ||
/** | ||
* Should the source column be excluded from the resulting table | ||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* By default the source column is included. Set this to true and | ||
* only the embedding will be stored. | ||
* @param options - The options for the field | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
excludeSource?: boolean; | ||
vectorField(options?: Partial<FieldOptions>): [DataType, Map<string, EmbeddingFunction>]; | ||
/** The number of dimensions of the embeddings */ | ||
ndims(): number | undefined; | ||
/** The datatype of the embeddings */ | ||
abstract embeddingDataType(): Float; | ||
/** | ||
* Creates a vector representation for the given values. | ||
*/ | ||
embed: (data: T[]) => Promise<number[][]>; | ||
abstract computeSourceEmbeddings(data: T[]): Promise<number[][] | Float32Array[] | Float64Array[]>; | ||
/** | ||
Compute the embeddings for a single query | ||
*/ | ||
computeQueryEmbeddings(data: T): Promise<number[] | Float32Array | Float64Array>; | ||
} | ||
/** Test if the input seems to be an embedding function */ | ||
export declare function isEmbeddingFunction<T>(value: unknown): value is EmbeddingFunction<T>; | ||
export interface FieldOptions<T extends DataType = DataType> { | ||
datatype: T; | ||
dims?: number; | ||
} |
"use strict"; | ||
// Copyright 2023 Lance Developers. | ||
// Copyright 2024 Lance Developers. | ||
// | ||
@@ -16,13 +16,74 @@ // Licensed under the Apache License, Version 2.0 (the "License"); | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.isEmbeddingFunction = void 0; | ||
/** Test if the input seems to be an embedding function */ | ||
function isEmbeddingFunction(value) { | ||
if (typeof value !== "object" || value === null) { | ||
return false; | ||
exports.EmbeddingFunction = void 0; | ||
require("reflect-metadata"); | ||
const arrow_1 = require("../arrow"); | ||
const sanitize_1 = require("../sanitize"); | ||
/** | ||
* An embedding function that automatically creates vector representation for a given column. | ||
*/ | ||
class EmbeddingFunction { | ||
/** | ||
* sourceField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* @param optionsOrDatatype - The options for the field or the datatype | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
sourceField(optionsOrDatatype) { | ||
let datatype = (0, arrow_1.isDataType)(optionsOrDatatype) | ||
? optionsOrDatatype | ||
: optionsOrDatatype?.datatype; | ||
if (!datatype) { | ||
throw new Error("Datatype is required"); | ||
} | ||
datatype = (0, sanitize_1.sanitizeType)(datatype); | ||
const metadata = new Map(); | ||
metadata.set("source_column_for", this); | ||
return [datatype, metadata]; | ||
} | ||
if (!("sourceColumn" in value) || !("embed" in value)) { | ||
return false; | ||
/** | ||
* vectorField is used in combination with `LanceSchema` to provide a declarative data model | ||
* | ||
* @param options - The options for the field | ||
* | ||
* @see {@link lancedb.LanceSchema} | ||
*/ | ||
vectorField(options) { | ||
let dtype; | ||
const dims = this.ndims() ?? options?.dims; | ||
if (!options?.datatype) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = new arrow_1.FixedSizeList(dims, new arrow_1.Field("item", new arrow_1.Float32(), true)); | ||
} | ||
else { | ||
if ((0, arrow_1.isFixedSizeList)(options.datatype)) { | ||
dtype = options.datatype; | ||
} | ||
else if ((0, arrow_1.isFloat)(options.datatype)) { | ||
if (dims === undefined) { | ||
throw new Error("ndims is required for vector field"); | ||
} | ||
dtype = (0, arrow_1.newVectorType)(dims, options.datatype); | ||
} | ||
else { | ||
throw new Error("Expected FixedSizeList or Float as datatype for vector field"); | ||
} | ||
} | ||
const metadata = new Map(); | ||
metadata.set("vector_column_for", this); | ||
return [dtype, metadata]; | ||
} | ||
return (typeof value.sourceColumn === "string" && typeof value.embed === "function"); | ||
/** The number of dimensions of the embeddings */ | ||
ndims() { | ||
return undefined; | ||
} | ||
/** | ||
Compute the embeddings for a single query | ||
*/ | ||
async computeQueryEmbeddings(data) { | ||
return this.computeSourceEmbeddings([data]).then((embeddings) => embeddings[0]); | ||
} | ||
} | ||
exports.isEmbeddingFunction = isEmbeddingFunction; | ||
exports.EmbeddingFunction = EmbeddingFunction; |
@@ -1,2 +0,28 @@ | ||
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function"; | ||
export { OpenAIEmbeddingFunction } from "./openai"; | ||
import { Schema } from "../arrow"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
export { EmbeddingFunction } from "./embedding_function"; | ||
export * from "./openai"; | ||
export * from "./registry"; | ||
/** | ||
* Create a schema with embedding functions. | ||
* | ||
* @param fields | ||
* @returns Schema | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* // ... | ||
* } | ||
* const func = new MyEmbeddingFunction(); | ||
* const schema = LanceSchema({ | ||
* id: new Int32(), | ||
* text: func.sourceField(new Utf8()), | ||
* vector: func.vectorField(), | ||
* // optional: specify the datatype and/or dimensions | ||
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}), | ||
* }); | ||
* | ||
* const table = await db.createTable("my_table", data, { schema }); | ||
* ``` | ||
*/ | ||
export declare function LanceSchema(fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>): Schema; |
"use strict"; | ||
// Copyright 2023 Lance Developers. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
var desc = Object.getOwnPropertyDescriptor(m, k); | ||
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { | ||
desc = { enumerable: true, get: function() { return m[k]; } }; | ||
} | ||
Object.defineProperty(o, k2, desc); | ||
}) : (function(o, m, k, k2) { | ||
if (k2 === undefined) k2 = k; | ||
o[k2] = m[k]; | ||
})); | ||
var __exportStar = (this && this.__exportStar) || function(m, exports) { | ||
for (var p in m) if (p !== "default" && !Object.prototype.hasOwnProperty.call(exports, p)) __createBinding(exports, m, p); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.OpenAIEmbeddingFunction = exports.isEmbeddingFunction = void 0; | ||
exports.LanceSchema = exports.EmbeddingFunction = void 0; | ||
const arrow_1 = require("../arrow"); | ||
const arrow_2 = require("../arrow"); | ||
const sanitize_1 = require("../sanitize"); | ||
const registry_1 = require("./registry"); | ||
var embedding_function_1 = require("./embedding_function"); | ||
Object.defineProperty(exports, "isEmbeddingFunction", { enumerable: true, get: function () { return embedding_function_1.isEmbeddingFunction; } }); | ||
var openai_1 = require("./openai"); | ||
Object.defineProperty(exports, "OpenAIEmbeddingFunction", { enumerable: true, get: function () { return openai_1.OpenAIEmbeddingFunction; } }); | ||
Object.defineProperty(exports, "EmbeddingFunction", { enumerable: true, get: function () { return embedding_function_1.EmbeddingFunction; } }); | ||
// We need to explicitly export '*' so that the `register` decorator actually registers the class. | ||
__exportStar(require("./openai"), exports); | ||
__exportStar(require("./registry"), exports); | ||
/** | ||
* Create a schema with embedding functions. | ||
* | ||
* @param fields | ||
* @returns Schema | ||
* @example | ||
* ```ts | ||
* class MyEmbeddingFunction extends EmbeddingFunction { | ||
* // ... | ||
* } | ||
* const func = new MyEmbeddingFunction(); | ||
* const schema = LanceSchema({ | ||
* id: new Int32(), | ||
* text: func.sourceField(new Utf8()), | ||
* vector: func.vectorField(), | ||
* // optional: specify the datatype and/or dimensions | ||
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}), | ||
* }); | ||
* | ||
* const table = await db.createTable("my_table", data, { schema }); | ||
* ``` | ||
*/ | ||
function LanceSchema(fields) { | ||
const arrowFields = []; | ||
const embeddingFunctions = new Map(); | ||
Object.entries(fields).forEach(([key, value]) => { | ||
if ((0, arrow_2.isDataType)(value)) { | ||
arrowFields.push(new arrow_1.Field(key, (0, sanitize_1.sanitizeType)(value), true)); | ||
} | ||
else { | ||
const [dtype, metadata] = value; | ||
arrowFields.push(new arrow_1.Field(key, (0, sanitize_1.sanitizeType)(dtype), true)); | ||
parseEmbeddingFunctions(embeddingFunctions, key, metadata); | ||
} | ||
}); | ||
const registry = (0, registry_1.getRegistry)(); | ||
const metadata = registry.getTableMetadata(Array.from(embeddingFunctions.values())); | ||
const schema = new arrow_1.Schema(arrowFields, metadata); | ||
return schema; | ||
} | ||
exports.LanceSchema = LanceSchema; | ||
function parseEmbeddingFunctions(embeddingFunctions, key, metadata) { | ||
if (metadata.has("source_column_for")) { | ||
const embedFunction = metadata.get("source_column_for"); | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
sourceColumn: key, | ||
}); | ||
} | ||
else { | ||
embeddingFunctions.set(embedFunction, { | ||
sourceColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} | ||
else if (metadata.has("vector_column_for")) { | ||
const embedFunction = metadata.get("vector_column_for"); | ||
const current = embeddingFunctions.get(embedFunction); | ||
if (current !== undefined) { | ||
embeddingFunctions.set(embedFunction, { | ||
...current, | ||
vectorColumn: key, | ||
}); | ||
} | ||
else { | ||
embeddingFunctions.set(embedFunction, { | ||
vectorColumn: key, | ||
function: embedFunction, | ||
}); | ||
} | ||
} | ||
} |
@@ -1,8 +0,17 @@ | ||
import { type EmbeddingFunction } from "./embedding_function"; | ||
export declare class OpenAIEmbeddingFunction implements EmbeddingFunction<string> { | ||
private readonly _openai; | ||
private readonly _modelName; | ||
constructor(sourceColumn: string, openAIKey: string, modelName?: string); | ||
embed(data: string[]): Promise<number[][]>; | ||
sourceColumn: string; | ||
import { Float } from "../arrow"; | ||
import { EmbeddingFunction } from "./embedding_function"; | ||
export type OpenAIOptions = { | ||
apiKey?: string; | ||
model?: string; | ||
}; | ||
export declare class OpenAIEmbeddingFunction extends EmbeddingFunction<string, OpenAIOptions> { | ||
#private; | ||
constructor(options?: OpenAIOptions); | ||
toJSON(): { | ||
model: string; | ||
}; | ||
ndims(): number; | ||
embeddingDataType(): Float; | ||
computeSourceEmbeddings(data: string[]): Promise<number[][]>; | ||
computeQueryEmbeddings(data: string): Promise<number[]>; | ||
} |
@@ -15,8 +15,26 @@ "use strict"; | ||
// limitations under the License. | ||
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) { | ||
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d; | ||
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc); | ||
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r; | ||
return c > 3 && r && Object.defineProperty(target, key, r), r; | ||
}; | ||
var __metadata = (this && this.__metadata) || function (k, v) { | ||
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v); | ||
}; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.OpenAIEmbeddingFunction = void 0; | ||
class OpenAIEmbeddingFunction { | ||
_openai; | ||
_modelName; | ||
constructor(sourceColumn, openAIKey, modelName = "text-embedding-ada-002") { | ||
const arrow_1 = require("../arrow"); | ||
const embedding_function_1 = require("./embedding_function"); | ||
const registry_1 = require("./registry"); | ||
let OpenAIEmbeddingFunction = class OpenAIEmbeddingFunction extends embedding_function_1.EmbeddingFunction { | ||
#openai; | ||
#modelName; | ||
constructor(options = { model: "text-embedding-ada-002" }) { | ||
super(); | ||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; | ||
if (!openAIKey) { | ||
throw new Error("OpenAI API key is required"); | ||
} | ||
const modelName = options?.model ?? "text-embedding-ada-002"; | ||
/** | ||
@@ -34,12 +52,31 @@ * @type {import("openai").default} | ||
} | ||
this.sourceColumn = sourceColumn; | ||
const configuration = { | ||
apiKey: openAIKey, | ||
}; | ||
this._openai = new Openai(configuration); | ||
this._modelName = modelName; | ||
this.#openai = new Openai(configuration); | ||
this.#modelName = modelName; | ||
} | ||
async embed(data) { | ||
const response = await this._openai.embeddings.create({ | ||
model: this._modelName, | ||
toJSON() { | ||
return { | ||
model: this.#modelName, | ||
}; | ||
} | ||
ndims() { | ||
switch (this.#modelName) { | ||
case "text-embedding-ada-002": | ||
return 1536; | ||
case "text-embedding-3-large": | ||
return 3072; | ||
case "text-embedding-3-small": | ||
return 1536; | ||
default: | ||
return null; | ||
} | ||
} | ||
embeddingDataType() { | ||
return new arrow_1.Float32(); | ||
} | ||
async computeSourceEmbeddings(data) { | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
@@ -53,4 +90,17 @@ }); | ||
} | ||
sourceColumn; | ||
} | ||
async computeQueryEmbeddings(data) { | ||
if (typeof data !== "string") { | ||
throw new Error("Data must be a string"); | ||
} | ||
const response = await this.#openai.embeddings.create({ | ||
model: this.#modelName, | ||
input: data, | ||
}); | ||
return response.data[0].embedding; | ||
} | ||
}; | ||
exports.OpenAIEmbeddingFunction = OpenAIEmbeddingFunction; | ||
exports.OpenAIEmbeddingFunction = OpenAIEmbeddingFunction = __decorate([ | ||
(0, registry_1.register)("openai"), | ||
__metadata("design:paramtypes", [Object]) | ||
], OpenAIEmbeddingFunction); |
@@ -1,2 +0,2 @@ | ||
import { Table as ArrowTable, RecordBatch } from "apache-arrow"; | ||
import { Table as ArrowTable, RecordBatch } from "./arrow"; | ||
import { RecordBatchIterator as NativeBatchIterator, Query as NativeQuery, Table as NativeTable, VectorQuery as NativeVectorQuery } from "./native"; | ||
@@ -3,0 +3,0 @@ export declare class RecordBatchIterator implements AsyncIterator<RecordBatch> { |
@@ -17,3 +17,3 @@ "use strict"; | ||
exports.Query = exports.VectorQuery = exports.QueryBase = exports.RecordBatchIterator = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
class RecordBatchIterator { | ||
@@ -38,3 +38,3 @@ promisedInner; | ||
} | ||
const tbl = (0, apache_arrow_1.tableFromIPC)(n); | ||
const tbl = (0, arrow_1.tableFromIPC)(n); | ||
if (tbl.batches.length != 1) { | ||
@@ -153,3 +153,3 @@ throw new Error("Expected only one batch"); | ||
} | ||
return new apache_arrow_1.Table(batches); | ||
return new arrow_1.Table(batches); | ||
} | ||
@@ -156,0 +156,0 @@ /** Collect the results as an array of objects. */ |
@@ -1,2 +0,23 @@ | ||
import { Schema } from "apache-arrow"; | ||
import type { TKeys } from "apache-arrow/type"; | ||
import { DataType, Date_, Decimal, DenseUnion, Dictionary, Duration, Field, FixedSizeBinary, FixedSizeList, Float, Int, Interval, List, Map_, Schema, SparseUnion, Struct, Time, Timestamp, TimestampMicrosecond, TimestampMillisecond, TimestampNanosecond, TimestampSecond, Type, Union } from "./arrow"; | ||
export declare function sanitizeMetadata(metadataLike?: unknown): Map<string, string> | undefined; | ||
export declare function sanitizeInt(typeLike: object): Int<Type.Int | Type.Int8 | Type.Int16 | Type.Int32 | Type.Int64 | Type.Uint8 | Type.Uint16 | Type.Uint32 | Type.Uint64>; | ||
export declare function sanitizeFloat(typeLike: object): Float<Type.Float | Type.Float16 | Type.Float32 | Type.Float64>; | ||
export declare function sanitizeDecimal(typeLike: object): Decimal; | ||
export declare function sanitizeDate(typeLike: object): Date_<import("apache-arrow/type").Dates>; | ||
export declare function sanitizeTime(typeLike: object): Time<Type.Time | Type.TimeSecond | Type.TimeMillisecond | Type.TimeMicrosecond | Type.TimeNanosecond>; | ||
export declare function sanitizeTimestamp(typeLike: object): Timestamp<Type.Timestamp | Type.TimestampSecond | Type.TimestampMillisecond | Type.TimestampMicrosecond | Type.TimestampNanosecond>; | ||
export declare function sanitizeTypedTimestamp(typeLike: object, Datatype: typeof TimestampNanosecond | typeof TimestampMicrosecond | typeof TimestampMillisecond | typeof TimestampSecond): TimestampSecond | TimestampMillisecond | TimestampMicrosecond | TimestampNanosecond; | ||
export declare function sanitizeInterval(typeLike: object): Interval<Type.Interval | Type.IntervalDayTime | Type.IntervalYearMonth>; | ||
export declare function sanitizeList(typeLike: object): List<any>; | ||
export declare function sanitizeStruct(typeLike: object): Struct<any>; | ||
export declare function sanitizeUnion(typeLike: object): Union<Type.Union | Type.DenseUnion | Type.SparseUnion>; | ||
export declare function sanitizeTypedUnion(typeLike: object, UnionType: typeof DenseUnion | typeof SparseUnion): SparseUnion | DenseUnion; | ||
export declare function sanitizeFixedSizeBinary(typeLike: object): FixedSizeBinary; | ||
export declare function sanitizeFixedSizeList(typeLike: object): FixedSizeList<any>; | ||
export declare function sanitizeMap(typeLike: object): Map_<any, any>; | ||
export declare function sanitizeDuration(typeLike: object): Duration<Type.Duration | Type.DurationSecond | Type.DurationMillisecond | Type.DurationMicrosecond | Type.DurationNanosecond>; | ||
export declare function sanitizeDictionary(typeLike: object): Dictionary<DataType<any, any>, TKeys>; | ||
export declare function sanitizeType(typeLike: unknown): DataType<any>; | ||
export declare function sanitizeField(fieldLike: unknown): Field; | ||
/** | ||
@@ -3,0 +24,0 @@ * Convert something schemaLike into a Schema instance |
@@ -16,11 +16,4 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.sanitizeSchema = void 0; | ||
// The utilities in this file help sanitize data from the user's arrow | ||
// library into the types expected by vectordb's arrow library. Node | ||
// generally allows for mulitple versions of the same library (and sometimes | ||
// even multiple copies of the same version) to be installed at the same | ||
// time. However, arrow-js uses instanceof which expected that the input | ||
// comes from the exact same library instance. This is not always the case | ||
// and so we must sanitize the input to ensure that it is compatible. | ||
const apache_arrow_1 = require("apache-arrow"); | ||
exports.sanitizeSchema = exports.sanitizeField = exports.sanitizeType = exports.sanitizeDictionary = exports.sanitizeDuration = exports.sanitizeMap = exports.sanitizeFixedSizeList = exports.sanitizeFixedSizeBinary = exports.sanitizeTypedUnion = exports.sanitizeUnion = exports.sanitizeStruct = exports.sanitizeList = exports.sanitizeInterval = exports.sanitizeTypedTimestamp = exports.sanitizeTimestamp = exports.sanitizeTime = exports.sanitizeDate = exports.sanitizeDecimal = exports.sanitizeFloat = exports.sanitizeInt = exports.sanitizeMetadata = void 0; | ||
const arrow_1 = require("./arrow"); | ||
function sanitizeMetadata(metadataLike) { | ||
@@ -40,2 +33,3 @@ if (metadataLike === undefined || metadataLike === null) { | ||
} | ||
exports.sanitizeMetadata = sanitizeMetadata; | ||
function sanitizeInt(typeLike) { | ||
@@ -48,4 +42,5 @@ if (!("bitWidth" in typeLike) || | ||
} | ||
return new apache_arrow_1.Int(typeLike.isSigned, typeLike.bitWidth); | ||
return new arrow_1.Int(typeLike.isSigned, typeLike.bitWidth); | ||
} | ||
exports.sanitizeInt = sanitizeInt; | ||
function sanitizeFloat(typeLike) { | ||
@@ -55,4 +50,5 @@ if (!("precision" in typeLike) || typeof typeLike.precision !== "number") { | ||
} | ||
return new apache_arrow_1.Float(typeLike.precision); | ||
return new arrow_1.Float(typeLike.precision); | ||
} | ||
exports.sanitizeFloat = sanitizeFloat; | ||
function sanitizeDecimal(typeLike) { | ||
@@ -67,4 +63,5 @@ if (!("scale" in typeLike) || | ||
} | ||
return new apache_arrow_1.Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); | ||
return new arrow_1.Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); | ||
} | ||
exports.sanitizeDecimal = sanitizeDecimal; | ||
function sanitizeDate(typeLike) { | ||
@@ -74,4 +71,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Date_(typeLike.unit); | ||
return new arrow_1.Date_(typeLike.unit); | ||
} | ||
exports.sanitizeDate = sanitizeDate; | ||
function sanitizeTime(typeLike) { | ||
@@ -84,4 +82,5 @@ if (!("unit" in typeLike) || | ||
} | ||
return new apache_arrow_1.Time(typeLike.unit, typeLike.bitWidth); | ||
return new arrow_1.Time(typeLike.unit, typeLike.bitWidth); | ||
} | ||
exports.sanitizeTime = sanitizeTime; | ||
function sanitizeTimestamp(typeLike) { | ||
@@ -95,4 +94,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Timestamp(typeLike.unit, timezone); | ||
return new arrow_1.Timestamp(typeLike.unit, timezone); | ||
} | ||
exports.sanitizeTimestamp = sanitizeTimestamp; | ||
function sanitizeTypedTimestamp(typeLike, | ||
@@ -107,2 +107,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
} | ||
exports.sanitizeTypedTimestamp = sanitizeTypedTimestamp; | ||
function sanitizeInterval(typeLike) { | ||
@@ -112,4 +113,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Interval(typeLike.unit); | ||
return new arrow_1.Interval(typeLike.unit); | ||
} | ||
exports.sanitizeInterval = sanitizeInterval; | ||
function sanitizeList(typeLike) { | ||
@@ -122,4 +124,5 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.List(sanitizeField(typeLike.children[0])); | ||
return new arrow_1.List(sanitizeField(typeLike.children[0])); | ||
} | ||
exports.sanitizeList = sanitizeList; | ||
function sanitizeStruct(typeLike) { | ||
@@ -129,4 +132,5 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.Struct(typeLike.children.map((child) => sanitizeField(child))); | ||
return new arrow_1.Struct(typeLike.children.map((child) => sanitizeField(child))); | ||
} | ||
exports.sanitizeStruct = sanitizeStruct; | ||
function sanitizeUnion(typeLike) { | ||
@@ -141,6 +145,7 @@ if (!("typeIds" in typeLike) || | ||
} | ||
return new apache_arrow_1.Union(typeLike.mode, | ||
return new arrow_1.Union(typeLike.mode, | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
typeLike.typeIds, typeLike.children.map((child) => sanitizeField(child))); | ||
} | ||
exports.sanitizeUnion = sanitizeUnion; | ||
function sanitizeTypedUnion(typeLike, | ||
@@ -157,2 +162,3 @@ // eslint-disable-next-line @typescript-eslint/naming-convention | ||
} | ||
exports.sanitizeTypedUnion = sanitizeTypedUnion; | ||
function sanitizeFixedSizeBinary(typeLike) { | ||
@@ -162,4 +168,5 @@ if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { | ||
} | ||
return new apache_arrow_1.FixedSizeBinary(typeLike.byteWidth); | ||
return new arrow_1.FixedSizeBinary(typeLike.byteWidth); | ||
} | ||
exports.sanitizeFixedSizeBinary = sanitizeFixedSizeBinary; | ||
function sanitizeFixedSizeList(typeLike) { | ||
@@ -175,4 +182,5 @@ if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") { | ||
} | ||
return new apache_arrow_1.FixedSizeList(typeLike.listSize, sanitizeField(typeLike.children[0])); | ||
return new arrow_1.FixedSizeList(typeLike.listSize, sanitizeField(typeLike.children[0])); | ||
} | ||
exports.sanitizeFixedSizeList = sanitizeFixedSizeList; | ||
function sanitizeMap(typeLike) { | ||
@@ -185,6 +193,7 @@ if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { | ||
} | ||
return new apache_arrow_1.Map_( | ||
return new arrow_1.Map_( | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
typeLike.children.map((field) => sanitizeField(field)), typeLike.keysSorted); | ||
} | ||
exports.sanitizeMap = sanitizeMap; | ||
function sanitizeDuration(typeLike) { | ||
@@ -194,4 +203,5 @@ if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { | ||
} | ||
return new apache_arrow_1.Duration(typeLike.unit); | ||
return new arrow_1.Duration(typeLike.unit); | ||
} | ||
exports.sanitizeDuration = sanitizeDuration; | ||
function sanitizeDictionary(typeLike) { | ||
@@ -210,4 +220,5 @@ if (!("id" in typeLike) || typeof typeLike.id !== "number") { | ||
} | ||
return new apache_arrow_1.Dictionary(sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices), typeLike.id, typeLike.isOrdered); | ||
return new arrow_1.Dictionary(sanitizeType(typeLike.dictionary), sanitizeType(typeLike.indices), typeLike.id, typeLike.isOrdered); | ||
} | ||
exports.sanitizeDictionary = sanitizeDictionary; | ||
// biome-ignore lint/suspicious/noExplicitAny: skip | ||
@@ -232,100 +243,100 @@ function sanitizeType(typeLike) { | ||
switch (typeId) { | ||
case apache_arrow_1.Type.NONE: | ||
case arrow_1.Type.NONE: | ||
throw Error("Received a Type with a typeId of NONE"); | ||
case apache_arrow_1.Type.Null: | ||
return new apache_arrow_1.Null(); | ||
case apache_arrow_1.Type.Int: | ||
case arrow_1.Type.Null: | ||
return new arrow_1.Null(); | ||
case arrow_1.Type.Int: | ||
return sanitizeInt(typeLike); | ||
case apache_arrow_1.Type.Float: | ||
case arrow_1.Type.Float: | ||
return sanitizeFloat(typeLike); | ||
case apache_arrow_1.Type.Binary: | ||
return new apache_arrow_1.Binary(); | ||
case apache_arrow_1.Type.Utf8: | ||
return new apache_arrow_1.Utf8(); | ||
case apache_arrow_1.Type.Bool: | ||
return new apache_arrow_1.Bool(); | ||
case apache_arrow_1.Type.Decimal: | ||
case arrow_1.Type.Binary: | ||
return new arrow_1.Binary(); | ||
case arrow_1.Type.Utf8: | ||
return new arrow_1.Utf8(); | ||
case arrow_1.Type.Bool: | ||
return new arrow_1.Bool(); | ||
case arrow_1.Type.Decimal: | ||
return sanitizeDecimal(typeLike); | ||
case apache_arrow_1.Type.Date: | ||
case arrow_1.Type.Date: | ||
return sanitizeDate(typeLike); | ||
case apache_arrow_1.Type.Time: | ||
case arrow_1.Type.Time: | ||
return sanitizeTime(typeLike); | ||
case apache_arrow_1.Type.Timestamp: | ||
case arrow_1.Type.Timestamp: | ||
return sanitizeTimestamp(typeLike); | ||
case apache_arrow_1.Type.Interval: | ||
case arrow_1.Type.Interval: | ||
return sanitizeInterval(typeLike); | ||
case apache_arrow_1.Type.List: | ||
case arrow_1.Type.List: | ||
return sanitizeList(typeLike); | ||
case apache_arrow_1.Type.Struct: | ||
case arrow_1.Type.Struct: | ||
return sanitizeStruct(typeLike); | ||
case apache_arrow_1.Type.Union: | ||
case arrow_1.Type.Union: | ||
return sanitizeUnion(typeLike); | ||
case apache_arrow_1.Type.FixedSizeBinary: | ||
case arrow_1.Type.FixedSizeBinary: | ||
return sanitizeFixedSizeBinary(typeLike); | ||
case apache_arrow_1.Type.FixedSizeList: | ||
case arrow_1.Type.FixedSizeList: | ||
return sanitizeFixedSizeList(typeLike); | ||
case apache_arrow_1.Type.Map: | ||
case arrow_1.Type.Map: | ||
return sanitizeMap(typeLike); | ||
case apache_arrow_1.Type.Duration: | ||
case arrow_1.Type.Duration: | ||
return sanitizeDuration(typeLike); | ||
case apache_arrow_1.Type.Dictionary: | ||
case arrow_1.Type.Dictionary: | ||
return sanitizeDictionary(typeLike); | ||
case apache_arrow_1.Type.Int8: | ||
return new apache_arrow_1.Int8(); | ||
case apache_arrow_1.Type.Int16: | ||
return new apache_arrow_1.Int16(); | ||
case apache_arrow_1.Type.Int32: | ||
return new apache_arrow_1.Int32(); | ||
case apache_arrow_1.Type.Int64: | ||
return new apache_arrow_1.Int64(); | ||
case apache_arrow_1.Type.Uint8: | ||
return new apache_arrow_1.Uint8(); | ||
case apache_arrow_1.Type.Uint16: | ||
return new apache_arrow_1.Uint16(); | ||
case apache_arrow_1.Type.Uint32: | ||
return new apache_arrow_1.Uint32(); | ||
case apache_arrow_1.Type.Uint64: | ||
return new apache_arrow_1.Uint64(); | ||
case apache_arrow_1.Type.Float16: | ||
return new apache_arrow_1.Float16(); | ||
case apache_arrow_1.Type.Float32: | ||
return new apache_arrow_1.Float32(); | ||
case apache_arrow_1.Type.Float64: | ||
return new apache_arrow_1.Float64(); | ||
case apache_arrow_1.Type.DateMillisecond: | ||
return new apache_arrow_1.DateMillisecond(); | ||
case apache_arrow_1.Type.DateDay: | ||
return new apache_arrow_1.DateDay(); | ||
case apache_arrow_1.Type.TimeNanosecond: | ||
return new apache_arrow_1.TimeNanosecond(); | ||
case apache_arrow_1.Type.TimeMicrosecond: | ||
return new apache_arrow_1.TimeMicrosecond(); | ||
case apache_arrow_1.Type.TimeMillisecond: | ||
return new apache_arrow_1.TimeMillisecond(); | ||
case apache_arrow_1.Type.TimeSecond: | ||
return new apache_arrow_1.TimeSecond(); | ||
case apache_arrow_1.Type.TimestampNanosecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampNanosecond); | ||
case apache_arrow_1.Type.TimestampMicrosecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampMicrosecond); | ||
case apache_arrow_1.Type.TimestampMillisecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampMillisecond); | ||
case apache_arrow_1.Type.TimestampSecond: | ||
return sanitizeTypedTimestamp(typeLike, apache_arrow_1.TimestampSecond); | ||
case apache_arrow_1.Type.DenseUnion: | ||
return sanitizeTypedUnion(typeLike, apache_arrow_1.DenseUnion); | ||
case apache_arrow_1.Type.SparseUnion: | ||
return sanitizeTypedUnion(typeLike, apache_arrow_1.SparseUnion); | ||
case apache_arrow_1.Type.IntervalDayTime: | ||
return new apache_arrow_1.IntervalDayTime(); | ||
case apache_arrow_1.Type.IntervalYearMonth: | ||
return new apache_arrow_1.IntervalYearMonth(); | ||
case apache_arrow_1.Type.DurationNanosecond: | ||
return new apache_arrow_1.DurationNanosecond(); | ||
case apache_arrow_1.Type.DurationMicrosecond: | ||
return new apache_arrow_1.DurationMicrosecond(); | ||
case apache_arrow_1.Type.DurationMillisecond: | ||
return new apache_arrow_1.DurationMillisecond(); | ||
case apache_arrow_1.Type.DurationSecond: | ||
return new apache_arrow_1.DurationSecond(); | ||
case arrow_1.Type.Int8: | ||
return new arrow_1.Int8(); | ||
case arrow_1.Type.Int16: | ||
return new arrow_1.Int16(); | ||
case arrow_1.Type.Int32: | ||
return new arrow_1.Int32(); | ||
case arrow_1.Type.Int64: | ||
return new arrow_1.Int64(); | ||
case arrow_1.Type.Uint8: | ||
return new arrow_1.Uint8(); | ||
case arrow_1.Type.Uint16: | ||
return new arrow_1.Uint16(); | ||
case arrow_1.Type.Uint32: | ||
return new arrow_1.Uint32(); | ||
case arrow_1.Type.Uint64: | ||
return new arrow_1.Uint64(); | ||
case arrow_1.Type.Float16: | ||
return new arrow_1.Float16(); | ||
case arrow_1.Type.Float32: | ||
return new arrow_1.Float32(); | ||
case arrow_1.Type.Float64: | ||
return new arrow_1.Float64(); | ||
case arrow_1.Type.DateMillisecond: | ||
return new arrow_1.DateMillisecond(); | ||
case arrow_1.Type.DateDay: | ||
return new arrow_1.DateDay(); | ||
case arrow_1.Type.TimeNanosecond: | ||
return new arrow_1.TimeNanosecond(); | ||
case arrow_1.Type.TimeMicrosecond: | ||
return new arrow_1.TimeMicrosecond(); | ||
case arrow_1.Type.TimeMillisecond: | ||
return new arrow_1.TimeMillisecond(); | ||
case arrow_1.Type.TimeSecond: | ||
return new arrow_1.TimeSecond(); | ||
case arrow_1.Type.TimestampNanosecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampNanosecond); | ||
case arrow_1.Type.TimestampMicrosecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampMicrosecond); | ||
case arrow_1.Type.TimestampMillisecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampMillisecond); | ||
case arrow_1.Type.TimestampSecond: | ||
return sanitizeTypedTimestamp(typeLike, arrow_1.TimestampSecond); | ||
case arrow_1.Type.DenseUnion: | ||
return sanitizeTypedUnion(typeLike, arrow_1.DenseUnion); | ||
case arrow_1.Type.SparseUnion: | ||
return sanitizeTypedUnion(typeLike, arrow_1.SparseUnion); | ||
case arrow_1.Type.IntervalDayTime: | ||
return new arrow_1.IntervalDayTime(); | ||
case arrow_1.Type.IntervalYearMonth: | ||
return new arrow_1.IntervalYearMonth(); | ||
case arrow_1.Type.DurationNanosecond: | ||
return new arrow_1.DurationNanosecond(); | ||
case arrow_1.Type.DurationMicrosecond: | ||
return new arrow_1.DurationMicrosecond(); | ||
case arrow_1.Type.DurationMillisecond: | ||
return new arrow_1.DurationMillisecond(); | ||
case arrow_1.Type.DurationSecond: | ||
return new arrow_1.DurationSecond(); | ||
default: | ||
@@ -335,4 +346,5 @@ throw new Error("Unrecoginized type id in schema: " + typeId); | ||
} | ||
exports.sanitizeType = sanitizeType; | ||
function sanitizeField(fieldLike) { | ||
if (fieldLike instanceof apache_arrow_1.Field) { | ||
if (fieldLike instanceof arrow_1.Field) { | ||
return fieldLike; | ||
@@ -361,4 +373,5 @@ } | ||
} | ||
return new apache_arrow_1.Field(name, type, nullable, metadata); | ||
return new arrow_1.Field(name, type, nullable, metadata); | ||
} | ||
exports.sanitizeField = sanitizeField; | ||
/** | ||
@@ -372,3 +385,3 @@ * Convert something schemaLike into a Schema instance | ||
function sanitizeSchema(schemaLike) { | ||
if (schemaLike instanceof apache_arrow_1.Schema) { | ||
if (schemaLike instanceof arrow_1.Schema) { | ||
return schemaLike; | ||
@@ -390,4 +403,4 @@ } | ||
const sanitizedFields = schemaLike.fields.map((field) => sanitizeField(field)); | ||
return new apache_arrow_1.Schema(sanitizedFields, metadata); | ||
return new arrow_1.Schema(sanitizedFields, metadata); | ||
} | ||
exports.sanitizeSchema = sanitizeSchema; |
@@ -1,3 +0,2 @@ | ||
import { Schema } from "apache-arrow"; | ||
import { Data } from "./arrow"; | ||
import { Data, Schema } from "./arrow"; | ||
import { IndexOptions } from "./indices"; | ||
@@ -4,0 +3,0 @@ import { AddColumnsSql, ColumnAlteration, IndexConfig, OptimizeStats, Table as _NativeTable } from "./native"; |
@@ -17,4 +17,4 @@ "use strict"; | ||
exports.Table = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const arrow_1 = require("./arrow"); | ||
const registry_1 = require("./embedding/registry"); | ||
const query_1 = require("./query"); | ||
@@ -60,3 +60,3 @@ /** | ||
const schemaBuf = await this.inner.schema(); | ||
const tbl = (0, apache_arrow_1.tableFromIPC)(schemaBuf); | ||
const tbl = (0, arrow_1.tableFromIPC)(schemaBuf); | ||
return tbl.schema; | ||
@@ -70,3 +70,6 @@ } | ||
const mode = options?.mode ?? "append"; | ||
const buffer = await (0, arrow_1.fromDataToBuffer)(data); | ||
const schema = await this.schema(); | ||
const registry = (0, registry_1.getRegistry)(); | ||
const functions = registry.parseFunctions(schema.metadata); | ||
const buffer = await (0, arrow_1.fromDataToBuffer)(data, functions.values().next().value); | ||
await this.inner.add(buffer, mode); | ||
@@ -73,0 +76,0 @@ } |
{ | ||
"name": "@lancedb/lancedb", | ||
"version": "0.5.0", | ||
"main": "./dist/index.js", | ||
"types": "./dist/index.d.ts", | ||
"version": "0.5.1", | ||
"main": "dist/index.js", | ||
"exports": { | ||
".": "./dist/index.js", | ||
"./embedding": "./dist/embedding/index.js" | ||
}, | ||
"types": "dist/index.d.ts", | ||
"napi": { | ||
@@ -72,11 +76,12 @@ "name": "lancedb", | ||
"apache-arrow": "^15.0.0", | ||
"openai": "^4.29.2" | ||
"openai": "^4.29.2", | ||
"reflect-metadata": "^0.2.2" | ||
}, | ||
"optionalDependencies": { | ||
"@lancedb/lancedb-darwin-arm64": "0.5.0", | ||
"@lancedb/lancedb-linux-arm64-gnu": "0.5.0", | ||
"@lancedb/lancedb-darwin-x64": "0.5.0", | ||
"@lancedb/lancedb-linux-x64-gnu": "0.5.0", | ||
"@lancedb/lancedb-win32-x64-msvc": "0.5.0" | ||
"@lancedb/lancedb-darwin-arm64": "0.5.1", | ||
"@lancedb/lancedb-linux-arm64-gnu": "0.5.1", | ||
"@lancedb/lancedb-darwin-x64": "0.5.1", | ||
"@lancedb/lancedb-linux-x64-gnu": "0.5.1", | ||
"@lancedb/lancedb-win32-x64-msvc": "0.5.1" | ||
} | ||
} |
@@ -10,3 +10,5 @@ { | ||
"allowJs": true, | ||
"resolveJsonModule": true | ||
"resolveJsonModule": true, | ||
"emitDecoratorMetadata": true, | ||
"experimentalDecorators": true | ||
}, | ||
@@ -13,0 +15,0 @@ "exclude": ["./dist/*"], |
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Environment variable access
Supply chain riskPackage accesses environment variables, which may be a sign of credential stuffing or data theft.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
485896
7303
67
11973
8
5
5
1
+ Addedreflect-metadata@^0.2.2
+ Added@lancedb/lancedb-darwin-arm64@0.5.1(transitive)
+ Added@lancedb/lancedb-darwin-x64@0.5.1(transitive)
+ Added@lancedb/lancedb-linux-arm64-gnu@0.5.1(transitive)
+ Added@lancedb/lancedb-linux-x64-gnu@0.5.1(transitive)
+ Added@lancedb/lancedb-win32-x64-msvc@0.5.1(transitive)
+ Addedreflect-metadata@0.2.2(transitive)
- Removed@lancedb/lancedb-darwin-arm64@0.5.0(transitive)
- Removed@lancedb/lancedb-darwin-x64@0.5.0(transitive)
- Removed@lancedb/lancedb-linux-arm64-gnu@0.5.0(transitive)
- Removed@lancedb/lancedb-linux-x64-gnu@0.5.0(transitive)
- Removed@lancedb/lancedb-win32-x64-msvc@0.5.0(transitive)