Bladeren bron

fix: escape double quotation marks in the vector DB search query (#6506)

tags/0.6.15
Sangmin Ahn 1 jaar geleden
bovenliggende
commit
093b8ca475
No account linked to committer's email address
2 gewijzigde bestanden met toevoegingen van 12 en 4 verwijderingen
  1. 7
    3
      api/core/rag/datasource/retrieval_service.py
  2. 5
    1
      api/services/hit_testing_service.py

+ 7
- 3
api/core/rag/datasource/retrieval_service.py Bestand weergeven

) )


documents = keyword.search( documents = keyword.search(
query,
cls.escape_query_for_search(query),
top_k=top_k top_k=top_k
) )
all_documents.extend(documents) all_documents.extend(documents)
) )


documents = vector.search_by_vector( documents = vector.search_by_vector(
query,
cls.escape_query_for_search(query),
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
top_k=top_k, top_k=top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
) )


documents = vector_processor.search_by_full_text( documents = vector_processor.search_by_full_text(
query,
cls.escape_query_for_search(query),
top_k=top_k top_k=top_k
) )
if documents: if documents:
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
exceptions.append(str(e)) exceptions.append(str(e))

@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

+ 5
- 1
api/services/hit_testing_service.py Bestand weergeven



all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id, dataset_id=dataset.id,
query=query,
query=cls.escape_query_for_search(query),
top_k=retrieval_model['top_k'], top_k=retrieval_model['top_k'],
score_threshold=retrieval_model['score_threshold'] score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None, if retrieval_model['score_threshold_enabled'] else None,


if not query or len(query) > 250: if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters') raise ValueError('Query is required and cannot exceed 250 characters')

@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

Laden…
Annuleren
Opslaan