mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: support TOC transformer. (#11685)
### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
return
|
return
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
|
|||||||
@ -12,10 +12,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy, copy
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import xxhash
|
||||||
|
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.prompts.generator import run_toc_from_text
|
||||||
|
|
||||||
|
|
||||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
|||||||
class Extractor(ProcessBase, LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
|
def _build_TOC(self, docs):
|
||||||
|
self.callback(message="Start to generate table of content ...")
|
||||||
|
docs = sorted(docs, key=lambda d:(
|
||||||
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||||
|
))
|
||||||
|
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
|
||||||
|
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||||
|
ii = 0
|
||||||
|
while ii < len(toc):
|
||||||
|
try:
|
||||||
|
idx = int(toc[ii]["chunk_id"])
|
||||||
|
del toc[ii]["chunk_id"]
|
||||||
|
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||||
|
if ii == len(toc) -1:
|
||||||
|
break
|
||||||
|
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||||
|
toc[ii]["ids"].append(docs[jj]["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
ii += 1
|
||||||
|
|
||||||
|
if toc:
|
||||||
|
d = copy.deepcopy(docs[-1])
|
||||||
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||||
|
d["toc_kwd"] = "toc"
|
||||||
|
d["available_int"] = 0
|
||||||
|
d["page_num_int"] = [100000000]
|
||||||
|
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
|
return d
|
||||||
|
return None
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
self.set_output("output_format", "chunks")
|
self.set_output("output_format", "chunks")
|
||||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||||
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
|
|||||||
chunks_key = k
|
chunks_key = k
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
|
if self._param.field_name == "toc":
|
||||||
|
toc = self._build_TOC(chunks)
|
||||||
|
chunks.append(toc)
|
||||||
|
self.set_output("chunks", chunks)
|
||||||
|
return
|
||||||
|
|
||||||
prog = 0
|
prog = 0
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
args[chunks_key] = ck["text"]
|
args[chunks_key] = ck["text"]
|
||||||
|
|||||||
@ -944,7 +944,7 @@ async def do_handle_task(task):
|
|||||||
logging.info(progress_message)
|
logging.info(progress_message)
|
||||||
progress_callback(msg=progress_message)
|
progress_callback(msg=progress_message)
|
||||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||||
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
|
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
|
||||||
|
|
||||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
|
|||||||
Reference in New Issue
Block a user