ソースを参照

Add 2 embeding models from OpenAI (#812)

### What problem does this PR solve?

#810 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.6.0
KevinHuSh 1年前
コミット
e73ce39b66
コミッターのメールアドレスに関連付けられたアカウントが存在しません
2個のファイルの変更40行の追加0行の削除
  1. 30
    0
      api/db/init_data.py
  2. 10
    0
      api/db/services/llm_service.py

+ 30
- 0
api/db/init_data.py ファイルの表示

import os import os
import time import time
import uuid import uuid
from copy import deepcopy
from api.db import LLMType, UserTenantRole from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
"tags": "TEXT EMBEDDING,8K", "tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191, "max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-small",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-large",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, { }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "whisper-1", "llm_name": "whisper-1",
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"]) LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
LLMService.filter_delete([LLMService.model.fid == "QAnything"]) LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
## insert openai two embedding models to the current openai user.
print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
for tid in tenant_ids:
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
row = row.to_dict()
row["model_type"] = LLMType.EMBEDDING.value
row["llm_name"] = "text-embedding-3-small"
row["used_tokens"] = 0
try:
TenantLLMService.save(**row)
row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row)
except Exception as e:
pass
break
""" """
drop table llm; drop table llm;
drop table llm_factories; drop table llm_factories;

+ 10
- 0
api/db/services/llm_service.py ファイルの表示

.execute() .execute()
return num return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
return list(objs)
class LLMBundle(object): class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):

読み込み中…
キャンセル
保存