mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Move some vars to globals (#11017)
### What problem does this PR solve? As title. ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -21,6 +21,7 @@ from copy import deepcopy
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app.qa import Excel
|
||||
from rag.nlp import rag_tokenizer
|
||||
from common import globals
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
def label_question(question, kbs):
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from api import settings
|
||||
tags = None
|
||||
tag_kb_ids = []
|
||||
for kb in kbs:
|
||||
@ -133,14 +133,14 @@ def label_question(question, kbs):
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retriever.tag_query(question,
|
||||
tags = globals.retriever.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
|
||||
@ -20,10 +20,10 @@ import time
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
from common import globals
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api import settings
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.nlp import tokenize, search
|
||||
from ranx import evaluate
|
||||
@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
@ -77,9 +77,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
@ -114,13 +114,13 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
@ -155,12 +155,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
globals.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
@ -210,12 +210,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
|
||||
@ -21,7 +21,7 @@ from functools import partial
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import id2image, image2id
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import image2id
|
||||
from rag.utils.base64_image import image2id
|
||||
from deepdoc.parser import ExcelParser
|
||||
from deepdoc.parser.mineru_parser import MinerUParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||
|
||||
@ -18,7 +18,7 @@ from functools import partial
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import id2image, image2id
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
|
||||
@ -30,7 +30,6 @@ from zhipuai import ZhipuAI
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from common import globals
|
||||
from api import settings
|
||||
import logging
|
||||
|
||||
|
||||
@ -74,9 +73,9 @@ class BuiltinEmbed(Base):
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||
with BuiltinEmbed._model_lock:
|
||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
self._model = BuiltinEmbed._model
|
||||
self._model_name = BuiltinEmbed._model_name
|
||||
self._max_tokens = BuiltinEmbed._max_tokens
|
||||
|
||||
@ -18,13 +18,14 @@ import logging
|
||||
from common.config_utils import get_base_config, decrypt_database_config
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import pip_install_torch
|
||||
from common import globals
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
# Get storage type and document engine from system environment variables
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
globals.DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
@ -35,11 +36,11 @@ OSS = {}
|
||||
OS = {}
|
||||
|
||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||
if DOC_ENGINE == 'elasticsearch':
|
||||
if globals.DOC_ENGINE == 'elasticsearch':
|
||||
ES = get_base_config("es", {})
|
||||
elif DOC_ENGINE == 'opensearch':
|
||||
elif globals.DOC_ENGINE == 'opensearch':
|
||||
OS = get_base_config("os", {})
|
||||
elif DOC_ENGINE == 'infinity':
|
||||
elif globals.DOC_ENGINE == 'infinity':
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from common.connection_utils import timeout
|
||||
from common.base64_image import image2id
|
||||
from rag.utils.base64_image import image2id
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
@ -68,6 +68,7 @@ from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common import globals
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -349,7 +350,7 @@ async def build_chunks(task, progress_callback):
|
||||
examples = []
|
||||
all_tags = get_tags_from_cache(kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
@ -362,7 +363,7 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@ -423,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -647,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for doc_id in doc_ids:
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@ -698,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -715,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
@ -751,7 +752,7 @@ async def do_handle_task(task):
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
|
||||
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
||||
lower_case_doc_engine = settings.DOC_ENGINE.lower()
|
||||
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
||||
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
||||
progress_callback(-1, msg=error_message)
|
||||
|
||||
75
rag/utils/base64_image.py
Normal file
75
rag/utils/base64_image.py
Normal file
@ -0,0 +1,75 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
from rag.svr.task_executor import minio_limiter
|
||||
if "image" not in d:
|
||||
return
|
||||
if not d["image"]:
|
||||
del d["image"]
|
||||
return
|
||||
|
||||
with BytesIO() as output_buffer:
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer.write(d["image"])
|
||||
output_buffer.seek(0)
|
||||
else:
|
||||
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
|
||||
if d["image"].mode in ("RGBA", "P"):
|
||||
converted_image = d["image"].convert("RGB")
|
||||
d["image"] = converted_image
|
||||
try:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"Saving image exception, ignore: {}".format(str(e)))
|
||||
|
||||
async with minio_limiter:
|
||||
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
|
||||
d["img_id"] = f"{bucket}-{objname}"
|
||||
if not isinstance(d["image"], bytes):
|
||||
d["image"].close()
|
||||
del d["image"] # Remove image reference
|
||||
|
||||
|
||||
def id2image(image_id:str|None, storage_get_func: partial):
|
||||
if not image_id:
|
||||
return
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
Reference in New Issue
Block a user