Fix IDE warnings (#12281)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-29 12:01:18 +08:00
committed by GitHub
parent 647fb115a0
commit 01f0ced1e6
43 changed files with 817 additions and 637 deletions

View File

@ -48,13 +48,15 @@ def main():
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))
except Exception as e:
traceback.print_stack(e)
logging.error(f"Error to get data from REDIS: {e}")
traceback.print_stack()
except Exception as e:
traceback.print_stack(e)
logging.error(f"Error to check REDIS connection: {e}")
traceback.print_stack()
if __name__ == "__main__":
while True:
main()
close_connection()
time.sleep(1)
time.sleep(1)

View File

@ -19,16 +19,15 @@ import requests
import base64
import asyncio
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
JSON_DATA = {
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize
}
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" # Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default()
intents.message_content = True
@ -50,7 +49,7 @@ async def on_message(message):
if len(message.content.split('> ')) == 1:
await message.channel.send("Hi~ How can I help you? ")
else:
JSON_DATA['word']=message.content.split('> ')[1]
JSON_DATA['word'] = message.content.split('> ')[1]
response = requests.post(URL, json=JSON_DATA)
response_data = response.json().get('data', [])
image_bool = False
@ -61,9 +60,9 @@ async def on_message(message):
if i['type'] == 3:
image_bool = True
image_data = base64.b64decode(i['url'])
with open('tmp_image.png','wb') as file:
with open('tmp_image.png', 'wb') as file:
file.write(image_data)
image= discord.File('tmp_image.png')
image = discord.File('tmp_image.png')
await message.channel.send(f"{message.author.mention}{res}")

View File

@ -38,7 +38,8 @@ from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, \
MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector, AirtableConnector
from common.constants import FileSource, TaskStatus
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.confluence_connector import ConfluenceConnector
@ -96,7 +97,7 @@ class SyncBase:
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
for document_batch in document_batch_generator:
if not document_batch:
continue
@ -161,6 +162,7 @@ class SyncBase:
def _get_source_prefix(self):
return ""
class _BlobLikeBase(SyncBase):
DEFAULT_BUCKET_TYPE: str = "s3"
@ -199,22 +201,27 @@ class _BlobLikeBase(SyncBase):
)
return document_batch_generator
class S3(_BlobLikeBase):
SOURCE_NAME: str = FileSource.S3
DEFAULT_BUCKET_TYPE: str = "s3"
class R2(_BlobLikeBase):
SOURCE_NAME: str = FileSource.R2
DEFAULT_BUCKET_TYPE: str = "r2"
class OCI_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.OCI_STORAGE
DEFAULT_BUCKET_TYPE: str = "oci_storage"
class GOOGLE_CLOUD_STORAGE(_BlobLikeBase):
SOURCE_NAME: str = FileSource.GOOGLE_CLOUD_STORAGE
DEFAULT_BUCKET_TYPE: str = "google_cloud_storage"
class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE
@ -248,7 +255,9 @@ class Confluence(SyncBase):
index_recursively=index_recursively,
)
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"])
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
@ -280,7 +289,8 @@ class Confluence(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
logging.warning("Confluence connector failure: %s",
getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@ -300,7 +310,7 @@ class Confluence(SyncBase):
async def async_wrapper():
for batch in document_batches():
yield batch
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
return async_wrapper()
@ -314,10 +324,12 @@ class Notion(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator
@ -340,10 +352,12 @@ class Discord(SyncBase):
document_generator = (
self.connector.load_from_state()
if task["reindex"] == "1" or not task["poll_range_start"]
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
else self.connector.poll_source(task["poll_range_start"].timestamp(),
datetime.now(timezone.utc).timestamp())
)
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(
task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator
@ -485,7 +499,8 @@ class GoogleDrive(SyncBase):
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
for document, failure, next_checkpoint in doc_generator:
if failure is not None:
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
logging.warning("Google Drive connector failure: %s",
getattr(failure, "failure_message", failure))
continue
if document is not None:
pending_docs.append(document)
@ -646,10 +661,10 @@ class WebDAV(SyncBase):
remote_path=self.conf.get("remote_path", "/")
)
self.connector.load_credentials(self.conf["credentials"])
logging.info(f"Task info: reindex={task['reindex']}, poll_range_start={task['poll_range_start']}")
if task["reindex"]=="1" or not task["poll_range_start"]:
if task["reindex"] == "1" or not task["poll_range_start"]:
logging.info("Using load_from_state (full sync)")
document_batch_generator = self.connector.load_from_state()
begin_info = "totally"
@ -659,14 +674,15 @@ class WebDAV(SyncBase):
logging.info(f"Polling WebDAV from {task['poll_range_start']} (ts: {start_ts}) to now (ts: {end_ts})")
document_batch_generator = self.connector.poll_source(start_ts, end_ts)
begin_info = "from {}".format(task["poll_range_start"])
logging.info("Connect to WebDAV: {}(path: {}) {}".format(
self.conf["base_url"],
self.conf.get("remote_path", "/"),
begin_info
))
return document_batch_generator
class Moodle(SyncBase):
SOURCE_NAME: str = FileSource.MOODLE
@ -675,7 +691,7 @@ class Moodle(SyncBase):
moodle_url=self.conf["moodle_url"],
batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)
)
self.connector.load_credentials(self.conf["credentials"])
# Determine the time range for synchronization based on reindex or poll_range_start
@ -689,7 +705,7 @@ class Moodle(SyncBase):
begin_info = "totally"
else:
document_generator = self.connector.poll_source(
poll_start.timestamp(),
poll_start.timestamp(),
datetime.now(timezone.utc).timestamp()
)
begin_info = "from {}".format(poll_start)
@ -718,7 +734,7 @@ class BOX(SyncBase):
token = AccessToken(
access_token=credential['access_token'],
refresh_token=credential['refresh_token'],
)
)
auth.token_storage.store(token)
self.connector.load_credentials(auth)
@ -739,6 +755,7 @@ class BOX(SyncBase):
logging.info("Connect to Box: folder_id({}) {}".format(self.conf["folder_id"], begin_info))
return document_generator
class Airtable(SyncBase):
SOURCE_NAME: str = FileSource.AIRTABLE
@ -784,6 +801,7 @@ class Airtable(SyncBase):
return document_generator
func_factory = {
FileSource.S3: S3,
FileSource.R2: R2,

View File

@ -92,7 +92,7 @@ FACTORY = {
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"dataflow": PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
@ -221,7 +221,7 @@ async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1)
@timeout(60 * 80, 1)
async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
@ -283,7 +283,8 @@ async def build_chunks(task, progress_callback):
try:
d = copy.deepcopy(document)
d.update(chunk)
d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["id"] = xxhash.xxh64(
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
@ -328,9 +329,11 @@ async def build_chunks(task, progress_callback):
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
tasks.append(
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@ -355,9 +358,11 @@ async def build_chunks(task, progress_callback):
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
tasks.append(
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
@ -374,15 +379,18 @@ async def build_chunks(task, progress_callback):
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
async def gen_metadata_task(chat_mdl, d):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata", task["parser_config"]["metadata"])
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
task["parser_config"]["metadata"])
if not cached:
async with chat_limiter:
cached = await gen_metadata(chat_mdl,
metadata_schema(task["parser_config"]["metadata"]),
d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata", task["parser_config"]["metadata"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
task["parser_config"]["metadata"])
if cached:
d["metadata_obj"] = cached
tasks = []
for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
@ -430,7 +438,8 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@ -438,7 +447,7 @@ async def build_chunks(task, progress_callback):
async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
@ -454,6 +463,7 @@ async def build_chunks(task, progress_callback):
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
@ -473,21 +483,22 @@ async def build_chunks(task, progress_callback):
def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
docs = sorted(docs, key=lambda d:(
docs = sorted(docs, key=lambda d: (
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
toc: list[dict] = asyncio.run(
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
if ii == len(toc) - 1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
for jj in range(idx + 1, int(toc[ii + 1]["chunk_id"]) + 1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
@ -499,7 +510,8 @@ def build_TOC(task, docs, progress_callback):
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["id"] = xxhash.xxh64(
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
@ -532,12 +544,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
@timeout(60)
def batch_encode(txts):
nonlocal mdl
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -545,7 +557,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
@ -588,7 +600,8 @@ async def run_dataflow(task: dict):
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@ -610,25 +623,27 @@ async def run_dataflow(task: dict):
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
@ -636,10 +651,10 @@ async def run_dataflow(task: dict):
ck["q_%d_vec" % len(v)] = v
except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
metadata = {}
for ck in chunks:
ck["doc_id"] = doc_id
@ -686,15 +701,19 @@ async def run_dataflow(task: dict):
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
dsl=str(pipeline))
@timeout(3600)
@ -702,7 +721,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
vctr_nm = "q_%d_vec"%vector_size
vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
@ -747,17 +766,17 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for x, doc_id in enumerate(doc_ids):
chunks = []
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, doc_id)
callback(prog=(x+1.)/len(doc_ids))
callback(prog=(x + 1.) / len(doc_ids))
else:
chunks = []
for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
await generate(chunks, fake_doc_id)
@ -792,19 +811,22 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0
flds = list(mom_ck.keys())
for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int", "position_int"]:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
"position_int"]:
del mom_ck[fld]
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -821,7 +843,8 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
@ -838,7 +861,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
return True
@timeout(60*60*3, 1)
@timeout(60 * 60 * 3, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
@ -914,7 +937,7 @@ async def do_handle_task(task):
},
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
@ -943,7 +966,7 @@ async def do_handle_task(task):
doc_ids=task.get("doc_ids", []),
)
if fake_doc_ids := task.get("doc_ids", []):
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
# Either using graphrag or Standard chunking methods
elif task_type == "graphrag":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
@ -968,11 +991,10 @@ async def do_handle_task(task):
}
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
@ -1030,7 +1052,7 @@ async def do_handle_task(task):
return True
e = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(e)
try:
if not await _maybe_insert_es(chunks):
return
@ -1084,8 +1106,8 @@ async def do_handle_task(task):
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled."
)
async def handle_task():
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
@ -1093,7 +1115,8 @@ async def handle_task():
return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
@ -1119,7 +1142,9 @@ async def handle_task():
if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
task_type=pipeline_task_type,
fake_document_ids=task_document_ids)
redis_msg.ack()
@ -1249,6 +1274,7 @@ async def main():
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)