You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.tsx 9.8KB


  1. import type {
  2. FC,
  3. ReactNode,
  4. } from 'react'
  5. import { useMemo, useState } from 'react'
  6. import { useTranslation } from 'react-i18next'
  7. import type {
  8. DefaultModel,
  9. FormValue,
  10. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  11. import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  12. import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
  13. import {
  14. useModelList,
  15. } from '@/app/components/header/account-setting/model-provider-page/hooks'
  16. import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
  17. import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  18. import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  19. import {
  20. PortalToFollowElem,
  21. PortalToFollowElemContent,
  22. PortalToFollowElemTrigger,
  23. } from '@/app/components/base/portal-to-follow-elem'
  24. import LLMParamsPanel from './llm-params-panel'
  25. import TTSParamsPanel from './tts-params-panel'
  26. import { useProviderContext } from '@/context/provider-context'
  27. import cn from '@/utils/classnames'
  28. import Toast from '@/app/components/base/toast'
  29. import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params'
  30. export type ModelParameterModalProps = {
  31. popupClassName?: string
  32. portalToFollowElemContentClassName?: string
  33. isAdvancedMode: boolean
  34. value: any
  35. setModel: (model: any) => void
  36. renderTrigger?: (v: TriggerProps) => ReactNode
  37. readonly?: boolean
  38. isInWorkflow?: boolean
  39. isAgentStrategy?: boolean
  40. scope?: string
  41. }
  42. const ModelParameterModal: FC<ModelParameterModalProps> = ({
  43. popupClassName,
  44. portalToFollowElemContentClassName,
  45. isAdvancedMode,
  46. value,
  47. setModel,
  48. renderTrigger,
  49. readonly,
  50. isInWorkflow,
  51. isAgentStrategy,
  52. scope = ModelTypeEnum.textGeneration,
  53. }) => {
  54. const { t } = useTranslation()
  55. const { isAPIKeySet } = useProviderContext()
  56. const [open, setOpen] = useState(false)
  57. const scopeArray = scope.split('&')
  58. const scopeFeatures = useMemo(() => {
  59. if (scopeArray.includes('all'))
  60. return []
  61. return scopeArray.filter(item => ![
  62. ModelTypeEnum.textGeneration,
  63. ModelTypeEnum.textEmbedding,
  64. ModelTypeEnum.rerank,
  65. ModelTypeEnum.moderation,
  66. ModelTypeEnum.speech2text,
  67. ModelTypeEnum.tts,
  68. ].includes(item as ModelTypeEnum))
  69. }, [scopeArray])
  70. const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
  71. const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
  72. const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
  73. const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
  74. const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
  75. const { data: ttsList } = useModelList(ModelTypeEnum.tts)
  76. const scopedModelList = useMemo(() => {
  77. const resultList: any[] = []
  78. if (scopeArray.includes('all')) {
  79. return [
  80. ...textGenerationList,
  81. ...textEmbeddingList,
  82. ...rerankList,
  83. ...sttList,
  84. ...ttsList,
  85. ...moderationList,
  86. ]
  87. }
  88. if (scopeArray.includes(ModelTypeEnum.textGeneration))
  89. return textGenerationList
  90. if (scopeArray.includes(ModelTypeEnum.textEmbedding))
  91. return textEmbeddingList
  92. if (scopeArray.includes(ModelTypeEnum.rerank))
  93. return rerankList
  94. if (scopeArray.includes(ModelTypeEnum.moderation))
  95. return moderationList
  96. if (scopeArray.includes(ModelTypeEnum.speech2text))
  97. return sttList
  98. if (scopeArray.includes(ModelTypeEnum.tts))
  99. return ttsList
  100. return resultList
  101. }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
  102. const { currentProvider, currentModel } = useMemo(() => {
  103. const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
  104. const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
  105. return {
  106. currentProvider,
  107. currentModel,
  108. }
  109. }, [scopedModelList, value?.provider, value?.model])
  110. const hasDeprecated = useMemo(() => {
  111. return !currentProvider || !currentModel
  112. }, [currentModel, currentProvider])
  113. const modelDisabled = useMemo(() => {
  114. return currentModel?.status !== ModelStatusEnum.active
  115. }, [currentModel?.status])
  116. const disabled = useMemo(() => {
  117. return !isAPIKeySet || hasDeprecated || modelDisabled
  118. }, [hasDeprecated, isAPIKeySet, modelDisabled])
  119. const handleChangeModel = async ({ provider, model }: DefaultModel) => {
  120. const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
  121. const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
  122. const model_type = targetModelItem?.model_type as string
  123. let nextCompletionParams: FormValue = {}
  124. if (model_type === ModelTypeEnum.textGeneration) {
  125. try {
  126. const { params: filtered, removedDetails } = await fetchAndMergeValidCompletionParams(
  127. provider,
  128. model,
  129. value?.completion_params,
  130. )
  131. nextCompletionParams = filtered
  132. const keys = Object.keys(removedDetails || {})
  133. if (keys.length) {
  134. Toast.notify({
  135. type: 'warning',
  136. message: `${t('common.modelProvider.parametersInvalidRemoved')}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`,
  137. })
  138. }
  139. }
  140. catch (e) {
  141. Toast.notify({ type: 'error', message: t('common.error') })
  142. }
  143. }
  144. setModel({
  145. provider,
  146. model,
  147. model_type,
  148. ...(model_type === ModelTypeEnum.textGeneration ? {
  149. mode: targetModelItem?.model_properties.mode as string,
  150. completion_params: nextCompletionParams,
  151. } : {}),
  152. })
  153. }
  154. const handleLLMParamsChange = (newParams: FormValue) => {
  155. const newValue = {
  156. ...value?.completionParams,
  157. completion_params: newParams,
  158. }
  159. setModel({
  160. ...value,
  161. ...newValue,
  162. })
  163. }
  164. const handleTTSParamsChange = (language: string, voice: string) => {
  165. setModel({
  166. ...value,
  167. language,
  168. voice,
  169. })
  170. }
  171. return (
  172. <PortalToFollowElem
  173. open={open}
  174. onOpenChange={setOpen}
  175. placement={isInWorkflow ? 'left' : 'bottom-end'}
  176. offset={4}
  177. >
  178. <div className='relative'>
  179. <PortalToFollowElemTrigger
  180. onClick={() => {
  181. if (readonly)
  182. return
  183. setOpen(v => !v)
  184. }}
  185. className='block'
  186. >
  187. {
  188. renderTrigger
  189. ? renderTrigger({
  190. open,
  191. disabled,
  192. modelDisabled,
  193. hasDeprecated,
  194. currentProvider,
  195. currentModel,
  196. providerName: value?.provider,
  197. modelId: value?.model,
  198. })
  199. : (isAgentStrategy
  200. ? <AgentModelTrigger
  201. disabled={disabled}
  202. hasDeprecated={hasDeprecated}
  203. currentProvider={currentProvider}
  204. currentModel={currentModel}
  205. providerName={value?.provider}
  206. modelId={value?.model}
  207. scope={scope}
  208. />
  209. : <Trigger
  210. disabled={disabled}
  211. isInWorkflow={isInWorkflow}
  212. modelDisabled={modelDisabled}
  213. hasDeprecated={hasDeprecated}
  214. currentProvider={currentProvider}
  215. currentModel={currentModel}
  216. providerName={value?.provider}
  217. modelId={value?.model}
  218. />
  219. )
  220. }
  221. </PortalToFollowElemTrigger>
  222. <PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
  223. <div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
  224. <div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
  225. <div className='relative'>
  226. <div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
  227. {t('common.modelProvider.model').toLocaleUpperCase()}
  228. </div>
  229. <ModelSelector
  230. defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
  231. modelList={scopedModelList}
  232. scopeFeatures={scopeFeatures}
  233. onSelect={handleChangeModel}
  234. />
  235. </div>
  236. {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
  237. <div className='my-3 h-px bg-divider-subtle' />
  238. )}
  239. {currentModel?.model_type === ModelTypeEnum.textGeneration && (
  240. <LLMParamsPanel
  241. provider={value?.provider}
  242. modelId={value?.model}
  243. completionParams={value?.completion_params || {}}
  244. onCompletionParamsChange={handleLLMParamsChange}
  245. isAdvancedMode={isAdvancedMode}
  246. />
  247. )}
  248. {currentModel?.model_type === ModelTypeEnum.tts && (
  249. <TTSParamsPanel
  250. currentModel={currentModel}
  251. language={value?.language}
  252. voice={value?.voice}
  253. onChange={handleTTSParamsChange}
  254. />
  255. )}
  256. </div>
  257. </div>
  258. </PortalToFollowElemContent>
  259. </div>
  260. </PortalToFollowElem>
  261. )
  262. }
  263. export default ModelParameterModal