mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-01 17:45:28 +08:00
Fix IDE warnings (#12281)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -48,13 +48,15 @@ def main():
|
||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||
logging.info("CACHE: {}".format(loc))
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
logging.error(f"Error to get data from REDIS: {e}")
|
||||
traceback.print_stack()
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
logging.error(f"Error to check REDIS connection: {e}")
|
||||
traceback.print_stack()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
main()
|
||||
close_connection()
|
||||
time.sleep(1)
|
||||
time.sleep(1)
|
||||
|
||||
@ -19,16 +19,15 @@ import requests
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
|
||||
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
|
||||
|
||||
JSON_DATA = {
|
||||
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
|
||||
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
|
||||
"word": "" # User question, don't need to initialize
|
||||
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
|
||||
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
|
||||
"word": "" # User question, don't need to initialize
|
||||
}
|
||||
|
||||
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
|
||||
|
||||
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
@ -50,7 +49,7 @@ async def on_message(message):
|
||||
if len(message.content.split('> ')) == 1:
|
||||
await message.channel.send("Hi~ How can I help you? ")
|
||||
else:
|
||||
JSON_DATA['word']=message.content.split('> ')[1]
|
||||
JSON_DATA['word'] = message.content.split('> ')[1]
|
||||
response = requests.post(URL, json=JSON_DATA)
|
||||
response_data = response.json().get('data', [])
|
||||
image_bool = False
|
||||
@ -61,9 +60,9 @@ async def on_message(message):
|
||||
if i['type'] == 3:
|
||||
image_bool = True
|
||||
image_data = base64.b64decode(i['url'])
|
||||
with open('tmp_image.png','wb') as file:
|
||||
with open('tmp_image.png', 'wb') as file:
|
||||
file.write(image_data)
|
||||
image= discord.File('tmp_image.png')
|
||||
image = discord.File('tmp_image.png')
|
||||
|
||||
await message.channel.send(f"{message.author.mention}{res}")
|
||||
|
||||
|
||||
@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common import settings
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
|
||||
MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
|
||||
from common.constants import FileSource, TaskStatus
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
@ -96,7 +97,7 @@ class SyncBase:
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
for document_batch in document_batch_generator:
|
||||
for document_batch in document_batch_generator:
|
||||
if not document_batch:
|
||||
continue
|
||||
|
||||
@ -161,6 +162,7 @@ class SyncBase:
|
||||
def _get_source_prefix(self):
|
||||
return ""
|
||||
|
||||
|
||||
class _BlobLikeBase(SyncBase):
|
||||
DEFAULT_BUCKET_TYPE: str = "s3"
|
||||
|
||||
@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
|
||||
)
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
class S3(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.S3
|
||||
DEFAULT_BUCKET_TYPE: str = "s3"
|
||||
|
||||
|
||||
class R2(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.R2
|
||||
DEFAULT_BUCKET_TYPE: str = "r2"
|
||||
|
||||
|
||||
class OCI_STORAGE(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.OCI_STORAGE
|
||||
DEFAULT_BUCKET_TYPE: str = "oci_storage"
|
||||
|
||||
|
||||
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
|
||||
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.CONFLUENCE
|
||||
|
||||
@ -248,7 +255,9 @@ class Confluence(SyncBase):
|
||||
index_recursively=index_recursively,
|
||||
)
|
||||
|
||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
|
||||
connector_name=DocumentSource.CONFLUENCE,
|
||||
credential_json=self.conf["credentials"])
|
||||
self.connector.set_credentials_provider(credentials_provider)
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
@ -280,7 +289,8 @@ class Confluence(SyncBase):
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
logging.warning("Confluence connector failure: %s",
|
||||
getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
@ -300,7 +310,7 @@ class Confluence(SyncBase):
|
||||
async def async_wrapper():
|
||||
for batch in document_batches():
|
||||
yield batch
|
||||
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return async_wrapper()
|
||||
|
||||
@ -314,10 +324,12 @@ class Notion(SyncBase):
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -340,10 +352,12 @@ class Discord(SyncBase):
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(),
|
||||
datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
|
||||
task["poll_range_start"])
|
||||
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
logging.warning("Google Drive connector failure: %s",
|
||||
getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
@ -646,10 +661,10 @@ class WebDAV(SyncBase):
|
||||
remote_path=self.conf.get("remote_path", "/")
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
|
||||
|
||||
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
|
||||
|
||||
if task["reindex"]=="1" or not task["poll_range_start"]:
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
logging.info("Using load_from_state (full sync)")
|
||||
document_batch_generator = self.connector.load_from_state()
|
||||
begin_info = "totally"
|
||||
@ -659,14 +674,15 @@ class WebDAV(SyncBase):
|
||||
logging.info(f"Polling WebDAV from {task['poll_range_start']} (ts: {start_ts}) to now (ts: {end_ts})")
|
||||
document_batch_generator = self.connector.poll_source(start_ts, end_ts)
|
||||
begin_info = "from {}".format(task["poll_range_start"])
|
||||
|
||||
|
||||
logging.info("Connect to WebDAV: {}(path: {}) {}".format(
|
||||
self.conf["base_url"],
|
||||
self.conf.get("remote_path", "/"),
|
||||
begin_info
|
||||
))
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
|
||||
class Moodle(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.MOODLE
|
||||
|
||||
@ -675,7 +691,7 @@ class Moodle(SyncBase):
|
||||
moodle_url=self.conf["moodle_url"],
|
||||
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)
|
||||
)
|
||||
|
||||
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
@ -689,7 +705,7 @@ class Moodle(SyncBase):
|
||||
begin_info = "totally"
|
||||
else:
|
||||
document_generator = self.connector.poll_source(
|
||||
poll_start.timestamp(),
|
||||
poll_start.timestamp(),
|
||||
datetime.now(timezone.utc).timestamp()
|
||||
)
|
||||
begin_info = "from {}".format(poll_start)
|
||||
@ -718,7 +734,7 @@ class BOX(SyncBase):
|
||||
token = AccessToken(
|
||||
access_token=credential['access_token'],
|
||||
refresh_token=credential['refresh_token'],
|
||||
)
|
||||
)
|
||||
auth.token_storage.store(token)
|
||||
|
||||
self.connector.load_credentials(auth)
|
||||
@ -739,6 +755,7 @@ class BOX(SyncBase):
|
||||
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
|
||||
class Airtable(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.AIRTABLE
|
||||
|
||||
@ -784,6 +801,7 @@ class Airtable(SyncBase):
|
||||
|
||||
return document_generator
|
||||
|
||||
|
||||
func_factory = {
|
||||
FileSource.S3: S3,
|
||||
FileSource.R2: R2,
|
||||
|
||||
@ -92,7 +92,7 @@ FACTORY = {
|
||||
}
|
||||
|
||||
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||
"dataflow" : PipelineTaskType.PARSE,
|
||||
"dataflow": PipelineTaskType.PARSE,
|
||||
"raptor": PipelineTaskType.RAPTOR,
|
||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||
"mindmap": PipelineTaskType.MINDMAP,
|
||||
@ -221,7 +221,7 @@ async def get_storage_binary(bucket, name):
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60*80, 1)
|
||||
@timeout(60 * 80, 1)
|
||||
async def build_chunks(task, progress_callback):
|
||||
if task["size"] > settings.DOC_MAXIMUM_SIZE:
|
||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
|
||||
try:
|
||||
d = copy.deepcopy(document)
|
||||
d.update(chunk)
|
||||
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["id"] = xxhash.xxh64(
|
||||
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
if not d.get("image"):
|
||||
@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
tasks.append(
|
||||
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
tasks.append(
|
||||
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
|
||||
async def gen_metadata_task(chat_mdl, d):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"])
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
|
||||
task["parser_config"]["metadata"])
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await gen_metadata(chat_mdl,
|
||||
metadata_schema(task["parser_config"]["metadata"]),
|
||||
d["content_with_weight"])
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"])
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
|
||||
task["parser_config"]["metadata"])
|
||||
if cached:
|
||||
d["metadata_obj"] = cached
|
||||
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
|
||||
@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return None
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
|
||||
d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@ -438,7 +447,7 @@ async def build_chunks(task, progress_callback):
|
||||
async def doc_content_tagging(chat_mdl, d, topn_tags):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||
if not cached:
|
||||
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
||||
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
|
||||
if not picked_examples:
|
||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
||||
async with chat_limiter:
|
||||
@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
|
||||
tasks = []
|
||||
for d in docs_to_tag:
|
||||
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
|
||||
@ -473,21 +483,22 @@ async def build_chunks(task, progress_callback):
|
||||
def build_TOC(task, docs, progress_callback):
|
||||
progress_callback(msg="Start to generate table of content ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
docs = sorted(docs, key=lambda d:(
|
||||
docs = sorted(docs, key=lambda d: (
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
toc: list[dict] = asyncio.run(
|
||||
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
try:
|
||||
idx = int(toc[ii]["chunk_id"])
|
||||
del toc[ii]["chunk_id"]
|
||||
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||
if ii == len(toc) -1:
|
||||
if ii == len(toc) - 1:
|
||||
break
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||
for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
|
||||
toc[ii]["ids"].append(docs[jj]["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
|
||||
d["toc_kwd"] = "toc"
|
||||
d["available_int"] = 0
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
d["id"] = xxhash.xxh64(
|
||||
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
@ -532,12 +544,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal mdl
|
||||
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
|
||||
return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -545,7 +557,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count += c
|
||||
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
|
||||
cnts = cnts_
|
||||
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
|
||||
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
|
||||
if not filename_embd_weight:
|
||||
filename_embd_weight = 0.1
|
||||
title_w = float(filename_embd_weight)
|
||||
@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
@ -610,25 +623,27 @@ async def run_dataflow(task: dict):
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
|
||||
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
|
||||
if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
except Exception as e:
|
||||
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
|
||||
metadata = {}
|
||||
for ck in chunks:
|
||||
ck["doc_id"] = doc_id
|
||||
@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
|
||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
|
||||
task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
|
||||
task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
|
||||
dsl=str(pipeline))
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -702,7 +721,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
|
||||
res = []
|
||||
tk_count = 0
|
||||
@ -747,17 +766,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
for x, doc_id in enumerate(doc_ids):
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
await generate(chunks, doc_id)
|
||||
callback(prog=(x+1.)/len(doc_ids))
|
||||
callback(prog=(x + 1.) / len(doc_ids))
|
||||
else:
|
||||
chunks = []
|
||||
for doc_id in doc_ids:
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
await generate(chunks, fake_doc_id)
|
||||
@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
mom_ck["available_int"] = 0
|
||||
flds = list(mom_ck.keys())
|
||||
for fld in flds:
|
||||
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]:
|
||||
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
|
||||
"position_int"]:
|
||||
del mom_ck[fld]
|
||||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
|
||||
@ -838,7 +861,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*3, 1)
|
||||
@timeout(60 * 60 * 3, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
@ -914,7 +937,7 @@ async def do_handle_task(task):
|
||||
},
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
|
||||
@ -943,7 +966,7 @@ async def do_handle_task(task):
|
||||
doc_ids=task.get("doc_ids", []),
|
||||
)
|
||||
if fake_doc_ids := task.get("doc_ids", []):
|
||||
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
|
||||
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task_type == "graphrag":
|
||||
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||
@ -968,11 +991,10 @@ async def do_handle_task(task):
|
||||
}
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||
return
|
||||
|
||||
|
||||
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
@ -1030,7 +1052,7 @@ async def do_handle_task(task):
|
||||
return True
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
|
||||
return bool(e)
|
||||
|
||||
|
||||
try:
|
||||
if not await _maybe_insert_es(chunks):
|
||||
return
|
||||
@ -1084,8 +1106,8 @@ async def do_handle_task(task):
|
||||
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
|
||||
)
|
||||
|
||||
async def handle_task():
|
||||
|
||||
async def handle_task():
|
||||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
@ -1093,7 +1115,8 @@ async def handle_task():
|
||||
return
|
||||
|
||||
task_type = task["task_type"]
|
||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
|
||||
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
|
||||
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
@ -1119,7 +1142,9 @@ async def handle_task():
|
||||
if task_type in ["graphrag", "raptor", "mindmap"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
if not task.get("dataflow_id", ""):
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
|
||||
task_type=pipeline_task_type,
|
||||
fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
@ -1249,6 +1274,7 @@ async def main():
|
||||
await asyncio.gather(report_task, return_exceptions=True)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
|
||||
Reference in New Issue
Block a user