mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor ask decorator (#4116)
### What problem does this PR solve? Refactor ask decorator ### Type of change - [x] Refactoring --------- Signed-off-by: jinhai <haijin.chn@gmail.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||
fact = set([f["name"] for f in fact])
|
||||
if arr[-1] not in fact:
|
||||
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
|
||||
return num
|
||||
else:
|
||||
tenant_llm = tenant_llms[0]
|
||||
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
||||
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
|
||||
.execute()
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception")
|
||||
@ -231,7 +233,7 @@ class LLMBundle(object):
|
||||
for lm in LLMService.query(llm_name=llm_name):
|
||||
self.max_length = lm.max_tokens
|
||||
break
|
||||
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings, used_tokens = self.mdl.encode(texts)
|
||||
if not TenantLLMService.increase_usage(
|
||||
@ -274,11 +276,11 @@ class LLMBundle(object):
|
||||
|
||||
def tts(self, text):
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk,int):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
|
||||
@ -287,7 +289,8 @@ class LLMBundle(object):
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
||||
used_tokens))
|
||||
return txt
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
@ -296,6 +299,7 @@ class LLMBundle(object):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
||||
txt))
|
||||
return
|
||||
yield txt
|
||||
|
||||
Reference in New Issue
Block a user