Feat: add basic Langfuse support for LLM module (#6443)

### What problem does this PR solve?

#6155

Add basic Langfuse support for LLM module.

A trace example:

<img width="755" alt="image"
src="https://github.com/user-attachments/assets/25c1f852-5116-486c-a47f-6097187142ca"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-03-24 13:18:47 +08:00
committed by GitHub
parent 0b63346a1a
commit 85eb367ede
7 changed files with 714 additions and 892 deletions

File diff suppressed because it is too large Load Diff

View File

@ -13,26 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import binascii
import time
from functools import partial
import logging
import re
import time
from copy import deepcopy
from functools import partial
from timeit import default_timer as timer
from langfuse import Langfuse
from agentic_reasoning import DeepResearcher
from api import settings
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import Dialog, DB
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle, TenantLLMService
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
citation_prompt
from rag.utils import rmSpace, num_tokens_from_string
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily
@ -41,17 +44,13 @@ class DialogService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, name):
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where(
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc())
else:
@ -72,13 +71,12 @@ def chat_solo(dialog, messages, stream=True):
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"]
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if stream:
last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
delta_ans = ans[len(last_ans) :]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
@ -110,6 +108,16 @@ def chat(dialog, messages, stream=True, **kwargs):
check_llm_ts = timer()
langfuse_tracer = None
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
langfuse_tracer = langfuse
langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")
check_langfuse_tracer_ts = timer()
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
@ -159,8 +167,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ")
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
@ -189,9 +196,11 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
)
for think in reasoner.thinking(kbinfos, " ".join(questions)):
if isinstance(think, str):
@ -200,31 +209,34 @@ def chat(dialog, messages, stream=True, **kwargs):
elif stream:
yield think
else:
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
dialog.kb_ids,
1,
dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
dialog.kb_ids,
embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
@ -239,16 +251,13 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"])
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(
gen_conf["max_tokens"],
max_tokens - used_token_count)
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
@ -262,14 +271,14 @@ def chat(dialog, messages, stream=True, **kwargs):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
answer, idx = retriever.insert_citations(
answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight,
)
else:
idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer):
@ -278,8 +287,7 @@ def chat(dialog, messages, stream=True, **kwargs):
idx.add(i)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
@ -295,7 +303,8 @@ def chat(dialog, messages, stream=True, **kwargs):
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
@ -304,45 +313,54 @@ def chat(dialog, messages, stream=True, **kwargs):
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
tk_num = num_tokens_from_string(think+answer)
tk_num = num_tokens_from_string(think + answer)
prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = (
f"{prompt}\n\n"
"## Time elapsed:\n"
f" - Total: {total_time_cost:.1f}ms\n"
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
f" - Tune question: {refine_question_time_cost:.1f}ms\n"
f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
"## Token usage:\n"
f" - Generated tokens(approximately): {tk_num}\n"
f" - Token speed: {int(tk_num/(generate_result_time_cost/1000.))}/s"
f"{prompt}\n\n"
"## Time elapsed:\n"
f" - Total: {total_time_cost:.1f}ms\n"
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
f" - Tune question: {refine_question_time_cost:.1f}ms\n"
f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
"## Token usage:\n"
f" - Generated tokens(approximately): {tk_num}\n"
f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
)
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
langfuse_generation.end(output=langfuse_output)
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if langfuse_tracer:
langfuse_generation = langfuse_tracer.trace.generation(name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg})
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans):]
delta_ans = ans[len(last_ans) :]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans) :]
if delta_ans:
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought+answer)
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
@ -360,27 +378,22 @@ Table of database fields are as follows:
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
tried_times = 0
def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
"temperature": 0.06})
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *":
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
@ -417,11 +430,7 @@ Please write the SQL, only SQL, without any other explanations or text.
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
@ -429,24 +438,18 @@ Please write the SQL, only SQL, without any other explanations or text.
if tbl.get("error") or len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
@ -456,11 +459,7 @@ Please write the SQL, only SQL, without any other explanations or text.
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
@ -471,10 +470,11 @@ Please write the SQL, only SQL, without any other explanations or text.
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]},
"prompt": sys_prompt
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
},
"prompt": sys_prompt,
}
@ -498,10 +498,7 @@ def ask(question, kb_ids, tenant_id):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
@ -523,17 +520,9 @@ def ask(question, kb_ids, tenant_id):
def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs

View File

@ -0,0 +1,56 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, TenantLangfuse
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class TenantLangfuseService(CommonService):
"""
All methods that modify the status should be enclosed within a DB.atomic() context to ensure atomicity
and maintain data integrity in case of errors during execution.
"""
model = TenantLangfuse
@classmethod
@DB.connection_context()
def filter_by_tenant(cls, tenant_id):
fields = [cls.model.host, cls.model.secret_key, cls.model.public_key]
try:
keys = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).first()
return keys
except peewee.DoesNotExist:
return None
@classmethod
def update_by_tenant(cls, tenant_id, langfuse_keys):
fields = ["tenant_id", "host", "secret_key", "public_key"]
return cls.model.update(**langfuse_keys).fields(*fields).where(cls.model.tenant_id == tenant_id).execute()
@classmethod
def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model.create(**kwargs)
return obj

View File

@ -15,10 +15,13 @@
#
import logging
from langfuse import Langfuse
from api import settings
from api.db import LLMType
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
@ -49,16 +52,8 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name,
cls.model.used_tokens
]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@ -114,8 +109,7 @@ class TenantLLMService(CommonService):
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
@ -124,43 +118,32 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"]
)
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](
key=model_config["api_key"], model_name=model_config["llm_name"],
lang=lang,
base_url=model_config["api_base"]
)
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
@ -184,7 +167,7 @@ class TenantLLMService(CommonService):
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
}
mdlnm = llm_map.get(llm_type)
@ -195,17 +178,13 @@ class TenantLLMService(CommonService):
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
try:
num = cls.model.update(
used_tokens=cls.model.used_tokens + used_tokens
).where(
cls.model.tenant_id == tenant_id,
cls.model.llm_name == llm_name,
cls.model.llm_factory == llm_factory if llm_factory else True
).execute()
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
logging.exception(
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
tenant_id, llm_name)
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@ -213,11 +192,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@ -226,87 +201,138 @@ class LLMBundle:
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
self.langfuse = langfuse
self.trace = self.langfuse.trace(name=f"{self.llm_type}-{self.llm_name}")
else:
self.langfuse = None
def encode(self, texts: list):
if self.langfuse:
generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts})
embeddings, used_tokens = self.mdl.encode(texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(usage_details={"total_tokens": used_tokens})
return embeddings, used_tokens
def encode_queries(self, query: str):
if self.langfuse:
generation = self.trace.generation(name="encode_queries", model=self.llm_name, input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(usage_details={"total_tokens": used_tokens})
return emd, used_tokens
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.trace.generation(name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(usage_details={"total_tokens": used_tokens})
return sim, used_tokens
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.trace.generation(name="describe", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(output={"output": txt}, usage_details={"total_tokens": used_tokens})
return txt
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.trace.generation(name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(output={"output": txt}, usage_details={"total_tokens": used_tokens})
return txt
def transcription(self, audio):
if self.langfuse:
generation = self.trace.generation(name="transcription", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.end(output={"output": txt}, usage_details={"total_tokens": used_tokens})
return txt
def tts(self, text):
if self.langfuse:
span = self.trace.span(name="tts", input={"text": text})
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
if self.langfuse:
span.end()
def chat(self, system, history, gen_conf):
if self.langfuse:
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if isinstance(txt, int) and not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error(
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
used_tokens))
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
generation.end(output={"output": txt}, usage_details={"total_tokens": used_tokens})
return txt
def chat_streamly(self, system, history, gen_conf):
if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
output = ""
for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error(
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
txt))
if self.langfuse:
generation.end(output={"output": output})
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return
output = txt
yield txt