Refactor ask decorator (#4116)

### What problem does this PR solve?

Refactor ask decorator

### Type of change

- [x] Refactoring

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Jin Hai
2024-12-19 18:13:33 +08:00
committed by GitHub
parent 478da3118c
commit 213218a094
2 changed files with 141 additions and 102 deletions

View File

@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
# model name must be xxx@yyy
try:
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
fact = set([f["name"] for f in fact])
if arr[-1] not in fact:
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
"llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
return num
else:
tenant_llm = tenant_llms[0]
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
.execute()
except Exception:
logging.exception("TenantLLMService.increase_usage got exception")
@ -231,7 +233,7 @@ class LLMBundle(object):
for lm in LLMService.query(llm_name=llm_name):
self.max_length = lm.max_tokens
break
def encode(self, texts: list):
embeddings, used_tokens = self.mdl.encode(texts)
if not TenantLLMService.increase_usage(
@ -274,11 +276,11 @@ class LLMBundle(object):
def tts(self, text):
for chunk in self.mdl.tts(text):
if isinstance(chunk,int):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
@ -287,7 +289,8 @@ class LLMBundle(object):
if isinstance(txt, int) and not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error(
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
used_tokens))
return txt
def chat_streamly(self, system, history, gen_conf):
@ -296,6 +299,7 @@ class LLMBundle(object):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error(
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
txt))
return
yield txt