312 lines
9.5 KiB
TypeScript
312 lines
9.5 KiB
TypeScript
/**
|
|
* Document Clustering with Gemini Embeddings
|
|
*
|
|
* Demonstrates automatic grouping of similar documents using K-means clustering.
|
|
* Useful for topic modeling, content organization, and duplicate detection.
|
|
*
|
|
* Setup:
|
|
* 1. npm install @google/genai@^1.27.0
|
|
* 2. export GEMINI_API_KEY="your-api-key"
|
|
*
|
|
* Usage:
|
|
* npx tsx clustering.ts
|
|
*/
|
|
|
|
import { GoogleGenAI } from "@google/genai";
|
|
|
|
interface Document {
|
|
id: string;
|
|
text: string;
|
|
embedding?: number[];
|
|
}
|
|
|
|
interface Cluster {
|
|
id: number;
|
|
centroid: number[];
|
|
documents: Document[];
|
|
}
|
|
|
|
/**
|
|
* Calculate cosine similarity
|
|
*/
|
|
function cosineSimilarity(a: number[], b: number[]): number {
|
|
if (a.length !== b.length) {
|
|
throw new Error('Vector dimensions must match');
|
|
}
|
|
|
|
let dotProduct = 0, magA = 0, magB = 0;
|
|
|
|
for (let i = 0; i < a.length; i++) {
|
|
dotProduct += a[i] * b[i];
|
|
magA += a[i] * a[i];
|
|
magB += b[i] * b[i];
|
|
}
|
|
|
|
return dotProduct / (Math.sqrt(magA) * Math.sqrt(magB));
|
|
}
|
|
|
|
/**
|
|
* K-means clustering algorithm
|
|
*/
|
|
function kMeansClustering(
|
|
documents: Document[],
|
|
k: number = 3,
|
|
maxIterations: number = 100
|
|
): Cluster[] {
|
|
if (documents.length === 0 || !documents[0].embedding) {
|
|
throw new Error('Documents must have embeddings');
|
|
}
|
|
|
|
const embeddings = documents.map(d => d.embedding!);
|
|
|
|
// 1. Initialize centroids randomly
|
|
const centroids: number[][] = [];
|
|
const usedIndices = new Set<number>();
|
|
|
|
for (let i = 0; i < k; i++) {
|
|
let randomIndex: number;
|
|
do {
|
|
randomIndex = Math.floor(Math.random() * embeddings.length);
|
|
} while (usedIndices.has(randomIndex));
|
|
|
|
usedIndices.add(randomIndex);
|
|
centroids.push([...embeddings[randomIndex]]);
|
|
}
|
|
|
|
console.log(`🔄 Starting K-means clustering (k=${k}, max iterations=${maxIterations})\n`);
|
|
|
|
// 2. Iterate until convergence
|
|
let iteration = 0;
|
|
let converged = false;
|
|
|
|
while (iteration < maxIterations && !converged) {
|
|
// Assign each document to nearest centroid
|
|
const clusters: Document[][] = Array(k).fill(null).map(() => []);
|
|
|
|
documents.forEach((doc, idx) => {
|
|
const embedding = embeddings[idx];
|
|
let maxSimilarity = -Infinity;
|
|
let closestCluster = 0;
|
|
|
|
centroids.forEach((centroid, i) => {
|
|
const similarity = cosineSimilarity(embedding, centroid);
|
|
if (similarity > maxSimilarity) {
|
|
maxSimilarity = similarity;
|
|
closestCluster = i;
|
|
}
|
|
});
|
|
|
|
clusters[closestCluster].push(doc);
|
|
});
|
|
|
|
// Update centroids (average of cluster members)
|
|
converged = true;
|
|
clusters.forEach((cluster, i) => {
|
|
if (cluster.length === 0) return;
|
|
|
|
const newCentroid = cluster[0].embedding!.map((_, dim) =>
|
|
cluster.reduce((sum, doc) => sum + doc.embedding![dim], 0) / cluster.length
|
|
);
|
|
|
|
// Check if centroid changed significantly
|
|
const similarity = cosineSimilarity(centroids[i], newCentroid);
|
|
if (similarity < 0.9999) {
|
|
converged = false;
|
|
}
|
|
|
|
centroids[i] = newCentroid;
|
|
});
|
|
|
|
iteration++;
|
|
|
|
if (iteration % 10 === 0) {
|
|
console.log(`Iteration ${iteration}...`);
|
|
}
|
|
}
|
|
|
|
console.log(`✅ Converged after ${iteration} iterations\n`);
|
|
|
|
// Build final clusters
|
|
const finalClusters: Cluster[] = centroids.map((centroid, i) => ({
|
|
id: i,
|
|
centroid,
|
|
documents: documents.filter((doc) => {
|
|
const similarities = centroids.map(c => cosineSimilarity(doc.embedding!, c));
|
|
return similarities.indexOf(Math.max(...similarities)) === i;
|
|
})
|
|
}));
|
|
|
|
return finalClusters;
|
|
}
|
|
|
|
/**
|
|
* Clustering by similarity threshold (alternative to K-means)
|
|
*/
|
|
function clusterByThreshold(
|
|
documents: Document[],
|
|
threshold: number = 0.8
|
|
): Cluster[] {
|
|
if (documents.length === 0 || !documents[0].embedding) {
|
|
throw new Error('Documents must have embeddings');
|
|
}
|
|
|
|
const clusters: Cluster[] = [];
|
|
const assigned = new Set<number>();
|
|
|
|
documents.forEach((doc, idx) => {
|
|
if (assigned.has(idx)) return;
|
|
|
|
const clusterDocs = [doc];
|
|
assigned.add(idx);
|
|
|
|
documents.forEach((otherDoc, otherIdx) => {
|
|
if (idx !== otherIdx && !assigned.has(otherIdx)) {
|
|
const similarity = cosineSimilarity(doc.embedding!, otherDoc.embedding!);
|
|
|
|
if (similarity >= threshold) {
|
|
clusterDocs.push(otherDoc);
|
|
assigned.add(otherIdx);
|
|
}
|
|
}
|
|
});
|
|
|
|
clusters.push({
|
|
id: clusters.length,
|
|
centroid: doc.embedding!,
|
|
documents: clusterDocs
|
|
});
|
|
});
|
|
|
|
return clusters;
|
|
}
|
|
|
|
/**
|
|
* Print cluster summary
|
|
*/
|
|
function printClusters(clusters: Cluster[], method: string): void {
|
|
console.log(`━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━`);
|
|
console.log(`${method} Results`);
|
|
console.log(`━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n`);
|
|
|
|
clusters.forEach(cluster => {
|
|
console.log(`📁 Cluster ${cluster.id + 1} (${cluster.documents.length} documents):`);
|
|
console.log(`${'─'.repeat(50)}`);
|
|
|
|
cluster.documents.forEach(doc => {
|
|
const preview = doc.text.substring(0, 80) + (doc.text.length > 80 ? '...' : '');
|
|
console.log(` • [${doc.id}] ${preview}`);
|
|
});
|
|
|
|
console.log('');
|
|
});
|
|
|
|
console.log(`Total clusters: ${clusters.length}\n`);
|
|
}
|
|
|
|
// Example usage
|
|
async function main() {
|
|
try {
|
|
const apiKey = process.env.GEMINI_API_KEY;
|
|
if (!apiKey) {
|
|
throw new Error('GEMINI_API_KEY environment variable not set');
|
|
}
|
|
|
|
const ai = new GoogleGenAI({ apiKey });
|
|
|
|
// Sample documents (3 topics: Geography, AI/ML, Food)
|
|
const documents: Document[] = [
|
|
// Geography
|
|
{ id: 'doc1', text: 'Paris is the capital of France. It is known for the Eiffel Tower and the Louvre Museum.' },
|
|
{ id: 'doc2', text: 'London is the capital of the United Kingdom and home to Big Ben and Buckingham Palace.' },
|
|
{ id: 'doc3', text: 'Rome is the capital of Italy and famous for the Colosseum and Vatican City.' },
|
|
|
|
// AI/ML
|
|
{ id: 'doc4', text: 'Machine learning is a subset of artificial intelligence that enables computers to learn from data.' },
|
|
{ id: 'doc5', text: 'Deep learning uses neural networks with multiple layers to learn complex patterns in data.' },
|
|
{ id: 'doc6', text: 'Natural language processing is a branch of AI that helps computers understand human language.' },
|
|
|
|
// Food
|
|
{ id: 'doc7', text: 'Pizza originated in Italy and is now popular worldwide. It typically has a tomato base and cheese.' },
|
|
{ id: 'doc8', text: 'Sushi is a Japanese dish made with vinegared rice and various ingredients like raw fish.' },
|
|
{ id: 'doc9', text: 'Tacos are a traditional Mexican food consisting of a tortilla filled with various ingredients.' }
|
|
];
|
|
|
|
console.log(`\n📚 Generating embeddings for ${documents.length} documents...\n`);
|
|
|
|
// Generate embeddings
|
|
for (const doc of documents) {
|
|
const response = await ai.models.embedContent({
|
|
model: 'gemini-embedding-001',
|
|
content: doc.text,
|
|
config: {
|
|
taskType: 'CLUSTERING', // ← Optimized for clustering
|
|
outputDimensionality: 768
|
|
}
|
|
});
|
|
|
|
doc.embedding = response.embedding.values;
|
|
console.log(`✅ Embedded: ${doc.id}`);
|
|
}
|
|
|
|
console.log('');
|
|
|
|
// Method 1: K-means clustering
|
|
const kMeansClusters = kMeansClustering(documents, 3, 100);
|
|
printClusters(kMeansClusters, 'K-Means Clustering (k=3)');
|
|
|
|
// Method 2: Threshold-based clustering
|
|
console.log('🔄 Running threshold-based clustering (threshold=0.7)...\n');
|
|
const thresholdClusters = clusterByThreshold(documents, 0.7);
|
|
printClusters(thresholdClusters, 'Threshold-Based Clustering (≥70% similarity)');
|
|
|
|
// Example: Find intra-cluster similarities
|
|
console.log(`━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━`);
|
|
console.log('Cluster Quality Analysis');
|
|
console.log(`━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n`);
|
|
|
|
kMeansClusters.forEach(cluster => {
|
|
if (cluster.documents.length < 2) return;
|
|
|
|
const similarities: number[] = [];
|
|
|
|
for (let i = 0; i < cluster.documents.length; i++) {
|
|
for (let j = i + 1; j < cluster.documents.length; j++) {
|
|
const sim = cosineSimilarity(
|
|
cluster.documents[i].embedding!,
|
|
cluster.documents[j].embedding!
|
|
);
|
|
similarities.push(sim);
|
|
}
|
|
}
|
|
|
|
const avgSimilarity = similarities.reduce((a, b) => a + b, 0) / similarities.length;
|
|
const minSimilarity = Math.min(...similarities);
|
|
const maxSimilarity = Math.max(...similarities);
|
|
|
|
console.log(`Cluster ${cluster.id + 1}:`);
|
|
console.log(` Documents: ${cluster.documents.map(d => d.id).join(', ')}`);
|
|
console.log(` Avg similarity: ${(avgSimilarity * 100).toFixed(1)}%`);
|
|
console.log(` Min similarity: ${(minSimilarity * 100).toFixed(1)}%`);
|
|
console.log(` Max similarity: ${(maxSimilarity * 100).toFixed(1)}%`);
|
|
console.log('');
|
|
});
|
|
|
|
} catch (error: any) {
|
|
console.error('❌ Error:', error.message);
|
|
process.exit(1);
|
|
}
|
|
}
|
|
|
|
main();
|
|
|
|
/**
|
|
* Expected output:
|
|
*
|
|
* Cluster 1: Geography documents (Paris, London, Rome)
|
|
* Cluster 2: AI/ML documents (Machine learning, Deep learning, NLP)
|
|
* Cluster 3: Food documents (Pizza, Sushi, Tacos)
|
|
*
|
|
* This demonstrates how embeddings capture semantic meaning,
|
|
* allowing automatic topic discovery without manual labeling.
|
|
*/
|