Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.6.13
| @@ -4,10 +4,6 @@ import time | |||
| import numpy as np | |||
| from sklearn.manifold import TSNE | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.entity.embedding import Embeddings | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document | |||
| from core.rag.retrieval.retrival_methods import RetrievalMethod | |||
| @@ -45,17 +41,6 @@ class HitTestingService: | |||
| if not retrieval_model: | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| # get embedding model | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| provider=dataset.embedding_model_provider, | |||
| model=dataset.embedding_model | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| @@ -80,20 +65,10 @@ class HitTestingService: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) | |||
| return cls.compact_retrieve_response(dataset, query, all_documents) | |||
| @classmethod | |||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]): | |||
| text_embeddings = [ | |||
| embeddings.embed_query(query) | |||
| ] | |||
| text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) | |||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) | |||
| query_position = tsne_position_data.pop(0) | |||
| def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | |||
| i = 0 | |||
| records = [] | |||
| for document in documents: | |||
| @@ -113,7 +88,6 @@ class HitTestingService: | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get('score', None), | |||
| "tsne_position": tsne_position_data[i] | |||
| } | |||
| records.append(record) | |||
| @@ -123,7 +97,6 @@ class HitTestingService: | |||
| return { | |||
| "query": { | |||
| "content": query, | |||
| "tsne_position": query_position, | |||
| }, | |||
| "records": records | |||
| } | |||
| @@ -2,51 +2,16 @@ import type { FC } from 'react' | |||
| import React from 'react' | |||
| import cn from 'classnames' | |||
| import { useTranslation } from 'react-i18next' | |||
| import ReactECharts from 'echarts-for-react' | |||
| import { SegmentIndexTag } from '../documents/detail/completed' | |||
| import s from '../documents/detail/completed/style.module.css' | |||
| import type { SegmentDetailModel } from '@/models/datasets' | |||
| import Divider from '@/app/components/base/divider' | |||
| type IScatterChartProps = { | |||
| data: Array<number[]> | |||
| curr: Array<number[]> | |||
| } | |||
| const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => { | |||
| const option = { | |||
| xAxis: {}, | |||
| yAxis: {}, | |||
| tooltip: { | |||
| trigger: 'item', | |||
| axisPointer: { | |||
| type: 'cross', | |||
| }, | |||
| }, | |||
| series: [ | |||
| { | |||
| type: 'effectScatter', | |||
| symbolSize: 5, | |||
| data: curr, | |||
| }, | |||
| { | |||
| type: 'scatter', | |||
| symbolSize: 5, | |||
| data, | |||
| }, | |||
| ], | |||
| } | |||
| return ( | |||
| <ReactECharts option={option} style={{ height: 380, width: 430 }} /> | |||
| ) | |||
| } | |||
| type IHitDetailProps = { | |||
| segInfo?: Partial<SegmentDetailModel> & { id: string } | |||
| vectorInfo?: { curr: Array<number[]>; points: Array<number[]> } | |||
| } | |||
| const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { | |||
| const HitDetail: FC<IHitDetailProps> = ({ segInfo }) => { | |||
| const { t } = useTranslation() | |||
| const renderContent = () => { | |||
| @@ -65,8 +30,8 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { | |||
| } | |||
| return ( | |||
| <div className='flex flex-row overflow-x-auto'> | |||
| <div className="flex-1 bg-gray-25 p-6 min-w-[300px]"> | |||
| <div className='overflow-x-auto'> | |||
| <div className="bg-gray-25 p-6"> | |||
| <div className="flex items-center"> | |||
| <SegmentIndexTag | |||
| positionId={segInfo?.position || ''} | |||
| @@ -94,20 +59,6 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { | |||
| })} | |||
| </div> | |||
| </div> | |||
| <div className="flex-1 bg-white p-6"> | |||
| <div className="flex items-center"> | |||
| <div className={cn(s.commonIcon, s.bezierCurveIcon)} /> | |||
| <span className={s.numberInfo}> | |||
| {t('datasetDocuments.segment.vectorHash')} | |||
| </span> | |||
| </div> | |||
| <div | |||
| className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')} | |||
| > | |||
| {segInfo?.index_node_hash} | |||
| </div> | |||
| <ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} /> | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useEffect, useMemo, useState } from 'react' | |||
| import React, { useEffect, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import useSWR from 'swr' | |||
| import { omit } from 'lodash-es' | |||
| @@ -62,8 +62,6 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => { | |||
| const total = recordsRes?.total || 0 | |||
| const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records]) | |||
| const onClickCard = (detail: HitTestingType) => { | |||
| setCurrParagraph({ paraInfo: detail, showModal: true }) | |||
| } | |||
| @@ -194,17 +192,13 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => { | |||
| </div> | |||
| </FloatRightContainer> | |||
| <Modal | |||
| className='!max-w-[960px] !p-0' | |||
| className='w-[520px] p-0' | |||
| closable | |||
| onClose={() => setCurrParagraph({ showModal: false })} | |||
| isShow={currParagraph.showModal} | |||
| > | |||
| {currParagraph.showModal && <HitDetail | |||
| segInfo={currParagraph.paraInfo?.segment} | |||
| vectorInfo={{ | |||
| curr: [[currParagraph.paraInfo?.tsne_position?.x || 0, currParagraph.paraInfo?.tsne_position.y || 0]], | |||
| points, | |||
| }} | |||
| />} | |||
| </Modal> | |||
| <Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> | |||