mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: init dataflow. (#9791)
### What problem does this PR solve? #9790 Close #9782 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
49
rag/flow/__init__.py
Normal file
49
rag/flow/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
_package_path = os.path.dirname(__file__)
|
||||
__all_classes: Dict[str, Type] = {}
|
||||
|
||||
def _import_submodules() -> None:
|
||||
for filename in os.listdir(_package_path): # noqa: F821
|
||||
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
|
||||
continue
|
||||
module_name = filename[:-3]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(f".{module_name}", package=__name__)
|
||||
_extract_classes_from_module(module) # noqa: F821
|
||||
except ImportError as e:
|
||||
print(f"Warning: Failed to import module {module_name}: {str(e)}")
|
||||
|
||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
obj.__module__ == module.__name__ and not name.startswith("_")):
|
||||
__all_classes[name] = obj
|
||||
globals()[name] = obj
|
||||
|
||||
_import_submodules()
|
||||
|
||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||
|
||||
del _package_path, _import_submodules, _extract_classes_from_module
|
||||
59
rag/flow/base.py
Normal file
59
rag/flow/base.py
Normal file
@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import trio
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class ProcessParamBase(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = 100000000
|
||||
self.persist_logs = True
|
||||
|
||||
|
||||
class ProcessBase(ComponentBase):
|
||||
|
||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||
super().__init__(pipeline, id, param)
|
||||
self.callback = partial(self._canvas.callback, self.component_name)
|
||||
|
||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k,v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
async def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
47
rag/flow/begin.py
Normal file
47
rag/flow/begin.py
Normal file
@ -0,0 +1,47 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class FileParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
|
||||
|
||||
class File(ProcessBase):
|
||||
component_name = "File"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
if self._canvas._doc_id:
|
||||
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
|
||||
if not e:
|
||||
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
|
||||
return
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
self.set_output("blob", STORAGE_IMPL.get(b, n))
|
||||
self.set_output("name", doc.name)
|
||||
else:
|
||||
file = kwargs.get("file")
|
||||
self.set_output("name", file["name"])
|
||||
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
|
||||
160
rag/flow/chunker.py
Normal file
160
rag/flow/chunker.py
Normal file
@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import trio
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.prompts.prompts import keyword_extraction, question_proposal
|
||||
|
||||
|
||||
class ChunkerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
|
||||
self.method = "general"
|
||||
self.chunk_token_size = 512
|
||||
self.delimiter = "\n"
|
||||
self.overlapped_percent = 0
|
||||
self.page_rank = 0
|
||||
self.auto_keywords = 0
|
||||
self.auto_questions = 0
|
||||
self.tag_sets = []
|
||||
self.llm_setting = {
|
||||
"llm_name": "",
|
||||
"lang": "Chinese"
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
||||
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
|
||||
self.check_nonnegative_number(self.page_rank, "Page rank value: (0, 10]")
|
||||
self.check_nonnegative_number(self.auto_keywords, "Auto-keyword value: (0, 10]")
|
||||
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||
|
||||
|
||||
class Chunker(ProcessBase):
|
||||
component_name = "Chunker"
|
||||
|
||||
def _general(self, **kwargs):
|
||||
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
|
||||
if kwargs.get("output_format") in ["markdown", "text"]:
|
||||
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
||||
return [{"text": c} for c in cks]
|
||||
|
||||
sections, section_images = [], []
|
||||
for o in kwargs["json"]:
|
||||
sections.append((o["text"], o.get("position_tag","")))
|
||||
section_images.append(o.get("image"))
|
||||
|
||||
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
||||
return [{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c)
|
||||
} for c,img in zip(chunks,images)]
|
||||
|
||||
def _q_and_a(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _resume(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _manual(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _table(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _paper(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _book(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _laws(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _presentation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _one(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"general": self._general,
|
||||
"q&a": self._q_and_a,
|
||||
"resume": self._resume,
|
||||
"manual": self._manual,
|
||||
"table": self._table,
|
||||
"paper": self._paper,
|
||||
"book": self._book,
|
||||
"laws": self._laws,
|
||||
"presentation": self._presentation,
|
||||
"one": self._one,
|
||||
}
|
||||
chunks = function_map[self._param.method](**kwargs)
|
||||
llm_setting = self._param.llm_setting
|
||||
|
||||
async def auto_keywords():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_keyword_extraction(chat_mdl, ck, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, ck["text"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
ck["keywords"] = cached.split(",")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in chunks:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, ck, self._param.auto_keywords)
|
||||
|
||||
async def auto_questions():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, ck["text"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["questions"] = cached.split("\n")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in chunks:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, ck, self._param.auto_questions)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self._param.auto_questions:
|
||||
nursery.start_soon(auto_questions)
|
||||
if self._param.auto_keywords:
|
||||
nursery.start_soon(auto_keywords)
|
||||
|
||||
if self._param.page_rank:
|
||||
for ck in chunks:
|
||||
ck["page_rank"] = self._param.page_rank
|
||||
|
||||
self.set_output("chunks", chunks)
|
||||
107
rag/flow/parser.py
Normal file
107
rag/flow/parser.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import trio
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from deepdoc.parser import ExcelParser
|
||||
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setups = {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": ["pdf"],
|
||||
"output_format": "json"
|
||||
},
|
||||
"excel": {
|
||||
"output_format": "html"
|
||||
},
|
||||
"ppt": {},
|
||||
"image": {
|
||||
"parse_method": "ocr"
|
||||
},
|
||||
"email": {},
|
||||
"text": {},
|
||||
"audio": {},
|
||||
"video": {},
|
||||
}
|
||||
|
||||
def check(self):
|
||||
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
|
||||
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
|
||||
assert self.setups["pdf"].get("lang"), "No language specified."
|
||||
|
||||
|
||||
class Parser(ProcessBase):
|
||||
component_name = "Parser"
|
||||
|
||||
def _pdf(self, blob):
|
||||
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
|
||||
conf = self._param.setups["pdf"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
if conf.get("parse_method") == "deepdoc":
|
||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||
elif conf.get("parse_method") == "plain_text":
|
||||
lines,_ = PlainParser()(blob)
|
||||
bboxes = [{"text": t} for t,_ in lines]
|
||||
else:
|
||||
assert conf.get("vlm_name")
|
||||
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
|
||||
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
pn, x0, x1, top, bott = poss.split(" ")
|
||||
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
|
||||
|
||||
self.set_output("json", bboxes)
|
||||
mkdn = ""
|
||||
for b in bboxes:
|
||||
if b.get("layout_type", "") == "title":
|
||||
mkdn += "\n## "
|
||||
if b.get("layout_type", "") == "figure":
|
||||
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
||||
continue
|
||||
mkdn += b.get("text", "") + "\n"
|
||||
self.set_output("markdown", mkdn)
|
||||
|
||||
def _excel(self, blob):
|
||||
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
|
||||
conf = self._param.setups["excel"]
|
||||
excel_parser = ExcelParser()
|
||||
if conf.get("output_format") == "html":
|
||||
html = excel_parser.html(blob,1000000000)
|
||||
self.set_output("html", html)
|
||||
elif conf.get("output_format") == "json":
|
||||
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
||||
elif conf.get("output_format") == "markdown":
|
||||
self.set_output("markdown", excel_parser.markdown(blob))
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"pdf": self._pdf,
|
||||
}
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
|
||||
break
|
||||
121
rag/flow/pipeline.py
Normal file
121
rag/flow/pipeline.py
Normal file
@ -0,0 +1,121 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import trio
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Pipeline(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
self._doc_id = doc_id
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if doc_id:
|
||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
||||
|
||||
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
obj = json.loads(bin.encode("utf-8"))
|
||||
if obj:
|
||||
if obj[-1]["component_name"] == component_name:
|
||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
||||
else:
|
||||
obj.append({
|
||||
"component_name": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
||||
})
|
||||
else:
|
||||
obj = [{
|
||||
"component_name": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
||||
}]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60*10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def fetch_logs(self):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
if bin:
|
||||
return json.loads(bin.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60*10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
if not self.path:
|
||||
self.path.append("begin")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(self._doc_id, {
|
||||
"progress": random.randint(0,5)/100.,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
})
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx-1])
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
async def invoke():
|
||||
nonlocal last_cpn, cpn_obj
|
||||
await cpn_obj.invoke(**last_cpn.output())
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(self._doc_id, {
|
||||
"progress": 1 if not self.error else -1,
|
||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
||||
"process_duration": time.perf_counter() - st
|
||||
})
|
||||
|
||||
57
rag/flow/tests/client.py
Normal file
57
rag/flow/tests/client.py
Normal file
@ -0,0 +1,57 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import trio
|
||||
from api import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
def print_logs(pipeline):
|
||||
last_logs = "[]"
|
||||
while True:
|
||||
time.sleep(5)
|
||||
logs = pipeline.fetch_logs()
|
||||
logs_str = json.dumps(logs)
|
||||
if logs_str != last_logs:
|
||||
print(logs_str)
|
||||
last_logs = logs_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
dsl_default_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"dsl_examples",
|
||||
"general_pdf_all.json",
|
||||
)
|
||||
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
settings.init_settings()
|
||||
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
|
||||
pipeline.reset()
|
||||
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
thr = exe.submit(print_logs, pipeline)
|
||||
|
||||
trio.run(pipeline.run)
|
||||
thr.result()
|
||||
54
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
54
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "File",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": ["parser:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"parser:0": {
|
||||
"obj": {
|
||||
"component_name": "Parser",
|
||||
"params": {
|
||||
"setups": {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc",
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf"
|
||||
],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["chunker:0"],
|
||||
"upstream": ["begin"]
|
||||
},
|
||||
"chunker:0": {
|
||||
"obj": {
|
||||
"component_name": "Chunker",
|
||||
"params": {
|
||||
"method": "general",
|
||||
"auto_keywords": 5
|
||||
}
|
||||
},
|
||||
"downstream": ["tokenizer:0"],
|
||||
"upstream": ["chunker:0"]
|
||||
},
|
||||
"tokenizer:0": {
|
||||
"obj": {
|
||||
"component_name": "Tokenizer",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": ["chunker:0"]
|
||||
}
|
||||
},
|
||||
"path": []
|
||||
}
|
||||
134
rag/flow/tokenizer.py
Normal file
134
rag/flow/tokenizer.py
Normal file
@ -0,0 +1,134 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.search_method = ["full_text", "embedding"]
|
||||
self.filename_embd_weight = 0.1
|
||||
|
||||
def check(self):
|
||||
for v in self.search_method:
|
||||
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
||||
|
||||
|
||||
class Tokenizer(ProcessBase):
|
||||
component_name = "Tokenizer"
|
||||
|
||||
async def _embedding(self, name, chunks):
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
token_count = 0
|
||||
if self._canvas._kb_id:
|
||||
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
|
||||
embedding_id = kb.embd_id
|
||||
else:
|
||||
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
|
||||
embedding_id = ten.embd_id
|
||||
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
texts = []
|
||||
for c in chunks:
|
||||
if c.get("questions"):
|
||||
texts.append("\n".join(c["questions"]))
|
||||
else:
|
||||
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"]))
|
||||
vts, c = embedding_model.encode([name])
|
||||
token_count += c
|
||||
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
token_count += c
|
||||
if i % 33 == 32:
|
||||
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
|
||||
|
||||
cnts = cnts_
|
||||
title_w = float(self._param.filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
return chunks, token_count
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
if "full_text" in self._param.search_method:
|
||||
self.callback(random.randint(1,5)/100., "Start to tokenize.")
|
||||
if kwargs.get("chunks"):
|
||||
chunks = kwargs["chunks"]
|
||||
for i, ck in enumerate(chunks):
|
||||
if ck.get("questions"):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i*1./len(chunks)/parts)
|
||||
elif kwargs.get("output_format") in ["markdown", "text"]:
|
||||
ck = {
|
||||
"text": kwargs.get(kwargs["output_format"], "")
|
||||
}
|
||||
if "full_text" in self._param.search_method:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
chunks = [ck]
|
||||
else:
|
||||
chunks = kwargs["json"]
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i*1./len(chunks)/parts)
|
||||
|
||||
self.callback(1./parts, "Finish tokenizing.")
|
||||
|
||||
if "embedding" in self._param.search_method:
|
||||
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
|
||||
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
|
||||
self.set_output("embedding_token_consumption", token_count)
|
||||
|
||||
self.callback(1., "Finish embedding.")
|
||||
|
||||
self.set_output("chunks", chunks)
|
||||
Reference in New Issue
Block a user