add support for Gemini (#1465)

### What problem does this PR solve?

#1036

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-07-11 15:41:00 +08:00
committed by GitHub
parent 2290c2a2f0
commit 3e9f444e6b
9 changed files with 263 additions and 2 deletions

View File

@ -31,7 +31,7 @@ import numpy as np
import asyncio
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
class Base(ABC):
def __init__(self, key, model_name):
@ -419,3 +419,27 @@ class BedrockEmbed(Base):
return np.array(embeddings), token_count
class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004',
**kwargs):
genai.configure(api_key=key)
self.model_name = 'models/' + model_name
def encode(self, texts: list, batch_size=32):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
result = genai.embed_content(
model=self.model_name,
content=texts,
task_type="retrieval_document",
title="Embedding of list of strings")
return np.array(result['embedding']),token_count
def encode_queries(self, text):
result = genai.embed_content(
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
token_count = num_tokens_from_string(text)
return np.array(result['embedding']),token_count