Socket
Socket
Sign inDemoInstall

open-retrievals

Package Overview
Dependencies
2
Maintainers
1
Alerts
File Explorer

Install Socket

Detect and block malicious and high-risk dependencies

Install

    open-retrievals

Text Embeddings for Retrieval and RAG based on transformers


Maintainers
1

Readme


LICENSE PyPI Version Build Status Lint Status Docs Status Code Coverage

Documentation | Tutorials | 中文

Open-Retrievals is an easy-to-use python framework getting SOTA text embeddings, oriented to information retrieval and LLM retrieval augmented generation, based on PyTorch and Transformers.

  • Contrastive learning enhanced embeddings
  • LLM embeddings

Installation

Prerequisites

pip install transformers
pip install faiss-cpu
pip install peft

With pip

pip install open-retrievals

Quick-start

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

# Example list of documents
documents = [
    "Open-retrievals is a text embedding libraries",
    "I can use it simply with a SOTA RAG application.",
]

# This will trigger the model download and initialization
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)

embeddings = model.encode(documents)
len(embeddings) # Vector of 384 dimensions

Usage

Build Index and Retrieval

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
model.build_index(sentences)

matcher = AutoModelForRetrieval()
results = matcher.faiss_search("He plays guitar.")

Rerank

from transformers import AutoTokenizer
from retrievals import RerankCollator, RerankModel, RerankTrainer, RerankDataset

train_dataset = RerankDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = RerankModel(
    model_args.model_name_or_path,
    pooling_method="mean"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=RerankCollator(tokenizer, max_length=data_args.query_max_len),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

RAG with LangChain

  • Prerequisites
pip install langchain
  • Server
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import Chroma as Vectorstore


class DenseRetrieval:
    def __init__(self, persist_directory):
        embeddings = LangchainEmbedding(model_name="BAAI/bge-large-zh-v1.5")
        vectordb = Vectorstore(
            persist_directory=persist_directory,
            embedding_function=embeddings,
        )
        retrieval_args = {"search_type" :"similarity", "score_threshold": 0.15, "k": 30}
        self.retriever = vectordb.as_retriever(retrieval_args)

        reranker_args = {
            "model": "../../inputs/bce-reranker-base_v1",
            "top_n": 7,
            "device": "cuda",
            "use_fp16": True,
        }
        self.reranker = LangchainReranker(**reranker_args)
        self.compression_retriever = ContextualCompressionRetriever(
            base_compressor=self.reranker, base_retriever=self.retriever
        )

    def query(
        self,
        question: str
    ):
        docs = self.compression_retriever.get_relevant_documents(question)
        return docs

Use Pretrained sentence embedding

from retrievals import AutoModelForEmbedding

sentences = ["Hello world", "How are you?"]
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path, pooling_method="mean", normalize_embeddings=True)
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
print(sentence_embeddings)

Finetune transformers by contrastive learning

from transformers import AutoTokenizer
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.data import  RetrievalDataset, RerankDataset


train_dataset = RetrievalDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = AutoModelForEmbedding(
    model_args.model_name_or_path,
    pooling_method="cls"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RetrievalTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=TripletCollator(tokenizer, max_length=data_args.query_max_len),
    loss_fn=TripletLoss(),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

Finetune LLM for embedding by Contrastive learning


from retrievals import AutoModelForEmbedding

model = AutoModelForEmbedding(
    "mistralai/Mistral-7B-v0.1",
    pooling_method='cls',
    query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: '
)

Search by Cosine similarity/KNN

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

query_texts = ['A dog is chasing car.']
passage_texts = ['A man is playing a guitar.', 'A bee is flying low']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)

matcher = AutoModelForRetrieval(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)

Reference & Acknowledge

Keywords

FAQs


Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc