mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Don't release full image (#10654)
### What problem does this PR solve? Introduced gpu profile in .env Added Dockerfile_tei fix datrie Removed LIGHTEN flag ### Type of change - [x] Documentation Update - [x] Refactoring
This commit is contained in:
@ -14,9 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
@ -25,15 +23,14 @@ import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from api import settings
|
||||
import logging
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -63,71 +60,42 @@ class Base(ABC):
|
||||
return 0
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
class BuiltinEmbed(Base):
|
||||
_FACTORY_NAME = "Builtin"
|
||||
MAX_TOKENS = {"Qwen/Qwen3-Embedding-0.6B": 30000, "BAAI/bge-m3": 8000, "BAAI/bge-small-en-v1.5": 500}
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_max_tokens = 500
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
input_cuda_visible_devices = None
|
||||
with DefaultEmbedding._model_lock:
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = FlagModel(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available(),
|
||||
)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
finally:
|
||||
if input_cuda_visible_devices:
|
||||
# restore CUDA_VISIBLE_DEVICES
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||
with BuiltinEmbed._model_lock:
|
||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
self._model = BuiltinEmbed._model
|
||||
self._model_name = BuiltinEmbed._model_name
|
||||
self._max_tokens = BuiltinEmbed._max_tokens
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
texts = [truncate(t, self._max_tokens) for t in texts]
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
ress = None
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embeddings, token_count_delta = self._model.encode(texts[i : i + batch_size])
|
||||
token_count += token_count_delta
|
||||
if ress is None:
|
||||
ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)
|
||||
ress = embeddings
|
||||
else:
|
||||
ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0)
|
||||
ress = np.concatenate((ress, embeddings), axis=0)
|
||||
return ress, token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
token_count = num_tokens_from_string(text)
|
||||
return self._model.encode_queries([text], convert_to_numpy=False)[0][0].cpu().numpy(), token_count
|
||||
return self._model.encode_queries(text)
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
@ -326,51 +294,6 @@ class OllamaEmbed(Base):
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class FastEmbed(DefaultEmbedding):
|
||||
_FACTORY_NAME = "FastEmbed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str | None = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not settings.LIGHTEN:
|
||||
with FastEmbed._model_lock:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encodings = self._model.model.tokenizer.encode_batch(texts)
|
||||
total_tokens = sum(len(e) for e in encodings)
|
||||
|
||||
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
||||
|
||||
return np.array(embeddings), total_tokens
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encoding = self._model.model.tokenizer.encode(text)
|
||||
embedding = next(self._model.query_embed(text))
|
||||
return np.array(embedding), len(encoding.ids)
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
@ -407,14 +330,7 @@ class YoudaoEmbed(Base):
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||
except Exception:
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
pass
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 10
|
||||
@ -885,21 +801,18 @@ class HuggingFaceEmbed(Base):
|
||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embeddings = response.json()
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
def encode_queries(self, text: str):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
embedding = response.json()[0]
|
||||
return np.array(embedding), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
@ -14,21 +14,14 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
@ -47,100 +40,6 @@ class Base(ABC):
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
with DefaultRerank._model_lock:
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
log_exception(e)
|
||||
|
||||
def _process_batch(self, pairs, max_batch_size=None):
|
||||
"""template method for subclass call"""
|
||||
old_dynamic_batch_size = self._dynamic_batch_size
|
||||
if max_batch_size is not None:
|
||||
self._dynamic_batch_size = max_batch_size
|
||||
res = np.array(len(pairs), dtype=float)
|
||||
i = 0
|
||||
while i < len(pairs):
|
||||
cur_i = i
|
||||
current_batch = self._dynamic_batch_size
|
||||
max_retries = 5
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# call subclass implemented batch processing calculation
|
||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||
res[i : i + current_batch] = batch_scores
|
||||
i += current_batch
|
||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
|
||||
current_batch = max(current_batch // 2, self._min_batch_size)
|
||||
self.torch_empty_cache()
|
||||
i = cur_i # reset i to the start of the current batch
|
||||
retry_count += 1
|
||||
else:
|
||||
raise
|
||||
if retry_count >= max_retries:
|
||||
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
|
||||
|
||||
self.torch_empty_cache()
|
||||
self._dynamic_batch_size = old_dynamic_batch_size
|
||||
return np.array(res)
|
||||
|
||||
def _compute_batch_scores(self, batch_pairs, max_length=None):
|
||||
if max_length is None:
|
||||
scores = self._model.compute_score(batch_pairs, normalize=True)
|
||||
else:
|
||||
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
|
||||
if not isinstance(scores, Iterable):
|
||||
scores = [scores]
|
||||
return scores
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, 2048)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 4096
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
@ -162,36 +61,6 @@ class JinaRerank(Base):
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
except Exception:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
self._model = YoudaoRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 8
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class XInferenceRerank(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
@ -514,7 +383,7 @@ class QWenRerank(Base):
|
||||
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
||||
|
||||
|
||||
class HuggingfaceRerank(DefaultRerank):
|
||||
class HuggingfaceRerank(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -17,6 +17,7 @@ import os
|
||||
import logging
|
||||
from api.utils.configs import get_base_config, decrypt_database_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.utils.common import pip_install_torch
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
@ -65,6 +66,7 @@ TAG_FLD = "tag_feas"
|
||||
|
||||
PARALLEL_DEVICES = 0
|
||||
try:
|
||||
pip_install_torch()
|
||||
import torch.cuda
|
||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from jina import Deployment
|
||||
from docarray import BaseDoc
|
||||
from jina import Executor, requests
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
class Prompt(BaseDoc):
|
||||
message: list[dict]
|
||||
gen_conf: dict
|
||||
|
||||
|
||||
class Generation(BaseDoc):
|
||||
text: str
|
||||
|
||||
|
||||
tokenizer = None
|
||||
model_name = ""
|
||||
|
||||
|
||||
class TokenStreamingExecutor(Executor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", torch_dtype="auto"
|
||||
)
|
||||
|
||||
@requests(on="/chat")
|
||||
async def generate(self, doc: Prompt, **kwargs) -> Generation:
|
||||
text = tokenizer.apply_chat_template(
|
||||
doc.message,
|
||||
tokenize=False,
|
||||
)
|
||||
inputs = tokenizer([text], return_tensors="pt")
|
||||
generation_config = GenerationConfig(
|
||||
**doc.gen_conf,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
generated_ids = self.model.generate(
|
||||
inputs.input_ids, generation_config=generation_config
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids) :]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
yield Generation(text=response)
|
||||
|
||||
@requests(on="/stream")
|
||||
async def task(self, doc: Prompt, **kwargs) -> Generation:
|
||||
text = tokenizer.apply_chat_template(
|
||||
doc.message,
|
||||
tokenize=False,
|
||||
)
|
||||
input = tokenizer([text], return_tensors="pt")
|
||||
input_len = input["input_ids"].shape[1]
|
||||
max_new_tokens = 512
|
||||
if "max_new_tokens" in doc.gen_conf:
|
||||
max_new_tokens = doc.gen_conf.pop("max_new_tokens")
|
||||
generation_config = GenerationConfig(
|
||||
**doc.gen_conf,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
for _ in range(max_new_tokens):
|
||||
output = self.model.generate(
|
||||
**input, max_new_tokens=1, generation_config=generation_config
|
||||
)
|
||||
if output[0][-1] == tokenizer.eos_token_id:
|
||||
break
|
||||
yield Generation(
|
||||
text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
|
||||
)
|
||||
input = {
|
||||
"input_ids": output,
|
||||
"attention_mask": torch.ones(1, len(output[0])),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, help="Model name or path")
|
||||
parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
with Deployment(
|
||||
uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
|
||||
) as dep:
|
||||
dep.block()
|
||||
@ -29,6 +29,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.base64_image import image2id
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from api.utils.configs import show_configs
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
from rag.flow.pipeline import Pipeline
|
||||
@ -475,7 +476,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
||||
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
|
||||
tk_count += c
|
||||
|
||||
@timeout(60)
|
||||
@ -1061,7 +1062,10 @@ async def main():
|
||||
/____/
|
||||
""")
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
from api.settings import EMBEDDING_CFG
|
||||
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}')
|
||||
print_rag_settings()
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
|
||||
Reference in New Issue
Block a user