Sfoglia il codice sorgente

Add 2 embeding models from OpenAI (#812)

### What problem does this PR solve?

#810 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.6.0
KevinHuSh 1 anno fa
parent
commit
e73ce39b66
Nessun account collegato all'indirizzo email del committer
2 ha cambiato i file con 40 aggiunte e 0 eliminazioni
  1. 30
    0
      api/db/init_data.py
  2. 10
    0
      api/db/services/llm_service.py

+ 30
- 0
api/db/init_data.py Vedi File

@@ -16,6 +16,7 @@
import os
import time
import uuid
from copy import deepcopy
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
@@ -166,6 +167,18 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-small",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-large",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
@@ -376,6 +389,23 @@ def init_llm_factory():
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
## insert openai two embedding models to the current openai user.
print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
for tid in tenant_ids:
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
row = row.to_dict()
row["model_type"] = LLMType.EMBEDDING.value
row["llm_name"] = "text-embedding-3-small"
row["used_tokens"] = 0
try:
TenantLLMService.save(**row)
row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row)
except Exception as e:
pass
break
"""
drop table llm;
drop table llm_factories;

+ 10
- 0
api/db/services/llm_service.py Vedi File

@@ -135,6 +135,16 @@ class TenantLLMService(CommonService):
.execute()
return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
return list(objs)
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):

Loading…
Annulla
Salva