--- slug: /create-custim-embedding-functions-of-api --- # Create a custom embedding function You can create a custom embedding function by implementing the `EmbeddedFunction` protocol. This function includes the following features: * Execute the `__call__` method, which accepts `Documents (str or List[str])` and returns `Embeddings (List[List[float]])`. * Optionally implement a dimension attribute to return the vector dimension. ## Prerequisites Before creating a custom embedding function, ensure the following: * Implement the `__call__` method: * Each vector must have the same dimension. * Input: The type of a single or multiple documents is str or List[str]. * Output: The field type of the embedded vectors is `List[List[float]]`. * (Recommended) Implement the dimension attribute: * Output: The type of the vectors generated by this function is `int`. * Creating collections helps verify uniqueness. * Handle special cases * Convert a single string input to a list. * Return an empty list for empty inputs. * All vectors in the output must have the same dimension. ## Example 1: Sentence Transformer custom embedding function ```python from typing import List, Union from pyseekdb import EmbeddingFunction, Client, HNSWConfiguration Documents = Union[str, List[str]] Embeddings = List[List[float]] class SentenceTransformerCustomEmbeddingFunction(EmbeddingFunction[Documents]): """ A custom embedding function using sentence-transformers with a specific model. """ def __init__(self, model_name: str = "all-mpnet-base-v2", device: str = "cpu"): # TODO: your own model name and device """ Initialize the sentence-transformer embedding function. Args: model_name: Name of the sentence-transformers model to use device: Device to run the model on ('cpu' or 'cuda') """ self.model_name = model_name self.device = device self._model = None self._dimension = None def _ensure_model_loaded(self): """Lazy load the embedding model""" if self._model is None: try: from sentence_transformers import SentenceTransformer self._model = SentenceTransformer(self.model_name, device=self.device) # Get dimension from model test_embedding = self._model.encode(["test"], convert_to_numpy=True) self._dimension = len(test_embedding[0]) except ImportError: raise ImportError( "sentence-transformers is not installed. " "Please install it with: pip install sentence-transformers" ) @property def dimension(self) -> int: """Get the dimension of embeddings produced by this function""" self._ensure_model_loaded() return self._dimension def __call__(self, input: Documents) -> Embeddings: """ Generate embeddings for the given documents. Args: input: Single document (str) or list of documents (List[str]) Returns: List of embedding vectors """ self._ensure_model_loaded() # Handle single string input if isinstance(input, str): input = [input] # Handle empty input if not input: return [] # Generate embeddings embeddings = self._model.encode( input, convert_to_numpy=True, show_progress_bar=False ) # Convert numpy arrays to lists return [embedding.tolist() for embedding in embeddings] # Use the custom embedding function client = Client() # Initialize embedding function with all-mpnet-base-v2 model (768 dimensions) ef = SentenceTransformerCustomEmbeddingFunction( model_name='all-mpnet-base-v2', # TODO: your own model name device='cpu' # TODO: your own device ) # Get the dimension from the embedding function dimension = ef.dimension print(f"Embedding dimension: {dimension}") # Create collection with matching dimension collection_name = "my_collection" if client.has_collection(collection_name): client.delete_collection(collection_name) collection = client.create_collection( name=collection_name, configuration=HNSWConfiguration(dimension=dimension, distance='cosine'), embedding_function=ef ) # Test the embedding function print("\nTesting embedding function...") test_documents = ["Hello world", "This is a test", "Sentence transformers are great"] embeddings = ef(test_documents) print(f"Generated {len(embeddings)} embeddings") print(f"Each embedding has {len(embeddings[0])} dimensions") # Add some documents to the collection print("\nAdding documents to collection...") collection.add( ids=["1", "2", "3"], documents=test_documents, metadatas=[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}] ) # Query the collection print("\nQuerying collection...") results = collection.query( query_texts="Hello", n_results=2 ) print("\nQuery results:") for i in range(len(results['ids'][0])): print(f"ID: {results['ids'][0][i]}") print(f"Document: {results['documents'][0][i]}") print(f"Distance: {results['distances'][0][i]}") print() # Clean up client.delete_collection(name=collection_name) print("Test completed successfully!") ``` ## Example 2: OpenAI embedding function ```python from typing import List, Union import os from openai import OpenAI from pyseekdb import EmbeddingFunction import pyseekdb Documents = Union[str, List[str]] Embeddings = List[List[float]] class QWenEmbeddingFunction(EmbeddingFunction[Documents]): """ A custom embedding function using OpenAI's embedding API. """ def __init__(self, model_name: str = "", api_key: str = ""): # TODO: your own model name and api key """ Initialize the OpenAI embedding function. Args: model_name: Name of the OpenAI embedding model api_key: OpenAI API key (if not provided, uses OPENAI_API_KEY env var) """ self.model_name = model_name self.api_key = api_key or os.environ.get('OPENAI_API_KEY') # TODO: your own api key if not self.api_key: raise ValueError("OpenAI API key is required") self._dimension = 1024 # TODO: your own dimension @property def dimension(self) -> int: """Get the dimension of embeddings produced by this function""" if self._dimension is None: # Call API to get dimension (or use known values) raise ValueError("Dimension not set for this model") return self._dimension def __call__(self, input: Documents) -> Embeddings: """ Generate embeddings using OpenAI API. Args: input: Single document (str) or list of documents (List[str]) Returns: List of embedding vectors """ # Handle single string input if isinstance(input, str): input = [input] # Handle empty input if not input: return [] # Call OpenAI API client = OpenAI( api_key=self.api_key, base_url="" # TODO: your own base url ) response = client.embeddings.create( model=self.model_name, input=input ) # Extract embeddings embeddings = [item.embedding for item in response.data] return embeddings # Use the custom embedding function collection_name = "my_collection" ef = QWenEmbeddingFunction() client = pyseekdb.Client() if client.has_collection(collection_name): client.delete_collection(collection_name) collection = client.create_collection( name=collection_name, embedding_function=ef ) collection.add( ids=["1", "2", "3"], documents=["Hello", "World", "Hello World"], metadatas=[{"tag": "A"}, {"tag": "B"}, {"tag": "C"}] ) results = collection.query( query_texts="Hello", n_results=2 ) for i in range(len(results['ids'][0])): print(results['ids'][0][i]) print(results['documents'][0][i]) print(results['metadatas'][0][i]) print(results['distances'][0][i]) print() client.delete_collection(name=collection_name) ```