mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add huggingface model (#2624)
### What problem does this PR solve? #2469 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
1b2f66fc11
commit
96f56a3c43
@ -18,7 +18,7 @@ from .chat_model import *
|
||||
from .cv_model import *
|
||||
from .rerank_model import *
|
||||
from .sequence2txt_model import *
|
||||
from .tts_model import *
|
||||
from .tts_model import *
|
||||
|
||||
EmbeddingModel = {
|
||||
"Ollama": OllamaEmbed,
|
||||
@ -46,7 +46,8 @@ EmbeddingModel = {
|
||||
"SILICONFLOW": SILICONFLOWEmbed,
|
||||
"Replicate": ReplicateEmbed,
|
||||
"BaiduYiyan": BaiduYiyanEmbed,
|
||||
"Voyage AI": VoyageEmbed
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace":HuggingFaceEmbed,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1414,3 +1414,4 @@ class GoogleChat(Base):
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield response._chunks[-1].usage_metadata.total_token_count
|
||||
|
||||
@ -678,3 +678,40 @@ class VoyageEmbed(Base):
|
||||
texts=text, model=self.model_name, input_type="query"
|
||||
)
|
||||
return np.array(res.embeddings), res.total_tokens
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user