How to implement and test a vector store integration
This guide walks through how to implement and test a custom vector store that you have developed.
For testing, we will rely on the langchain-tests
dependency we added in the previous bootstrapping guide.
Implementation
Let's say you're building a simple integration package that provides a ParrotVectorStore
vector store integration for LangChain. Here's a simple example of what your project
structure might look like:
langchain-parrot-link/
├── langchain_parrot_link/
│ ├── __init__.py
│ └── vectorstores.py
├── tests/
│ ├── __init__.py
│ └── test_vectorstores.py
├── pyproject.toml
└── README.md
Following the bootstrapping guide,
all of these files should already exist, except for
vectorstores.py
and test_vectorstores.py
. We will implement these files in this guide.
First we need an implementation for our vector store. This implementation will depend
on your chosen database technology. langchain-core
includes a minimal
in-memory vector store
that we can use as a guide. You can access the code here.
For convenience, we also provide it below.
vectorstores.py
from __future__ import annotations
import json
import uuid
from collections.abc import Iterator, Sequence
from pathlib import Path
from typing import Any, Callable, Optional
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.load import dumpd, load
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
from langchain_core.vectorstores.utils import maximal_marginal_relevance
class InMemoryVectorStore(VectorStore):
"""In-memory vector store implementation.
Uses a dictionary, and computes cosine similarity for search using numpy.
"""
def __init__(self, embedding: Embeddings) -> None:
"""Initialize with the given embedding function.
Args:
embedding: embedding function to use.
"""
self.store: dict[str, dict[str, Any]] = {}
self.embedding = embedding
@property
def embeddings(self) -> Embeddings:
return self.embedding
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
if ids:
for _id in ids:
self.store.pop(_id, None)
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
self.delete(ids)
def add_documents(
self,
documents: list[Document],
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts)
if ids and len(ids) != len(texts):
msg = (
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
raise ValueError(msg)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_ = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self.store[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": doc.page_content,
"metadata": doc.metadata,
}
return ids_
async def aadd_documents(
self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any
) -> list[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = await self.embedding.aembed_documents(texts)
if ids and len(ids) != len(texts):
msg = (
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
raise ValueError(msg)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_: list[str] = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self.store[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": doc.page_content,
"metadata": doc.metadata,
}
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their ids.
Args:
ids: The ids of the documents to get.
Returns:
A list of Document objects.
"""
documents = []
for doc_id in ids:
doc = self.store.get(doc_id)
if doc:
documents.append(
Document(
id=doc["id"],
page_content=doc["text"],
metadata=doc["metadata"],
)
)
return documents
async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Async get documents by their ids.
Args:
ids: The ids of the documents to get.
Returns:
A list of Document objects.
"""
return self.get_by_ids(ids)
def _similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> list[tuple[Document, float, list[float]]]:
# get all docs with fixed order in list
docs = list(self.store.values())
if filter is not None:
docs = [
doc
for doc in docs
if filter(Document(page_content=doc["text"], metadata=doc["metadata"]))
]
if not docs:
return []
similarity = cosine_similarity([embedding], [doc["vector"] for doc in docs])[0]
# get the indices ordered by similarity score
top_k_idx = similarity.argsort()[::-1][:k]
return [
(
Document(
id=doc_dict["id"],
page_content=doc_dict["text"],
metadata=doc_dict["metadata"],
),
float(similarity[idx].item()),
doc_dict["vector"],
)
for idx in top_k_idx
# Assign using walrus operator to avoid multiple lookups
if (doc_dict := docs[idx])
]
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
return [
(doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)
]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> list[tuple[Document, float]]:
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs
async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> list[tuple[Document, float]]:
embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs
def similarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
**kwargs: Any,
) -> list[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return [doc for doc, _ in docs_and_scores]
async def asimilarity_search_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[Document]:
return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return [
doc
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
]
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
prefetch_hits = self._similarity_search_with_score_by_vector(
embedding=embedding,
k=fetch_k,
**kwargs,
)
try:
import numpy as np
except ImportError as e:
msg = (
"numpy must be installed to use max_marginal_relevance_search "
"pip install numpy"
)
raise ImportError(msg) from e
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[vector for _, _, vector in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
return [prefetch_hits[idx][0] for idx in mmr_chosen_indices]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
embedding_vector = await self.embedding.aembed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
@classmethod
def from_texts(
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> InMemoryVectorStore:
store = cls(
embedding=embedding,
)
store.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return store
@classmethod
async def afrom_texts(
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> InMemoryVectorStore:
store = cls(
embedding=embedding,
)
await store.aadd_texts(texts=texts, metadatas=metadatas, **kwargs)
return store
@classmethod
def load(
cls, path: str, embedding: Embeddings, **kwargs: Any
) -> InMemoryVectorStore:
"""Load a vector store from a file.
Args:
path: The path to load the vector store from.
embedding: The embedding to use.
kwargs: Additional arguments to pass to the constructor.
Returns:
A VectorStore object.
"""
_path: Path = Path(path)
with _path.open("r") as f:
store = load(json.load(f))
vectorstore = cls(embedding=embedding, **kwargs)
vectorstore.store = store
return vectorstore
def dump(self, path: str) -> None:
"""Dump the vector store to a file.
Args:
path: The path to dump the vector store to.
"""
_path: Path = Path(path)
_path.parent.mkdir(exist_ok=True, parents=True)
with _path.open("w") as f:
json.dump(dumpd(self.store), f, indent=2)
All vector stores must inherit from the VectorStore base class. This interface consists of methods for writing, deleting and searching for documents in the vector store.
VectorStore
supports a variety of synchronous and asynchronous search types (e.g.,
nearest-neighbor or maximum marginal relevance), as well as interfaces for adding
documents to the store. See the API Reference
for all supported methods. The required methods are tabulated below:
Method/Property | Description |
---|---|
add_documents | Add documents to the vector store. |
delete | Delete selected documents from vector store (by IDs) |
get_by_ids | Get selected documents from vector store (by IDs) |
similarity_search | Get documents most similar to a query. |
embeddings (property) | Embeddings object for vector store. |
from_texts | Instantiate vector store via adding texts. |
Note that InMemoryVectorStore
implements some optional search types, as well as
convenience methods for loading and dumping the object to a file, but this is not
necessary for all implementations.
The in-memory vector store is tested against the standard tests in the LangChain Github repository. You can always use this as a starting point.
Testing
To implement our test files, we will subclass test classes from the langchain_tests
package. These test classes contain the tests that will be run. We will just need to
configure what vector store implementation is tested.
Setup
First we need to install certain dependencies. These include:
pytest
: For running testspytest-asyncio
: For testing async functionalitylangchain-tests
: For importing standard testslangchain-core
: This should already be installed, but is needed to define our integration.
If you followed the previous bootstrapping guide, these should already be installed.
Add and configure standard tests
The langchain-test
package implements suites of tests for testing vector store
integrations. By subclassing the base classes for each standard test, you
get all of the standard tests for that type.
The full set of tests that run can be found in the API reference. See details:
Here's how you would configure the standard tests for a typical vector store (using
ParrotVectorStore
as a placeholder):
# title="tests/integration_tests/test_vectorstores.py"
from typing import AsyncGenerator, Generator
import pytest
from langchain_core.vectorstores import VectorStore
from langchain_parrot_link.vectorstores import ParrotVectorStore
from langchain_tests.integration_tests.vectorstores import (
AsyncReadWriteTestSuite,
ReadWriteTestSuite,
)
class TestSync(ReadWriteTestSuite):
@pytest.fixture()
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore."""
store = ParrotVectorStore(self.get_embeddings())
# note: store should be EMPTY at this point
# if you need to delete data, you may do so here
try:
yield store
finally:
# cleanup operations, or deleting data
pass
class TestAsync(AsyncReadWriteTestSuite):
@pytest.fixture()
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
"""Get an empty vectorstore."""
store = ParrotVectorStore(self.get_embeddings())
# note: store should be EMPTY at this point
# if you need to delete data, you may do so here
try:
yield store
finally:
# cleanup operations, or deleting data
pass
There are separate suites for testing synchronous and asynchronous methods. Configuring the tests consists of implementing pytest fixtures for setting up an empty vector store and tearing down the vector store after the test run ends.
For example, below is the ReadWriteTestSuite
for the Chroma
integration:
from typing import Generator
import pytest
from langchain_core.vectorstores import VectorStore
from langchain_tests.integration_tests.vectorstores import ReadWriteTestSuite
from langchain_chroma import Chroma
class TestSync(ReadWriteTestSuite):
@pytest.fixture()
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore."""
store = Chroma(embedding_function=self.get_embeddings())
try:
yield store
finally:
store.delete_collection()
pass
Note that before the initial yield
, we instantiate the vector store with an
embeddings object. This is a pre-defined
"fake" embeddings model
that will generate short, arbitrary vectors for documents. You can use a different
embeddings object if desired.
In the finally
block, we call whatever integration-specific logic is needed to
bring the vector store to a clean state. This logic is executed in between each test
(e.g., even if tests fail).
Run standard tests
After setting tests up, you would run them with the following command from your project root:
pytest --asyncio-mode=auto tests/integration_tests
Test suite information and troubleshooting
Each test method documents:
- Troubleshooting tips;
- (If applicable) how test can be skipped.
This information along with the full set of tests that run can be found in the API reference. See details: