+280
| 'use strict'; | ||
| var pg = require('pg'); | ||
| var ai = require('ai'); | ||
| var zod = require('zod'); | ||
| var mysql$1 = require('mysql2'); | ||
| var sqlite3 = require('sqlite3'); | ||
| function _interopNamespaceDefault(e) { | ||
| var n = Object.create(null); | ||
| if (e) { | ||
| Object.keys(e).forEach(function (k) { | ||
| if (k !== 'default') { | ||
| var d = Object.getOwnPropertyDescriptor(e, k); | ||
| Object.defineProperty(n, k, d.get ? d : { | ||
| enumerable: true, | ||
| get: function () { return e[k]; } | ||
| }); | ||
| } | ||
| }); | ||
| } | ||
| n.default = e; | ||
| return Object.freeze(n); | ||
| } | ||
| var mysql__namespace = /*#__PURE__*/_interopNamespaceDefault(mysql$1); | ||
| class PostgresTool { | ||
| url; | ||
| client; | ||
| constructor(dbUrl) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = new pg.Client(this.url); | ||
| await this.client.connect(); | ||
| } | ||
| async describe() { | ||
| const result = await this.client.query(` | ||
| SELECT table_schema, | ||
| table_name, | ||
| string_agg( | ||
| format('%s %s %s', column_name, | ||
| CASE | ||
| WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')' | ||
| WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')' | ||
| WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')' | ||
| ELSE data_type | ||
| END, | ||
| CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END), | ||
| ', | ||
| ' ORDER BY ordinal_position | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('pg_catalog', 'information_schema') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| `); | ||
| const createTableStatements = result.rows.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: schema | ||
| }; | ||
| } | ||
| async query(query) { | ||
| const result = await this.client.query(query); | ||
| const res = Array.isArray(result) ? result : [result]; | ||
| return res.map((r) => r.rows); | ||
| } | ||
| } | ||
| const descriptionTemplate = (schema) => `Query a database with the following schema: | ||
| ${JSON.stringify(schema)}`; | ||
| async function sqlTool(db, { notes } = {}) { | ||
| await db.initialize(); | ||
| const schema = await db.describe(); | ||
| schema.notes = notes; | ||
| return ai.tool({ | ||
| description: descriptionTemplate(schema), | ||
| execute: async ({ query }) => { | ||
| return await db.query(query); | ||
| }, | ||
| parameters: zod.z.object({ | ||
| query: zod.z.string().describe( | ||
| `${schema.database} Query to execute. | ||
| Notes: ${schema.notes?.join( | ||
| ", " | ||
| )}` | ||
| ) | ||
| }) | ||
| }); | ||
| } | ||
| class MySQLTool { | ||
| url; | ||
| client; | ||
| constructor(dbUrl) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = mysql__namespace.createConnection(this.url); | ||
| this.client.connect(); | ||
| } | ||
| async describe() { | ||
| const res = await new Promise( | ||
| (resolve, reject) => this.client.query( | ||
| { | ||
| sql: `SELECT table_schema, | ||
| table_name, | ||
| GROUP_CONCAT( | ||
| CONCAT( | ||
| column_name, ' ', | ||
| CASE | ||
| WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')') | ||
| ELSE data_type | ||
| END, ' ', | ||
| IF(is_nullable = 'YES', 'NULL', 'NOT NULL') | ||
| ) ORDER BY ordinal_position | ||
| SEPARATOR ' | ||
| ' -- Newline separator | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| ` | ||
| }, | ||
| (err, rows, _fields) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(rows); | ||
| } | ||
| ) | ||
| ); | ||
| const createTableStatements = res.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "MySQL", | ||
| description: schema | ||
| }; | ||
| } | ||
| async query(query) { | ||
| const res = await new Promise( | ||
| (resolve, reject) => this.client.query({ sql: query }, (err, res2) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(res2); | ||
| }) | ||
| ); | ||
| return res; | ||
| } | ||
| } | ||
| function getDatabaseSchema(db) { | ||
| return new Promise((resolve, reject) => { | ||
| const getTablesQuery = ` | ||
| SELECT name | ||
| FROM sqlite_master | ||
| WHERE type = 'table' | ||
| AND name NOT LIKE 'sqlite_%' | ||
| ORDER BY name; | ||
| `; | ||
| db.all(getTablesQuery, [], (err, tables) => { | ||
| if (err) { | ||
| return reject(`Failed to retrieve tables: ${err.message}`); | ||
| } | ||
| if (!tables.length) { | ||
| return resolve("No tables in this database."); | ||
| } | ||
| let schemaOutput = ""; | ||
| let pendingTables = tables.length; | ||
| tables.forEach((name) => { | ||
| db.all(`PRAGMA table_info(${name})`, [], (err2, columns) => { | ||
| if (err2) { | ||
| return reject( | ||
| `Failed to get table_info for ${name}: ${err2.message}` | ||
| ); | ||
| } | ||
| schemaOutput += `Table: ${name} | ||
| `; | ||
| columns.forEach((col) => { | ||
| let columnStr = ` ${col.name} ${col.type}`; | ||
| if (col.pk === 1) { | ||
| columnStr += " PRIMARY KEY"; | ||
| } | ||
| if (col.notnull === 1) { | ||
| columnStr += " NOT NULL"; | ||
| } | ||
| if (col.dflt_value !== null) { | ||
| columnStr += ` DEFAULT ${col.dflt_value}`; | ||
| } | ||
| schemaOutput += columnStr + "\n"; | ||
| }); | ||
| schemaOutput += "\n"; | ||
| pendingTables -= 1; | ||
| if (pendingTables === 0) { | ||
| resolve(schemaOutput.trim()); | ||
| } | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
| } | ||
| class SQLiteTool { | ||
| path; | ||
| db; | ||
| constructor(dbPath) { | ||
| this.path = dbPath; | ||
| } | ||
| static inMemory() { | ||
| return new SQLiteTool(":memory:"); | ||
| } | ||
| async initialize() { | ||
| this.db = new sqlite3.Database(this.path); | ||
| } | ||
| async describe() { | ||
| const description = await getDatabaseSchema(this.db); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: ` | ||
| sqlite3 database at ${this.path} | ||
| schema: | ||
| ${description} | ||
| `.trim() | ||
| }; | ||
| } | ||
| async query(query) { | ||
| return await new Promise((resolve, reject) => { | ||
| this.db.run(query, (res, err) => { | ||
| if (err !== null) { | ||
| reject(err); | ||
| } else { | ||
| res.all((err2, rows) => { | ||
| if (err2 !== null) { | ||
| reject(err2); | ||
| } else { | ||
| resolve(rows); | ||
| } | ||
| }); | ||
| } | ||
| }); | ||
| }); | ||
| } | ||
| } | ||
| async function postgres(dbUrl, options) { | ||
| return await sqlTool(new PostgresTool(dbUrl), options); | ||
| } | ||
| async function mysql(dbUrl, options) { | ||
| return await sqlTool(new MySQLTool(dbUrl), options); | ||
| } | ||
| async function sqlite(dbPath, options) { | ||
| return await sqlTool(new SQLiteTool(dbPath), options); | ||
| } | ||
| exports.mysql = mysql; | ||
| exports.postgres = postgres; | ||
| exports.sqlite = sqlite; |
| import * as ai from 'ai'; | ||
| import * as zod from 'zod'; | ||
| interface SqlToolOptions { | ||
| notes?: string[]; | ||
| } | ||
| declare function postgres(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| declare function mysql(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| declare function sqlite(dbPath: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| export { mysql, postgres, sqlite }; |
| import * as ai from 'ai'; | ||
| import * as zod from 'zod'; | ||
| interface SqlToolOptions { | ||
| notes?: string[]; | ||
| } | ||
| declare function postgres(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| declare function mysql(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| declare function sqlite(dbPath: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{ | ||
| query: zod.ZodString; | ||
| }, "strip", zod.ZodTypeAny, { | ||
| query?: string; | ||
| }, { | ||
| query?: string; | ||
| }>, unknown[]> & { | ||
| execute: (args: { | ||
| query?: string; | ||
| }, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>; | ||
| }>; | ||
| export { mysql, postgres, sqlite }; |
+257
| import { Client } from 'pg'; | ||
| import { tool } from 'ai'; | ||
| import { z } from 'zod'; | ||
| import * as mysql$1 from 'mysql2'; | ||
| import { Database } from 'sqlite3'; | ||
| class PostgresTool { | ||
| url; | ||
| client; | ||
| constructor(dbUrl) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = new Client(this.url); | ||
| await this.client.connect(); | ||
| } | ||
| async describe() { | ||
| const result = await this.client.query(` | ||
| SELECT table_schema, | ||
| table_name, | ||
| string_agg( | ||
| format('%s %s %s', column_name, | ||
| CASE | ||
| WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')' | ||
| WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')' | ||
| WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')' | ||
| ELSE data_type | ||
| END, | ||
| CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END), | ||
| ', | ||
| ' ORDER BY ordinal_position | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('pg_catalog', 'information_schema') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| `); | ||
| const createTableStatements = result.rows.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: schema | ||
| }; | ||
| } | ||
| async query(query) { | ||
| const result = await this.client.query(query); | ||
| const res = Array.isArray(result) ? result : [result]; | ||
| return res.map((r) => r.rows); | ||
| } | ||
| } | ||
| const descriptionTemplate = (schema) => `Query a database with the following schema: | ||
| ${JSON.stringify(schema)}`; | ||
| async function sqlTool(db, { notes } = {}) { | ||
| await db.initialize(); | ||
| const schema = await db.describe(); | ||
| schema.notes = notes; | ||
| return tool({ | ||
| description: descriptionTemplate(schema), | ||
| execute: async ({ query }) => { | ||
| return await db.query(query); | ||
| }, | ||
| parameters: z.object({ | ||
| query: z.string().describe( | ||
| `${schema.database} Query to execute. | ||
| Notes: ${schema.notes?.join( | ||
| ", " | ||
| )}` | ||
| ) | ||
| }) | ||
| }); | ||
| } | ||
| class MySQLTool { | ||
| url; | ||
| client; | ||
| constructor(dbUrl) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = mysql$1.createConnection(this.url); | ||
| this.client.connect(); | ||
| } | ||
| async describe() { | ||
| const res = await new Promise( | ||
| (resolve, reject) => this.client.query( | ||
| { | ||
| sql: `SELECT table_schema, | ||
| table_name, | ||
| GROUP_CONCAT( | ||
| CONCAT( | ||
| column_name, ' ', | ||
| CASE | ||
| WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')') | ||
| ELSE data_type | ||
| END, ' ', | ||
| IF(is_nullable = 'YES', 'NULL', 'NOT NULL') | ||
| ) ORDER BY ordinal_position | ||
| SEPARATOR ' | ||
| ' -- Newline separator | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| ` | ||
| }, | ||
| (err, rows, _fields) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(rows); | ||
| } | ||
| ) | ||
| ); | ||
| const createTableStatements = res.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "MySQL", | ||
| description: schema | ||
| }; | ||
| } | ||
| async query(query) { | ||
| const res = await new Promise( | ||
| (resolve, reject) => this.client.query({ sql: query }, (err, res2) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(res2); | ||
| }) | ||
| ); | ||
| return res; | ||
| } | ||
| } | ||
| function getDatabaseSchema(db) { | ||
| return new Promise((resolve, reject) => { | ||
| const getTablesQuery = ` | ||
| SELECT name | ||
| FROM sqlite_master | ||
| WHERE type = 'table' | ||
| AND name NOT LIKE 'sqlite_%' | ||
| ORDER BY name; | ||
| `; | ||
| db.all(getTablesQuery, [], (err, tables) => { | ||
| if (err) { | ||
| return reject(`Failed to retrieve tables: ${err.message}`); | ||
| } | ||
| if (!tables.length) { | ||
| return resolve("No tables in this database."); | ||
| } | ||
| let schemaOutput = ""; | ||
| let pendingTables = tables.length; | ||
| tables.forEach((name) => { | ||
| db.all(`PRAGMA table_info(${name})`, [], (err2, columns) => { | ||
| if (err2) { | ||
| return reject( | ||
| `Failed to get table_info for ${name}: ${err2.message}` | ||
| ); | ||
| } | ||
| schemaOutput += `Table: ${name} | ||
| `; | ||
| columns.forEach((col) => { | ||
| let columnStr = ` ${col.name} ${col.type}`; | ||
| if (col.pk === 1) { | ||
| columnStr += " PRIMARY KEY"; | ||
| } | ||
| if (col.notnull === 1) { | ||
| columnStr += " NOT NULL"; | ||
| } | ||
| if (col.dflt_value !== null) { | ||
| columnStr += ` DEFAULT ${col.dflt_value}`; | ||
| } | ||
| schemaOutput += columnStr + "\n"; | ||
| }); | ||
| schemaOutput += "\n"; | ||
| pendingTables -= 1; | ||
| if (pendingTables === 0) { | ||
| resolve(schemaOutput.trim()); | ||
| } | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
| } | ||
| class SQLiteTool { | ||
| path; | ||
| db; | ||
| constructor(dbPath) { | ||
| this.path = dbPath; | ||
| } | ||
| static inMemory() { | ||
| return new SQLiteTool(":memory:"); | ||
| } | ||
| async initialize() { | ||
| this.db = new Database(this.path); | ||
| } | ||
| async describe() { | ||
| const description = await getDatabaseSchema(this.db); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: ` | ||
| sqlite3 database at ${this.path} | ||
| schema: | ||
| ${description} | ||
| `.trim() | ||
| }; | ||
| } | ||
| async query(query) { | ||
| return await new Promise((resolve, reject) => { | ||
| this.db.run(query, (res, err) => { | ||
| if (err !== null) { | ||
| reject(err); | ||
| } else { | ||
| res.all((err2, rows) => { | ||
| if (err2 !== null) { | ||
| reject(err2); | ||
| } else { | ||
| resolve(rows); | ||
| } | ||
| }); | ||
| } | ||
| }); | ||
| }); | ||
| } | ||
| } | ||
| async function postgres(dbUrl, options) { | ||
| return await sqlTool(new PostgresTool(dbUrl), options); | ||
| } | ||
| async function mysql(dbUrl, options) { | ||
| return await sqlTool(new MySQLTool(dbUrl), options); | ||
| } | ||
| async function sqlite(dbPath, options) { | ||
| return await sqlTool(new SQLiteTool(dbPath), options); | ||
| } | ||
| export { mysql, postgres, sqlite }; |
| import * as mysql from "mysql2"; | ||
| import { Schema, Database } from "../database"; | ||
| export class MySQLTool implements Database { | ||
| private url: string; | ||
| private client: mysql.Connection; | ||
| constructor(dbUrl: string) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = mysql.createConnection(this.url); | ||
| this.client.connect(); | ||
| } | ||
| async describe(): Promise<Schema> { | ||
| const res: { | ||
| table_schema: string; | ||
| table_name: string; | ||
| columns: string; | ||
| }[] = await new Promise((resolve, reject) => | ||
| this.client.query( | ||
| { | ||
| sql: `SELECT table_schema, | ||
| table_name, | ||
| GROUP_CONCAT( | ||
| CONCAT( | ||
| column_name, ' ', | ||
| CASE | ||
| WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')') | ||
| WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')') | ||
| ELSE data_type | ||
| END, ' ', | ||
| IF(is_nullable = 'YES', 'NULL', 'NOT NULL') | ||
| ) ORDER BY ordinal_position | ||
| SEPARATOR '\n' -- Newline separator | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| `, | ||
| }, | ||
| (err, rows, _fields) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(rows as any); | ||
| } | ||
| ) | ||
| ); | ||
| const createTableStatements = res.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| // Construct the CREATE TABLE statement | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "MySQL", | ||
| description: schema, | ||
| }; | ||
| } | ||
| async query(query: string): Promise<object[]> { | ||
| const res = await new Promise((resolve, reject) => | ||
| this.client.query({ sql: query }, (err, res) => { | ||
| if (err) { | ||
| reject(err); | ||
| } | ||
| resolve(res); | ||
| }) | ||
| ); | ||
| return res as object[]; | ||
| } | ||
| } |
| import { Client, QueryResult } from "pg"; | ||
| import { Schema, Database } from "../database"; | ||
| export class PostgresTool implements Database { | ||
| private url: string; | ||
| private client: Client; | ||
| constructor(dbUrl: string) { | ||
| this.url = dbUrl; | ||
| } | ||
| async initialize() { | ||
| this.client = new Client(this.url); | ||
| await this.client.connect(); | ||
| } | ||
| async describe(): Promise<Schema> { | ||
| const result = await this.client.query(` | ||
| SELECT table_schema, | ||
| table_name, | ||
| string_agg( | ||
| format('%s %s %s', column_name, | ||
| CASE | ||
| WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')' | ||
| WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')' | ||
| WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')' | ||
| ELSE data_type | ||
| END, | ||
| CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END), | ||
| ',\n ' ORDER BY ordinal_position | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('pg_catalog', 'information_schema') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| `); | ||
| const createTableStatements = result.rows.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| // Construct the CREATE TABLE statement | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| const schema = createTableStatements.join("\n\n"); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: schema, | ||
| }; | ||
| } | ||
| async query(query: string) { | ||
| const result = await this.client.query(query); | ||
| const res = Array.isArray(result) ? result : [result]; | ||
| return res.map((r: QueryResult) => r.rows); | ||
| } | ||
| } |
| import { RunResult, Database as SQLite } from "sqlite3"; | ||
| import { Schema, Database } from "../database"; | ||
| function getDatabaseSchema(db: SQLite): Promise<string> { | ||
| return new Promise((resolve, reject) => { | ||
| // Retrieve all "normal" tables (exclude SQLite's internal tables) | ||
| const getTablesQuery = ` | ||
| SELECT name | ||
| FROM sqlite_master | ||
| WHERE type = 'table' | ||
| AND name NOT LIKE 'sqlite_%' | ||
| ORDER BY name; | ||
| `; | ||
| db.all(getTablesQuery, [], (err, tables) => { | ||
| if (err) { | ||
| return reject(`Failed to retrieve tables: ${err.message}`); | ||
| } | ||
| if (!tables.length) { | ||
| return resolve("No tables in this database."); | ||
| } | ||
| let schemaOutput = ""; | ||
| let pendingTables = tables.length; | ||
| // Iterate over tables and get column info | ||
| tables.forEach((name) => { | ||
| db.all(`PRAGMA table_info(${name})`, [], (err, columns) => { | ||
| if (err) { | ||
| return reject( | ||
| `Failed to get table_info for ${name}: ${err.message}`, | ||
| ); | ||
| } | ||
| // Add table header | ||
| schemaOutput += `Table: ${name}\n`; | ||
| // List each column with its type and other attributes | ||
| columns.forEach((col) => { | ||
| let columnStr = ` ${col.name} ${col.type}`; | ||
| if (col.pk === 1) { | ||
| columnStr += " PRIMARY KEY"; | ||
| } | ||
| if (col.notnull === 1) { | ||
| columnStr += " NOT NULL"; | ||
| } | ||
| if (col.dflt_value !== null) { | ||
| columnStr += ` DEFAULT ${col.dflt_value}`; | ||
| } | ||
| schemaOutput += columnStr + "\n"; | ||
| }); | ||
| // Add a blank line after each table | ||
| schemaOutput += "\n"; | ||
| // If all tables have been processed, resolve the output | ||
| pendingTables -= 1; | ||
| if (pendingTables === 0) { | ||
| resolve(schemaOutput.trim()); | ||
| } | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
| } | ||
| export class SQLiteTool implements Database { | ||
| private path: string; | ||
| private db: SQLite; | ||
| constructor(dbPath: string) { | ||
| this.path = dbPath; | ||
| } | ||
| static inMemory() { | ||
| return new SQLiteTool(":memory:"); | ||
| } | ||
| async initialize() { | ||
| this.db = new SQLite(this.path); | ||
| } | ||
| async describe(): Promise<Schema> { | ||
| const description: string = await getDatabaseSchema(this.db); | ||
| return { | ||
| database: "PostgreSQL", | ||
| description: ` | ||
| sqlite3 database at ${this.path} | ||
| schema: | ||
| ${description} | ||
| `.trim(), | ||
| }; | ||
| } | ||
| async query(query: string): Promise<unknown[]> { | ||
| return await new Promise((resolve, reject) => { | ||
| this.db.run(query, (res: RunResult, err: Error | null) => { | ||
| if (err !== null) { | ||
| reject(err); | ||
| } else { | ||
| res.all((err, rows) => { | ||
| if (err !== null) { | ||
| reject(err); | ||
| } else { | ||
| resolve(rows as unknown[]); | ||
| } | ||
| }); | ||
| } | ||
| }); | ||
| }); | ||
| } | ||
| } |
+5
-2
| { | ||
| "name": "ai-sql", | ||
| "description": "Give any Vercel AI SDK project the ability to interact with PostgreSQL, SQlite or MySQL databases in one line.", | ||
| "version": "0.0.1-4", | ||
| "version": "0.0.1-6", | ||
| "main": "./dist/index.cjs", | ||
| "module": "./dist/index.mjs", | ||
| "types": "./dist/index.d.cts", | ||
| "type": "module", | ||
| "exports": { | ||
@@ -33,5 +34,7 @@ "require": { | ||
| "dependencies": { | ||
| "@types/pg": "^8.11.10", | ||
| "ai": "^4.0.23", | ||
| "mysql2": "^3.12.0", | ||
| "pg": "^8.13.1", | ||
| "@types/pg": "^8.11.10", | ||
| "sqlite3": "^5.1.7", | ||
| "zod": "^3.24.1" | ||
@@ -38,0 +41,0 @@ }, |
+43
-5
@@ -7,14 +7,13 @@ # AI SQL | ||
| Note: We rely on bun for running these, but not for installing. | ||
| ```ts | ||
| import * as ai from "ai"; | ||
| import { createOpenAI } from "@ai-sdk/openai"; | ||
| import { postgres, sqlTool } from "ai-sql"; // or mysql, sqlite | ||
| import { postgres } from "ai-sql"; // or mysql, sqlite | ||
| const openai = createOpenAI({ | ||
| compatibility: "strict", | ||
| apiKey: process.env.OPENAI_API_KEY!, | ||
| }); | ||
| const model = openai("gpt-4-turbo"); | ||
| const { text } = await ai.generateText({ | ||
@@ -24,3 +23,3 @@ model: openai("gpt-4-turbo"), | ||
| tools: { | ||
| postgreSQL: await sqlTool(postgres(process.env.POSTGRES_URL!)), | ||
| database: await postgres(process.env.POSTGRES_URL!), | ||
| }, | ||
@@ -30,1 +29,40 @@ maxSteps: 3, | ||
| ``` | ||
| For more examples, see the [example](./example) directory. | ||
| ## Creating a provider | ||
| ```typescript | ||
| import { Schema, Database, sqlTool } from "ai-sql"; | ||
| export class MyDbTool implements Database { | ||
| async initialize() { | ||
| // do setup here | ||
| } | ||
| async describe(): Promise<Schema> { | ||
| // describe the schema | ||
| return { | ||
| // database type | ||
| database: "my database", | ||
| // stringified schema representation | ||
| description: ` | ||
| create table messages ( | ||
| id integer primary key, | ||
| text string not null, | ||
| ); | ||
| `, | ||
| }; | ||
| } | ||
| async query(query: string) { | ||
| // return result rows here | ||
| return []; | ||
| } | ||
| } | ||
| const tools = { | ||
| database: await sqlTool(new MyDbTool()), | ||
| }; | ||
| ``` |
+37
-13
@@ -1,16 +0,8 @@ | ||
| export interface ColumnDefinition { | ||
| name: string; | ||
| type: string; | ||
| primaryKey: boolean; | ||
| } | ||
| import { tool } from "ai"; | ||
| import { z } from "zod"; | ||
| export interface TableDefinition { | ||
| name: string; | ||
| columns: ColumnDefinition[]; | ||
| } | ||
| export interface Schema { | ||
| database: string; | ||
| tables: TableDefinition[]; | ||
| description: string; | ||
| notes?: string[]; | ||
| } | ||
@@ -23,3 +15,35 @@ | ||
| query: (query: string) => Promise<object[]>; | ||
| query: (query: string) => Promise<unknown[]>; | ||
| } | ||
| const descriptionTemplate = (schema: unknown) => | ||
| `Query a database with the following schema:\n\n${JSON.stringify(schema)}`; | ||
| export interface SqlToolOptions { | ||
| notes?: string[]; | ||
| } | ||
| export async function sqlTool(db: Database, { notes }: SqlToolOptions = {}) { | ||
| await db.initialize(); | ||
| const schema = await db.describe(); | ||
| schema.notes = notes; | ||
| return tool({ | ||
| description: descriptionTemplate(schema), | ||
| execute: async ({ query }) => { | ||
| return await db.query(query); | ||
| }, | ||
| parameters: z.object({ | ||
| query: z | ||
| .string() | ||
| .describe( | ||
| `${schema.database} Query to execute.\n\nNotes: ${schema.notes?.join( | ||
| ", " | ||
| )}` | ||
| ), | ||
| }), | ||
| }); | ||
| } |
+13
-53
@@ -1,56 +0,16 @@ | ||
| import { tool } from "ai"; | ||
| import { z } from "zod"; | ||
| import { Client, QueryResult } from "pg"; | ||
| import { PostgresTool } from "./providers/postgres"; | ||
| import { sqlTool, SqlToolOptions } from "./database"; | ||
| import { MySQLTool } from "./providers/mysql"; | ||
| import { SQLiteTool } from "./providers/sqlite"; | ||
| export const postgreSQLTool = async (database_url: string) => { | ||
| const client = new Client(database_url); | ||
| await client.connect(); | ||
| export async function postgres(dbUrl: string, options?: SqlToolOptions) { | ||
| return await sqlTool(new PostgresTool(dbUrl), options); | ||
| } | ||
| const result = await client.query(` | ||
| SELECT table_schema, | ||
| table_name, | ||
| string_agg( | ||
| format('%s %s %s', column_name, | ||
| CASE | ||
| WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')' | ||
| WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')' | ||
| WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')' | ||
| ELSE data_type | ||
| END, | ||
| CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END), | ||
| ',\n ' ORDER BY ordinal_position | ||
| ) AS columns | ||
| FROM information_schema.columns | ||
| WHERE table_schema NOT IN ('pg_catalog', 'information_schema') | ||
| GROUP BY table_schema, table_name | ||
| ORDER BY table_schema, table_name; | ||
| `); | ||
| export async function mysql(dbUrl: string, options?: SqlToolOptions) { | ||
| return await sqlTool(new MySQLTool(dbUrl), options); | ||
| } | ||
| const createTableStatements = result.rows.map((row) => { | ||
| const { table_schema, table_name, columns } = row; | ||
| // Construct the CREATE TABLE statement | ||
| return ` | ||
| CREATE TABLE ${table_schema}.${table_name} ( | ||
| ${columns} | ||
| ); | ||
| `.trim(); | ||
| }); | ||
| let schema = createTableStatements.join("\n\n"); | ||
| return tool({ | ||
| description: `Query PostgreSQL Database with the following schema:\n\n${schema}`, | ||
| execute: async (query) => { | ||
| const result = await client.query(query.query!); | ||
| const res = Array.isArray(result) ? result : [result]; | ||
| const rows = res.map((r: QueryResult) => r.rows); | ||
| return rows; | ||
| }, | ||
| parameters: z.object({ | ||
| query: z.string().describe("SQL Query to execute"), | ||
| }), | ||
| }); | ||
| }; | ||
| export async function sqlite(dbPath: string, options?: SqlToolOptions) { | ||
| return await sqlTool(new SQLiteTool(dbPath), options); | ||
| } |
Sorry, the diff of this file is not supported yet
| { | ||
| "name": "example", | ||
| "version": "1.0.0", | ||
| "main": "index.js", | ||
| "scripts": { | ||
| "test": "echo \"Error: no test specified\" && exit 1" | ||
| }, | ||
| "keywords": [], | ||
| "author": "", | ||
| "license": "ISC", | ||
| "description": "", | ||
| "dependencies": { | ||
| "@ai-sdk/openai": "^1.0.12", | ||
| "ai": "^4.0.23", | ||
| "ai-sql": "file:..", | ||
| "dotenv": "^16.4.7" | ||
| } | ||
| } |
| import { postgreSQLTool } from "../../src"; | ||
| import * as dotenv from "dotenv"; | ||
| import * as ai from "ai"; | ||
| dotenv.config(); | ||
| import { createOpenAI } from "@ai-sdk/openai"; | ||
| const openai = createOpenAI({ | ||
| compatibility: "strict", | ||
| apiKey: process.env.OPENAI_API_KEY!, | ||
| }); | ||
| const params = { | ||
| model: openai("gpt-4-turbo"), | ||
| maxSteps: 3, | ||
| }; | ||
| const { text: text1 } = await ai.generateText({ | ||
| ...params, | ||
| tools: { | ||
| postgreSQL: await postgreSQLTool(process.env.POSTGRES_URL!), | ||
| }, | ||
| prompt: | ||
| "Create an employees table and an hours table to track work, add some example empoyees and hours (about 1 month worth of entries for each employee), and return all employees", | ||
| }); | ||
| console.log("OUTPUT", text1); | ||
| const { text: text2 } = await ai.generateText({ | ||
| ...params, | ||
| tools: { | ||
| postgreSQL: await postgreSQLTool(process.env.POSTGRES_URL!), | ||
| }, | ||
| prompt: | ||
| "How many employees worked more than 10 hours during any week of the month?", | ||
| }); | ||
| console.log("OUTPUT", text2); |
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
Environment variable access
Supply chain riskPackage accesses environment variables, which may be a sign of credential stuffing or data theft.
Found 2 instances in 1 package
Mixed license
LicensePackage contains multiple licenses.
Found 1 instance in 1 package
29021
78.06%12
50%0
-100%793
717.53%66
135.71%1
-75%Yes
NaN6
50%1
Infinity%+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added
+ Added