You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.tsx 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. 'use client'
  2. import type { FC } from 'react'
  3. import { memo, useState } from 'react'
  4. import { useTranslation } from 'react-i18next'
  5. import { useContext } from 'use-context-selector'
  6. import cn from 'classnames'
  7. import { Settings04 } from '@/app/components/base/icons/src/vender/line/general'
  8. import ConfigContext from '@/context/debug-configuration'
  9. import TopKItem from '@/app/components/base/param-item/top-k-item'
  10. import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item'
  11. import Modal from '@/app/components/base/modal'
  12. import Button from '@/app/components/base/button'
  13. import RadioCard from '@/app/components/base/radio-card/simple'
  14. import { RETRIEVE_TYPE } from '@/types/app'
  15. import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector'
  16. import { useProviderContext } from '@/context/provider-context'
  17. import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
  18. import Toast from '@/app/components/base/toast'
  19. import { DATASET_DEFAULT } from '@/config'
  20. import {
  21. MultiPathRetrieval,
  22. NTo1Retrieval,
  23. } from '@/app/components/base/icons/src/public/common'
  24. const ParamsConfig: FC = () => {
  25. const { t } = useTranslation()
  26. const [open, setOpen] = useState(false)
  27. const {
  28. datasetConfigs,
  29. setDatasetConfigs,
  30. } = useContext(ConfigContext)
  31. const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)
  32. const type = tempDataSetConfigs.retrieval_model
  33. const setType = (value: RETRIEVE_TYPE) => {
  34. setTempDataSetConfigs({
  35. ...tempDataSetConfigs,
  36. retrieval_model: value,
  37. })
  38. }
  39. const {
  40. rerankDefaultModel,
  41. isRerankDefaultModelVaild,
  42. } = useProviderContext()
  43. const rerankModel = (() => {
  44. if (tempDataSetConfigs.reranking_model) {
  45. return {
  46. provider_name: tempDataSetConfigs.reranking_model.reranking_provider_name,
  47. model_name: tempDataSetConfigs.reranking_model.reranking_model_name,
  48. }
  49. }
  50. else if (rerankDefaultModel) {
  51. return {
  52. provider_name: rerankDefaultModel.model_provider.provider_name,
  53. model_name: rerankDefaultModel.model_name,
  54. }
  55. }
  56. })()
  57. const handleParamChange = (key: string, value: number) => {
  58. if (key === 'top_k') {
  59. setTempDataSetConfigs({
  60. ...tempDataSetConfigs,
  61. top_k: value,
  62. })
  63. }
  64. else if (key === 'score_threshold') {
  65. setTempDataSetConfigs({
  66. ...tempDataSetConfigs,
  67. score_threshold: value,
  68. })
  69. }
  70. }
  71. const handleSwitch = (key: string, enable: boolean) => {
  72. if (key === 'top_k')
  73. return
  74. setTempDataSetConfigs({
  75. ...tempDataSetConfigs,
  76. score_threshold_enabled: enable,
  77. })
  78. }
  79. const isValid = () => {
  80. let errMsg = ''
  81. if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
  82. if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (!rerankDefaultModel && isRerankDefaultModelVaild))
  83. errMsg = t('appDebug.datasetConfig.rerankModelRequired')
  84. }
  85. if (errMsg) {
  86. Toast.notify({
  87. type: 'error',
  88. message: errMsg,
  89. })
  90. }
  91. return !errMsg
  92. }
  93. const handleSave = () => {
  94. if (!isValid())
  95. return
  96. const config = { ...tempDataSetConfigs }
  97. if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) {
  98. config.reranking_model = {
  99. reranking_provider_name: rerankDefaultModel?.model_provider.provider_name,
  100. reranking_model_name: rerankDefaultModel?.model_name,
  101. } as any
  102. }
  103. setDatasetConfigs(config)
  104. setOpen(false)
  105. }
  106. return (
  107. <div>
  108. <div
  109. className={cn('flex items-center rounded-md h-7 px-3 space-x-1 text-gray-700 cursor-pointer hover:bg-gray-200', open && 'bg-gray-200')}
  110. onClick={() => {
  111. setTempDataSetConfigs({
  112. ...datasetConfigs,
  113. top_k: datasetConfigs.top_k || DATASET_DEFAULT.top_k,
  114. score_threshold: datasetConfigs.score_threshold || DATASET_DEFAULT.score_threshold,
  115. })
  116. setOpen(true)
  117. }}
  118. >
  119. <Settings04 className="w-[14px] h-[14px]" />
  120. <div className='text-xs font-medium'>
  121. {t('appDebug.datasetConfig.params')}
  122. </div>
  123. </div>
  124. {
  125. open && (
  126. <Modal
  127. isShow={open}
  128. onClose={() => {
  129. setOpen(false)
  130. }}
  131. className='sm:min-w-[528px]'
  132. wrapperClassName='z-50'
  133. title={t('appDebug.datasetConfig.settingTitle')}
  134. >
  135. <div className='mt-2 space-y-3'>
  136. <RadioCard
  137. icon={<NTo1Retrieval className='shrink-0 mr-3 w-9 h-9 rounded-lg' />}
  138. title={t('appDebug.datasetConfig.retrieveOneWay.title')}
  139. description={t('appDebug.datasetConfig.retrieveOneWay.description')}
  140. isChosen={type === RETRIEVE_TYPE.oneWay}
  141. onChosen={() => { setType(RETRIEVE_TYPE.oneWay) }}
  142. />
  143. <RadioCard
  144. icon={<MultiPathRetrieval className='shrink-0 mr-3 w-9 h-9 rounded-lg' />}
  145. title={t('appDebug.datasetConfig.retrieveMultiWay.title')}
  146. description={t('appDebug.datasetConfig.retrieveMultiWay.description')}
  147. isChosen={type === RETRIEVE_TYPE.multiWay}
  148. onChosen={() => { setType(RETRIEVE_TYPE.multiWay) }}
  149. />
  150. </div>
  151. {type === RETRIEVE_TYPE.multiWay && (
  152. <>
  153. <div className='mt-6'>
  154. <div className='leading-[32px] text-[13px] font-medium text-gray-900'>{t('common.modelProvider.rerankModel.key')}</div>
  155. <div>
  156. <ModelSelector
  157. popClassName='!max-w-[100%] !w-full'
  158. value={rerankModel && { providerName: rerankModel.provider_name, modelName: rerankModel.model_name } as any}
  159. modelType={ModelType.reranking}
  160. onChange={(v) => {
  161. setTempDataSetConfigs({
  162. ...tempDataSetConfigs,
  163. reranking_model: {
  164. reranking_provider_name: v.model_provider.provider_name,
  165. reranking_model_name: v.model_name,
  166. },
  167. })
  168. }}
  169. />
  170. </div>
  171. </div>
  172. <div className='mt-4 space-y-4'>
  173. <TopKItem
  174. value={tempDataSetConfigs.top_k}
  175. onChange={handleParamChange}
  176. enable={true}
  177. />
  178. <ScoreThresholdItem
  179. value={tempDataSetConfigs.score_threshold}
  180. onChange={handleParamChange}
  181. enable={tempDataSetConfigs.score_threshold_enabled}
  182. hasSwitch={true}
  183. onSwitchChange={handleSwitch}
  184. />
  185. </div>
  186. </>
  187. )}
  188. <div className='mt-6 flex justify-end'>
  189. <Button className='mr-2 flex-shrink-0' onClick={() => {
  190. setOpen(false)
  191. }}>{t('common.operation.cancel')}</Button>
  192. <Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
  193. </div>
  194. </Modal>
  195. )
  196. }
  197. </div>
  198. )
  199. }
  200. export default memo(ParamsConfig)