Files
gh-oceanbase-ecology-plugin…/skills/seekdb-docs/official-docs/200.develop/900.sdk/60.embedding-funcations/200.create-custim-embedding-functions-of-api.md
2025-11-30 08:44:54 +08:00

8.3 KiB

slug
slug
/create-custim-embedding-functions-of-api

Create a custom embedding function

You can create a custom embedding function by implementing the EmbeddedFunction protocol. This function includes the following features:

  • Execute the __call__ method, which accepts Documents (str or List[str]) and returns Embeddings (List[List[float]]).

  • Optionally implement a dimension attribute to return the vector dimension.

Prerequisites

Before creating a custom embedding function, ensure the following:

  • Implement the __call__ method:

    • Each vector must have the same dimension.
    • Input: The type of a single or multiple documents is str or List[str].
    • Output: The field type of the embedded vectors is List[List[float]].
  • (Recommended) Implement the dimension attribute:

    • Output: The type of the vectors generated by this function is int.
    • Creating collections helps verify uniqueness.
  • Handle special cases

    • Convert a single string input to a list.
    • Return an empty list for empty inputs.
    • All vectors in the output must have the same dimension.

Example 1: Sentence Transformer custom embedding function

from typing import List, Union
from pyseekdb import EmbeddingFunction, Client, HNSWConfiguration

Documents = Union[str, List[str]]
Embeddings = List[List[float]]

class SentenceTransformerCustomEmbeddingFunction(EmbeddingFunction[Documents]):
    """
    A custom embedding function using sentence-transformers with a specific model.
    """
    
    def __init__(self, model_name: str = "all-mpnet-base-v2", device: str = "cpu"): # TODO: your own model name and device
        """
        Initialize the sentence-transformer embedding function.
        
        Args:
            model_name: Name of the sentence-transformers model to use
            device: Device to run the model on ('cpu' or 'cuda')
        """
        self.model_name = model_name
        self.device = device
        self._model = None
        self._dimension = None
    
    def _ensure_model_loaded(self):
        """Lazy load the embedding model"""
        if self._model is None:
            try:
                from sentence_transformers import SentenceTransformer
                self._model = SentenceTransformer(self.model_name, device=self.device)
                # Get dimension from model
                test_embedding = self._model.encode(["test"], convert_to_numpy=True)
                self._dimension = len(test_embedding[0])
            except ImportError:
                raise ImportError(
                    "sentence-transformers is not installed. "
                    "Please install it with: pip install sentence-transformers"
                )
    
    @property
    def dimension(self) -> int:
        """Get the dimension of embeddings produced by this function"""
        self._ensure_model_loaded()
        return self._dimension
    
    def __call__(self, input: Documents) -> Embeddings:
        """
        Generate embeddings for the given documents.
        
        Args:
            input: Single document (str) or list of documents (List[str])
            
        Returns:
            List of embedding vectors
        """
        self._ensure_model_loaded()
        
        # Handle single string input
        if isinstance(input, str):
            input = [input]
        
        # Handle empty input
        if not input:
            return []
        
        # Generate embeddings
        embeddings = self._model.encode(
            input,
            convert_to_numpy=True,
            show_progress_bar=False
        )
        
        # Convert numpy arrays to lists
        return [embedding.tolist() for embedding in embeddings]

# Use the custom embedding function
client = Client()

# Initialize embedding function with all-mpnet-base-v2 model (768 dimensions)
ef = SentenceTransformerCustomEmbeddingFunction(
    model_name='all-mpnet-base-v2', # TODO: your own model name
    device='cpu' # TODO: your own device
)

# Get the dimension from the embedding function
dimension = ef.dimension
print(f"Embedding dimension: {dimension}")

# Create collection with matching dimension
collection_name = "my_collection"
if client.has_collection(collection_name):
    client.delete_collection(collection_name)

collection = client.create_collection(
    name=collection_name,
    configuration=HNSWConfiguration(dimension=dimension, distance='cosine'),
    embedding_function=ef
)

# Test the embedding function
print("\nTesting embedding function...")
test_documents = ["Hello world", "This is a test", "Sentence transformers are great"]
embeddings = ef(test_documents)
print(f"Generated {len(embeddings)} embeddings")
print(f"Each embedding has {len(embeddings[0])} dimensions")

# Add some documents to the collection
print("\nAdding documents to collection...")
collection.add(
    ids=["1", "2", "3"],
    documents=test_documents,
    metadatas=[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
)

# Query the collection
print("\nQuerying collection...")
results = collection.query(
    query_texts="Hello",
    n_results=2
)

print("\nQuery results:")
for i in range(len(results['ids'][0])):
    print(f"ID: {results['ids'][0][i]}")
    print(f"Document: {results['documents'][0][i]}")
    print(f"Distance: {results['distances'][0][i]}")
    print()

# Clean up
client.delete_collection(name=collection_name)
print("Test completed successfully!")

Example 2: OpenAI embedding function

from typing import List, Union
import os
from openai import OpenAI
from pyseekdb import EmbeddingFunction
import pyseekdb

Documents = Union[str, List[str]]
Embeddings = List[List[float]]

class QWenEmbeddingFunction(EmbeddingFunction[Documents]):
    """
    A custom embedding function using OpenAI's embedding API.
    """
    
    def __init__(self, model_name: str = "", api_key: str = ""): # TODO: your own model name and api key
        """
        Initialize the OpenAI embedding function.
        
        Args:
            model_name: Name of the OpenAI embedding model
            api_key: OpenAI API key (if not provided, uses OPENAI_API_KEY env var)
        """
        self.model_name = model_name
        self.api_key = api_key or os.environ.get('OPENAI_API_KEY') # TODO: your own api key
        if not self.api_key:
            raise ValueError("OpenAI API key is required")
        
        self._dimension = 1024 # TODO: your own dimension
    
    @property
    def dimension(self) -> int:
        """Get the dimension of embeddings produced by this function"""
        if self._dimension is None:
            # Call API to get dimension (or use known values)
            raise ValueError("Dimension not set for this model")
        return self._dimension
    
    def __call__(self, input: Documents) -> Embeddings:
        """
        Generate embeddings using OpenAI API.
        
        Args:
            input: Single document (str) or list of documents (List[str])
            
        Returns:
            List of embedding vectors
        """
        # Handle single string input
        if isinstance(input, str):
            input = [input]
        
        # Handle empty input
        if not input:
            return []
        
        # Call OpenAI API
        client = OpenAI(
            api_key=self.api_key,  
            base_url="" # TODO: your own base url
        )
        response = client.embeddings.create(
            model=self.model_name,
            input=input
        )
        
        # Extract embeddings
        embeddings = [item.embedding for item in response.data]
        return embeddings

# Use the custom embedding function
collection_name = "my_collection"
ef = QWenEmbeddingFunction()
client = pyseekdb.Client()

if client.has_collection(collection_name):
    client.delete_collection(collection_name)

collection = client.create_collection(
        name=collection_name,
        embedding_function=ef
    )

collection.add(
    ids=["1", "2", "3"],
    documents=["Hello", "World", "Hello World"],
    metadatas=[{"tag": "A"}, {"tag": "B"}, {"tag": "C"}]
)

results = collection.query(
    query_texts="Hello",
    n_results=2
)
for i in range(len(results['ids'][0])):
    print(results['ids'][0][i])
    print(results['documents'][0][i])
    print(results['metadatas'][0][i])
    print(results['distances'][0][i])
    print()

client.delete_collection(name=collection_name)