From e32ef75e997b108319c93be0edc8b3b1c3e1fb76 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Tue, 23 Jan 2024 19:45:36 +0800 Subject: [PATCH] Test chat API and refine ppt chunker (#42) --- api/apps/conversation_app.py | 12 ++-- api/db/db_models.py | 1 + api/db/services/llm_service.py | 113 ++++++++++++++++++++++++++++----- api/utils/file_utils.py | 4 +- rag/llm/chat_model.py | 6 +- rag/llm/cv_model.py | 6 +- rag/llm/embedding_model.py | 40 ++++++++---- rag/nlp/huchunk.py | 61 ++++++++++++------ rag/nlp/search.py | 56 ++++++++++------ rag/svr/parse_user_docs.py | 18 +++--- 10 files changed, 226 insertions(+), 91 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 130acb999..b26fe7e10 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -17,7 +17,7 @@ from flask import request from flask_login import login_required from api.db.services.dialog_service import DialogService, ConversationService from api.db import LLMType -from api.db.services.llm_service import LLMService, TenantLLMService +from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result @@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs): if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") - model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id) - if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) - question = messages[-1]["content"] - embd_mdl = TenantLLMService.model_instance( - dialog.tenant_id, LLMType.EMBEDDING.value) + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, top=1024, aggs=False) knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] @@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs): used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) - mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id) - answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) + answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) answer = retrievaler.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], diff --git a/api/db/db_models.py b/api/db/db_models.py index 0862dec83..b0580eb72 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -524,6 +524,7 @@ class Dialog(DataBaseModel): similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) top_n = IntegerField(default=6) + do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") kb_ids = JSONField(null=False, default=[]) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 51914714e..0fb10b0a1 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -14,12 +14,12 @@ # limitations under the License. # from api.db.services.user_service import TenantService -from rag.llm import EmbeddingModel, CvModel +from api.settings import database_logger +from rag.llm import EmbeddingModel, CvModel, ChatModel from api.db import LLMType from api.db.db_models import DB, UserTenant from api.db.db_models import LLMFactories, LLM, TenantLLM from api.db.services.common_service import CommonService -from api.db import StatusEnum class LLMFactoriesService(CommonService): @@ -37,13 +37,19 @@ class TenantLLMService(CommonService): @DB.connection_context() def get_api_key(cls, tenant_id, model_name): objs = cls.query(tenant_id=tenant_id, llm_name=model_name) - if not objs: return + if not objs: + return return objs[0] @classmethod @DB.connection_context() def get_my_llms(cls, tenant_id): - fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name] + fields = [ + cls.model.llm_factory, + LLMFactories.logo, + LLMFactories.tags, + cls.model.model_type, + cls.model.llm_name] objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( cls.model.tenant_id == tenant_id).dicts() @@ -51,23 +57,96 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() - def model_instance(cls, tenant_id, llm_type): - e,tenant = TenantService.get_by_id(tenant_id) - if not e: raise LookupError("Tenant not found") + def model_instance(cls, tenant_id, llm_type, llm_name=None): + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + raise LookupError("Tenant not found") - if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id - elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id - elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id - elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id - else: assert False, "LLM type error" + if llm_type == LLMType.EMBEDDING.value: + mdlnm = tenant.embd_id + elif llm_type == LLMType.SPEECH2TEXT.value: + mdlnm = tenant.asr_id + elif llm_type == LLMType.IMAGE2TEXT.value: + mdlnm = tenant.img2txt_id + elif llm_type == LLMType.CHAT.value: + mdlnm = tenant.llm_id if not llm_name else llm_name + else: + assert False, "LLM type error" model_config = cls.get_api_key(tenant_id, mdlnm) - if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) + if not model_config: + raise LookupError("Model({}) not found".format(mdlnm)) model_config = model_config.to_dict() if llm_type == LLMType.EMBEDDING.value: - if model_config["llm_factory"] not in EmbeddingModel: return - return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) + if model_config["llm_factory"] not in EmbeddingModel: + return + return EmbeddingModel[model_config["llm_factory"]]( + model_config["api_key"], model_config["llm_name"]) if llm_type == LLMType.IMAGE2TEXT.value: - if model_config["llm_factory"] not in CvModel: return - return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) + if model_config["llm_factory"] not in CvModel: + return + return CvModel[model_config["llm_factory"]]( + model_config["api_key"], model_config["llm_name"]) + + if llm_type == LLMType.CHAT.value: + if model_config["llm_factory"] not in ChatModel: + return + return ChatModel[model_config["llm_factory"]]( + model_config["api_key"], model_config["llm_name"]) + + @classmethod + @DB.connection_context() + def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + raise LookupError("Tenant not found") + + if llm_type == LLMType.EMBEDDING.value: + mdlnm = tenant.embd_id + elif llm_type == LLMType.SPEECH2TEXT.value: + mdlnm = tenant.asr_id + elif llm_type == LLMType.IMAGE2TEXT.value: + mdlnm = tenant.img2txt_id + elif llm_type == LLMType.CHAT.value: + mdlnm = tenant.llm_id if not llm_name else llm_name + else: + assert False, "LLM type error" + + num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\ + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ + .execute() + return num + + +class LLMBundle(object): + def __init__(self, tenant_id, llm_type, llm_name=None): + self.tenant_id = tenant_id + self.llm_type = llm_type + self.llm_name = llm_name + self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name) + assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) + + def encode(self, texts: list, batch_size=32): + emd, used_tokens = self.mdl.encode(texts, batch_size) + if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + return emd, used_tokens + + def encode_queries(self, query: str): + emd, used_tokens = self.mdl.encode_queries(query) + if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + return emd, used_tokens + + def describe(self, image, max_tokens=300): + txt, used_tokens = self.mdl.describe(image, max_tokens) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) + return txt + + def chat(self, system, history, gen_conf): + txt, used_tokens = self.mdl.chat(system, history, gen_conf) + if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) + return txt diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 0aa650e2e..c3446b245 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -143,11 +143,11 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): + if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): return FileType.AURAL.value if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): - return FileType.VISUAL \ No newline at end of file + return FileType.VISUAL diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 7dd8267b5..316263686 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -37,7 +37,7 @@ class GptTurbo(Base): model=self.model_name, messages=history, **gen_conf) - return res.choices[0].message.content.strip() + return res.choices[0].message.content.strip(), res.usage.completion_tokens from dashscope import Generation @@ -56,5 +56,5 @@ class QWenChat(Base): result_format='message' ) if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'] - return response.message + return response.output.choices[0]['message']['content'], response.usage.output_tokens + return response.message, 0 diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 371953533..67816a165 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -72,7 +72,7 @@ class GptV4(Base): messages=self.prompt(b64), max_tokens=max_tokens, ) - return res.choices[0].message.content.strip() + return res.choices[0].message.content.strip(), res.usage.total_tokens class QWenCV(Base): @@ -87,5 +87,5 @@ class QWenCV(Base): response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(self.image2base64(image))) if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'] - return response.message + return response.output.choices[0]['message']['content'], response.usage.output_tokens + return response.message, 0 diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2d0694dd7..be914e60d 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -36,6 +36,9 @@ class Base(ABC): def encode(self, texts: list, batch_size=32): raise NotImplementedError("Please implement encode method!") + def encode_queries(self, text: str): + raise NotImplementedError("Please implement encode method!") + class HuEmbedding(Base): def __init__(self, key="", model_name=""): @@ -68,15 +71,18 @@ class HuEmbedding(Base): class OpenAIEmbed(Base): def __init__(self, key, model_name="text-embedding-ada-002"): - self.client = OpenAI(key) + self.client = OpenAI(api_key=key) self.model_name = model_name def encode(self, texts: list, batch_size=32): - token_count = 0 - for t in texts: token_count += num_tokens_from_string(t) res = self.client.embeddings.create(input=texts, model=self.model_name) - return [d["embedding"] for d in res["data"]], token_count + return np.array([d.embedding for d in res.data]), res.usage.total_tokens + + def encode_queries(self, text): + res = self.client.embeddings.create(input=[text], + model=self.model_name) + return np.array(res.data[0].embedding), res.usage.total_tokens class QWenEmbed(Base): @@ -84,16 +90,28 @@ class QWenEmbed(Base): dashscope.api_key = key self.model_name = model_name - def encode(self, texts: list, batch_size=32, text_type="document"): + def encode(self, texts: list, batch_size=10): import dashscope res = [] token_count = 0 - for txt in texts: + texts = [txt[:2048] for txt in texts] + for i in range(0, len(texts), batch_size): resp = dashscope.TextEmbedding.call( model=self.model_name, - input=txt[:2048], - text_type=text_type + input=texts[i:i+batch_size], + text_type="document" ) - res.append(resp["output"]["embeddings"][0]["embedding"]) - token_count += resp["usage"]["total_tokens"] - return res, token_count + embds = [[]] * len(resp["output"]["embeddings"]) + for e in resp["output"]["embeddings"]: + embds[e["text_index"]] = e["embedding"] + res.extend(embds) + token_count += resp["usage"]["input_tokens"] + return np.array(res), token_count + + def encode_queries(self, text): + resp = dashscope.TextEmbedding.call( + model=self.model_name, + input=text[:2048], + text_type="query" + ) + return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"] \ No newline at end of file diff --git a/rag/nlp/huchunk.py b/rag/nlp/huchunk.py index cc93f5faf..c8f9e4704 100644 --- a/rag/nlp/huchunk.py +++ b/rag/nlp/huchunk.py @@ -11,6 +11,11 @@ from io import BytesIO class HuChunker: + @dataclass + class Fields: + text_chunks: List = None + table_chunks: List = None + def __init__(self): self.MAX_LVL = 12 self.proj_patt = [ @@ -228,11 +233,6 @@ class HuChunker: class PdfChunker(HuChunker): - @dataclass - class Fields: - text_chunks: List = None - table_chunks: List = None - def __init__(self, pdf_parser): self.pdf = pdf_parser super().__init__() @@ -293,11 +293,6 @@ class PdfChunker(HuChunker): class DocxChunker(HuChunker): - @dataclass - class Fields: - text_chunks: List = None - table_chunks: List = None - def __init__(self, doc_parser): self.doc = doc_parser super().__init__() @@ -344,11 +339,6 @@ class DocxChunker(HuChunker): class ExcelChunker(HuChunker): - @dataclass - class Fields: - text_chunks: List = None - table_chunks: List = None - def __init__(self, excel_parser): self.excel = excel_parser super().__init__() @@ -370,18 +360,51 @@ class PptChunker(HuChunker): def __init__(self): super().__init__() + def __extract(self, shape): + if shape.shape_type == 19: + tb = shape.table + rows = [] + for i in range(1, len(tb.rows)): + rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) + return "\n".join(rows) + + if shape.has_text_frame: + return shape.text_frame.text + + if shape.shape_type == 6: + texts = [] + for p in shape.shapes: + t = self.__extract(p) + if t: texts.append(t) + return "\n".join(texts) + def __call__(self, fnm): from pptx import Presentation ppt = Presentation(fnm) if isinstance( fnm, str) else Presentation( BytesIO(fnm)) - flds = self.Fields() - flds.text_chunks = [] + txts = [] for slide in ppt.slides: + texts = [] for shape in slide.shapes: - if hasattr(shape, "text"): - flds.text_chunks.append((shape.text, None)) + txt = self.__extract(shape) + if txt: texts.append(txt) + txts.append("\n".join(texts)) + + import aspose.slides as slides + import aspose.pydrawing as drawing + imgs = [] + with slides.Presentation(BytesIO(fnm)) as presentation: + for slide in presentation.slides: + buffered = BytesIO() + slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) + imgs.append(buffered.getvalue()) + assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) + + flds = self.Fields() + flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] flds.table_chunks = [] + return flds diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 29539cedd..d42909fce 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -58,7 +58,8 @@ class Dealer: if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) else: - bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) + bqry.filter.append( + Q("bool", must_not=Q("range", available_int={"lt": 1}))) bqry.boost = 0.05 s = Search() @@ -87,9 +88,12 @@ class Dealer: q_vec = [] if req.get("vector"): assert emb_mdl, "No embedding model selected" - s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps) + s["knn"] = self._vector( + qst, emb_mdl, req.get( + "similarity", 0.4), ps) s["knn"]["filter"] = bqry.to_dict() - if "highlight" in s: del s["highlight"] + if "highlight" in s: + del s["highlight"] q_vec = s["knn"]["query_vector"] es_logger.info("【Q】: {}".format(json.dumps(s))) res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) @@ -175,7 +179,8 @@ class Dealer: def trans2floats(txt): return [float(t) for t in txt.split("\t")] - def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7): + def insert_citations(self, answer, chunks, chunk_v, + embd_mdl, tkweight=0.3, vtweight=0.7): pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) for i in range(1, len(pieces)): if re.match(r"[a-z][.?;!][ \n]", pieces[i]): @@ -184,47 +189,57 @@ class Dealer: idx = [] pieces_ = [] for i, t in enumerate(pieces): - if len(t) < 5: continue + if len(t) < 5: + continue idx.append(i) pieces_.append(t) es_logger.info("{} => {}".format(answer, pieces_)) - if not pieces_: return answer + if not pieces_: + return answer - ans_v, c = embd_mdl.encode(pieces_) + ans_v, _ = embd_mdl.encode(pieces_) assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( len(ans_v[0]), len(chunk_v[0])) chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks] cites = {} - for i,a in enumerate(pieces_): + for i, a in enumerate(pieces_): sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v, - huqie.qie(pieces_[i]).split(" "), + huqie.qie( + pieces_[i]).split(" "), chunks_tks, tkweight, vtweight) mx = np.max(sim) * 0.99 - if mx < 0.55: continue - cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] + if mx < 0.55: + continue + cites[idx[i]] = list( + set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] res = "" - for i,p in enumerate(pieces): + for i, p in enumerate(pieces): res += p - if i not in idx:continue - if i not in cites:continue - res += "##%s$$"%"$".join(cites[i]) + if i not in idx: + continue + if i not in cites: + continue + res += "##%s$$" % "$".join(cites[i]) return res - def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): + def rerank(self, sres, query, tkweight=0.3, + vtweight=0.7, cfield="content_ltks"): ins_embd = [ Dealer.trans2floats( - sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] + sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] if not ins_embd: return [], [], [] - ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] + ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") + for i in sres.ids] sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, - huqie.qie(query).split(" "), + huqie.qie( + query).split(" "), ins_tw, tkweight, vtweight) return sim, tksim, vtsim @@ -237,7 +252,8 @@ class Dealer: def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} - if not question: return ranks + if not question: + return ranks req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, "question": question, "vector": True, "similarity": similarity_threshold} diff --git a/rag/svr/parse_user_docs.py b/rag/svr/parse_user_docs.py index 0000c6a38..88bc585e9 100644 --- a/rag/svr/parse_user_docs.py +++ b/rag/svr/parse_user_docs.py @@ -49,7 +49,7 @@ from rag.nlp.huchunk import ( ) from api.db import LLMType from api.db.services.document_service import DocumentService -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import TenantLLMService, LLMBundle from api.settings import database_logger from api.utils import get_format_time from api.utils.file_utils import get_project_base_directory @@ -62,7 +62,7 @@ EXC = ExcelChunker(ExcelParser()) PPT = PptChunker() -def chuck_doc(name, binary, cvmdl=None): +def chuck_doc(name, binary, tenant_id, cvmdl=None): suff = os.path.split(name)[-1].lower().split(".")[-1] if suff.find("pdf") >= 0: return PDF(binary) @@ -127,7 +127,7 @@ def build(row, cvmdl): 100., "Finished preparing! Start to slice file!", True) try: cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) - obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) + obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl) except Exception as e: if re.search("(No such file|not found)", str(e)): set_progress( @@ -236,12 +236,14 @@ def main(comm, mod): tmf = open(tm_fnm, "a+") for _, r in rows.iterrows(): - embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING) - if not embd_mdl: - set_progress(r["id"], -1, "Can't find embedding model!") - cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"])) + try: + embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) + cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT) + #TODO: sequence2text model + except Exception as e: + set_progress(r["id"], -1, str(e)) continue - cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT) + st_tm = timer() cks = build(r, cv_mdl) if not cks: