mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
fix user login issue (#85)
This commit is contained in:
@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
|
||||
|
||||
@manager.route('/login', methods=['POST', 'GET'])
|
||||
def login():
|
||||
userinfo = None
|
||||
login_channel = "password"
|
||||
if session.get("access_token"):
|
||||
login_channel = session["access_token_from"]
|
||||
if session["access_token_from"] == "github":
|
||||
userinfo = user_info_from_github(session["access_token"])
|
||||
elif not request.json:
|
||||
if not request.json:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg='Unautherized!')
|
||||
|
||||
email = request.json.get('email') if not userinfo else userinfo["email"]
|
||||
email = request.json.get('email', "")
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
if request.json is not None:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
avatar = ""
|
||||
try:
|
||||
avatar = download_img(userinfo["avatar_url"])
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, {
|
||||
"access_token": session["access_token"],
|
||||
"email": userinfo["email"],
|
||||
"avatar": avatar,
|
||||
"nickname": userinfo["login"],
|
||||
"login_channel": login_channel,
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
elif not request.json:
|
||||
login_user(users[0])
|
||||
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
||||
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
|
||||
password = request.json.get('password')
|
||||
try:
|
||||
@ -97,28 +62,50 @@ def login():
|
||||
|
||||
@manager.route('/github_callback', methods=['GET'])
|
||||
def github_callback():
|
||||
try:
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
},headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg=res["error_description"])
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
}, headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg=res["error_description"])
|
||||
|
||||
if "user:email" not in res["scope"].split(","):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||
if "user:email" not in res["scope"].split(","):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
return redirect(url_for("user.login"), code=307)
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
userinfo = user_info_from_github(session["access_token"])
|
||||
users = UserService.query(email=userinfo["email"])
|
||||
user_id = get_uuid()
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(userinfo["avatar_url"])
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
avatar = ""
|
||||
users = user_register(user_id, {
|
||||
"access_token": session["access_token"],
|
||||
"email": userinfo["email"],
|
||||
"avatar": avatar,
|
||||
"nickname": userinfo["login"],
|
||||
"login_channel": "github",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
return redirect("/knowledge")
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
@ -208,7 +195,7 @@ def user_register(user_id, user):
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
||||
|
||||
if not UserService.insert(**user):return
|
||||
if not UserService.save(**user):return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
|
||||
@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
GENERAL = "general"
|
||||
PRESENTATION = "presentation"
|
||||
LAWS = "laws"
|
||||
MANUAL = "manual"
|
||||
|
||||
@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
|
||||
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ def init_superuser():
|
||||
"password": "admin",
|
||||
"nickname": "admin",
|
||||
"is_superuser": True,
|
||||
"email": "kai.hu@infiniflow.org",
|
||||
"email": "admin@ragflow.io",
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
@ -61,7 +61,7 @@ def init_superuser():
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
||||
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
@ -20,7 +22,7 @@ from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, get_format_time
|
||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
@ -53,6 +55,11 @@ class UserService(CommonService):
|
||||
kwargs["id"] = get_uuid()
|
||||
if "password" in kwargs:
|
||||
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
||||
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
@ -66,10 +73,10 @@ class UserService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user(cls, user_id, user_dict):
|
||||
date_time = get_format_time()
|
||||
with DB.atomic():
|
||||
if user_dict:
|
||||
user_dict["update_time"] = date_time
|
||||
user_dict["update_time"] = current_timestamp()
|
||||
user_dict["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||
|
||||
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
||||
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
||||
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
||||
|
||||
# distribution
|
||||
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
||||
|
||||
Reference in New Issue
Block a user