Browse Source

Refa. (#7022)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
tags/v0.18.0
Kevin Hu 6 months ago
parent
commit
5af2d57086
No account linked to committer's email address
3 changed files with 78 additions and 15 deletions
  1. 39
    4
      api/db/services/dialog_service.py
  2. 38
    10
      api/db/services/user_service.py
  3. 1
    1
      rag/prompts.py

+ 39
- 4
api/db/services/dialog_service.py View File

@@ -14,6 +14,7 @@
# limitations under the License.
#
import binascii
from datetime import datetime
import logging
import re
import time
@@ -31,6 +32,7 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
@@ -42,6 +44,39 @@ from rag.utils.tavily_conn import Tavily
class DialogService(CommonService):
model = Dialog

@classmethod
def save(cls, **kwargs):
"""Save a new record to database.

This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.

Args:
**kwargs: Record field values as keyword arguments.

Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj

@classmethod
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.

This method updates multiple records in the database, identified by their IDs.
It automatically updates the update_time and update_date fields for each record.

Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()

@classmethod
@DB.connection_context()
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
@@ -434,11 +469,11 @@ Please write the SQL, only SQL, without any other explanations or text.
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.

The SQL error you provided last time is as follows:
{}
@@ -461,7 +496,7 @@ Please write the SQL, only SQL, without any other explanations or text.

# compose Markdown table
columns = (
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
)

line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@@ -557,4 +592,4 @@ def ask(question, kb_ids, tenant_id):
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
yield decorate_answer(answer)

+ 38
- 10
api/db/services/user_service.py View File

@@ -30,10 +30,10 @@ from rag.settings import MINIO

class UserService(CommonService):
"""Service class for managing user-related database operations.
This class extends CommonService to provide specialized functionality for user management,
including authentication, user creation, updates, and deletions.
Attributes:
model: The User model class for database operations.
"""
@@ -43,10 +43,10 @@ class UserService(CommonService):
@DB.connection_context()
def filter_by_id(cls, user_id):
"""Retrieve a user by their ID.
Args:
user_id: The unique identifier of the user.
Returns:
User object if found, None otherwise.
"""
@@ -60,11 +60,11 @@ class UserService(CommonService):
@DB.connection_context()
def query_user(cls, email, password):
"""Authenticate a user with email and password.
Args:
email: User's email address.
password: User's password in plain text.
Returns:
User object if authentication successful, None otherwise.
"""
@@ -111,10 +111,10 @@ class UserService(CommonService):

class TenantService(CommonService):
"""Service class for managing tenant-related database operations.
This class extends CommonService to provide functionality for tenant management,
including tenant information retrieval and credit management.
Attributes:
model: The Tenant model class for database operations.
"""
@@ -170,15 +170,24 @@ class TenantService(CommonService):

class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations.
This class extends CommonService to handle the many-to-many relationship
between users and tenants, managing user roles and tenant memberships.
Attributes:
model: The UserTenant model class for database operations.
"""
model = UserTenant

@classmethod
@DB.connection_context()
def filter_by_id(cls, user_tenant_id):
try:
user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
return user_tenant
except peewee.DoesNotExist:
return None

@classmethod
@DB.connection_context()
def save(cls, **kwargs):
@@ -191,6 +200,7 @@ class UserTenantService(CommonService):
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id):
fields = [
cls.model.id,
cls.model.user_id,
cls.model.status,
cls.model.role,
@@ -222,3 +232,21 @@ class UserTenantService(CommonService):
return list(cls.model.select(*fields)
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())

@classmethod
@DB.connection_context()
def get_num_members(cls, user_id: str):
cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
return cnt_members

@classmethod
@DB.connection_context()
def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
try:
user_tenant = cls.model.select().where(
(cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
(cls.model.user_id == user_id)
).first()
return user_tenant
except peewee.DoesNotExist:
return None

+ 1
- 1
rag/prompts.py View File

@@ -52,7 +52,7 @@ def chunks_format(reference):
def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService

llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)

llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:

Loading…
Cancel
Save