Просмотр исходного кода

Add graphrag (#1793)

### What problem does this PR solve?

#1594

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.9.0
Kevin Hu 1 год назад
Родитель
Сommit
152072f900
Аккаунт пользователя с таким Email не найден
74 измененных файлов: 2522 добавлений и 105 удалений
  1. 0
    0
      agent/README.md
  2. 0
    0
      agent/README_zh.md
  3. 0
    0
      agent/__init__.py
  4. 3
    3
      agent/canvas.py
  5. 0
    0
      agent/component/__init__.py
  6. 1
    1
      agent/component/answer.py
  7. 2
    4
      agent/component/arxiv.py
  8. 2
    2
      agent/component/baidu.py
  9. 2
    2
      agent/component/base.py
  10. 2
    3
      agent/component/begin.py
  11. 2
    2
      agent/component/bing.py
  12. 2
    5
      agent/component/categorize.py
  13. 1
    1
      agent/component/cite.py
  14. 2
    4
      agent/component/duckduckgo.py
  15. 1
    3
      agent/component/generate.py
  16. 2
    2
      agent/component/google.py
  17. 2
    2
      agent/component/googlescholar.py
  18. 2
    2
      agent/component/keyword.py
  19. 1
    4
      agent/component/message.py
  20. 2
    4
      agent/component/pubmed.py
  21. 1
    1
      agent/component/relevant.py
  22. 1
    1
      agent/component/retrieval.py
  23. 1
    1
      agent/component/rewrite.py
  24. 1
    6
      agent/component/switch.py
  25. 2
    2
      agent/component/wikipedia.py
  26. 0
    0
      agent/settings.py
  27. 0
    0
      agent/templates/HR_callout_zh.json
  28. 0
    0
      agent/templates/customer_service.json
  29. 0
    0
      agent/templates/general_chat_bot.json
  30. 0
    0
      agent/templates/interpreter.json
  31. 0
    0
      agent/templates/websearch_assistant.json
  32. 2
    3
      agent/test/client.py
  33. 0
    0
      agent/test/dsl_examples/categorize.json
  34. 0
    0
      agent/test/dsl_examples/customer_service.json
  35. 0
    0
      agent/test/dsl_examples/headhunter_zh.json
  36. 0
    0
      agent/test/dsl_examples/intergreper.json
  37. 0
    0
      agent/test/dsl_examples/interpreter.json
  38. 0
    0
      agent/test/dsl_examples/keyword_wikipedia_and_generate.json
  39. 0
    0
      agent/test/dsl_examples/retrieval_and_generate.json
  40. 0
    0
      agent/test/dsl_examples/retrieval_categorize_and_generate.json
  41. 0
    0
      agent/test/dsl_examples/retrieval_relevant_and_generate.json
  42. 0
    0
      agent/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json
  43. 0
    0
      agent/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json
  44. 1
    3
      api/apps/api_app.py
  45. 1
    4
      api/apps/canvas_app.py
  46. 36
    10
      api/apps/chunk_app.py
  47. 2
    1
      api/apps/dataset_api.py
  48. 1
    0
      api/db/__init__.py
  49. 3
    3
      api/db/init_data.py
  50. 8
    5
      api/db/services/dialog_service.py
  51. 2
    0
      api/db/services/task_service.py
  52. 3
    1
      api/settings.py
  53. 278
    0
      graphrag/claim_extractor.py
  54. 84
    0
      graphrag/claim_prompt.py
  55. 171
    0
      graphrag/community_report_prompt.py
  56. 135
    0
      graphrag/community_reports_extractor.py
  57. 167
    0
      graphrag/description_summary.py
  58. 78
    0
      graphrag/entity_embedding.py
  59. 212
    0
      graphrag/entity_resolution.py
  60. 74
    0
      graphrag/entity_resolution_prompt.py
  61. 319
    0
      graphrag/graph_extractor.py
  62. 121
    0
      graphrag/graph_prompt.py
  63. 160
    0
      graphrag/index.py
  64. 160
    0
      graphrag/leiden.py
  65. 137
    0
      graphrag/mind_map_extractor.py
  66. 42
    0
      graphrag/mind_map_prompt.py
  67. 109
    0
      graphrag/search.py
  68. 52
    0
      graphrag/smoke.py
  69. 74
    0
      graphrag/utils.py
  70. 30
    0
      rag/app/knowledge_graph.py
  71. 3
    0
      rag/app/naive.py
  72. 1
    1
      rag/nlp/__init__.py
  73. 18
    17
      rag/nlp/search.py
  74. 3
    2
      rag/svr/task_executor.py

graph/README.md → agent/README.md Просмотреть файл


graph/README_zh.md → agent/README_zh.md Просмотреть файл


graph/__init__.py → agent/__init__.py Просмотреть файл


graph/canvas.py → agent/canvas.py Просмотреть файл

@@ -22,9 +22,9 @@ from functools import partial

import pandas as pd

from graph.component import component_class
from graph.component.base import ComponentBase
from graph.settings import flow_logger, DEBUG
from agent.component import component_class
from agent.component.base import ComponentBase
from agent.settings import flow_logger, DEBUG


class Canvas(ABC):

graph/component/__init__.py → agent/component/__init__.py Просмотреть файл


graph/component/answer.py → agent/component/answer.py Просмотреть файл

@@ -19,7 +19,7 @@ from functools import partial

import pandas as pd

from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class AnswerParam(ComponentParamBase):

graph/component/arxiv.py → agent/component/arxiv.py Просмотреть файл

@@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from abc import ABC
from functools import partial
import arxiv
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase


class ArXivParam(ComponentParamBase):

graph/component/baidu.py → agent/component/baidu.py Просмотреть файл

@@ -19,8 +19,8 @@ from functools import partial
import pandas as pd
import requests
import re
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase
class BaiduParam(ComponentParamBase):

graph/component/base.py → agent/component/base.py Просмотреть файл

@@ -23,8 +23,8 @@ from typing import List, Dict, Tuple, Union

import pandas as pd

from graph import settings
from graph.settings import flow_logger, DEBUG
from agent import settings
from agent.settings import flow_logger, DEBUG

_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"

graph/component/begin.py → agent/component/begin.py Просмотреть файл

@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from functools import partial

import pandas as pd
from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class BeginParam(ComponentParamBase):


graph/component/bing.py → agent/component/bing.py Просмотреть файл

@@ -16,8 +16,8 @@
from abc import ABC
import requests
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase
class BingParam(ComponentParamBase):

graph/component/categorize.py → agent/component/categorize.py Просмотреть файл

@@ -14,13 +14,10 @@
# limitations under the License.
#
from abc import ABC

import pandas as pd

from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG
from agent.component import GenerateParam, Generate
from agent.settings import DEBUG


class CategorizeParam(GenerateParam):

graph/component/cite.py → agent/component/cite.py Просмотреть файл

@@ -21,7 +21,7 @@ from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class CiteParam(ComponentParamBase):

graph/component/duckduckgo.py → agent/component/duckduckgo.py Просмотреть файл

@@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from abc import ABC
from functools import partial
from duckduckgo_search import DDGS
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase


class DuckDuckGoParam(ComponentParamBase):

graph/component/generate.py → agent/component/generate.py Просмотреть файл

@@ -15,13 +15,11 @@
#
import re
from functools import partial

import pandas as pd

from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class GenerateParam(ComponentParamBase):

graph/component/google.py → agent/component/google.py Просмотреть файл

@@ -16,8 +16,8 @@
from abc import ABC
from serpapi import GoogleSearch
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase
class GoogleParam(ComponentParamBase):

graph/component/googlescholar.py → agent/component/googlescholar.py Просмотреть файл

@@ -15,8 +15,8 @@
#
from abc import ABC
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase
from scholarly import scholarly

graph/component/keyword.py → agent/component/keyword.py Просмотреть файл

@@ -17,8 +17,8 @@ import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG
from agent.component import GenerateParam, Generate
from agent.settings import DEBUG


class KeywordExtractParam(GenerateParam):

graph/component/message.py → agent/component/message.py Просмотреть файл

@@ -16,10 +16,7 @@
import random
from abc import ABC
from functools import partial

import pandas as pd

from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class MessageParam(ComponentParamBase):

graph/component/pubmed.py → agent/component/pubmed.py Просмотреть файл

@@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from abc import ABC
from functools import partial
from Bio import Entrez
import pandas as pd
import xml.etree.ElementTree as ET
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase


class PubMedParam(ComponentParamBase):

graph/component/relevant.py → agent/component/relevant.py Просмотреть файл

@@ -16,7 +16,7 @@
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from agent.component import GenerateParam, Generate
from rag.utils import num_tokens_from_string, encoder



graph/component/retrieval.py → agent/component/retrieval.py Просмотреть файл

@@ -21,7 +21,7 @@ from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class RetrievalParam(ComponentParamBase):

graph/component/rewrite.py → agent/component/rewrite.py Просмотреть файл

@@ -16,7 +16,7 @@
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from agent.component import GenerateParam, Generate


class RewriteQuestionParam(GenerateParam):

graph/component/switch.py → agent/component/switch.py Просмотреть файл

@@ -16,12 +16,7 @@
from abc import ABC

import pandas as pd

from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from graph.component.base import ComponentBase, ComponentParamBase
from agent.component.base import ComponentBase, ComponentParamBase


class SwitchParam(ComponentParamBase):

graph/component/wikipedia.py → agent/component/wikipedia.py Просмотреть файл

@@ -18,8 +18,8 @@ from abc import ABC
from functools import partial
import wikipedia
import pandas as pd
from graph.settings import DEBUG
from graph.component.base import ComponentBase, ComponentParamBase
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase


class WikipediaParam(ComponentParamBase):

graph/settings.py → agent/settings.py Просмотреть файл


graph/templates/HR_callout_zh.json → agent/templates/HR_callout_zh.json Просмотреть файл


graph/templates/customer_service.json → agent/templates/customer_service.json Просмотреть файл


graph/templates/general_chat_bot.json → agent/templates/general_chat_bot.json Просмотреть файл


graph/templates/interpreter.json → agent/templates/interpreter.json Просмотреть файл


graph/templates/websearch_assistant.json → agent/templates/websearch_assistant.json Просмотреть файл


graph/test/client.py → agent/test/client.py Просмотреть файл

@@ -16,9 +16,8 @@
import argparse
import os
from functools import partial
import readline
from graph.canvas import Canvas
from graph.settings import DEBUG
from agent.canvas import Canvas
from agent.settings import DEBUG

if __name__ == '__main__':
parser = argparse.ArgumentParser()

graph/test/dsl_examples/categorize.json → agent/test/dsl_examples/categorize.json Просмотреть файл


graph/test/dsl_examples/customer_service.json → agent/test/dsl_examples/customer_service.json Просмотреть файл


graph/test/dsl_examples/headhunter_zh.json → agent/test/dsl_examples/headhunter_zh.json Просмотреть файл


graph/test/dsl_examples/intergreper.json → agent/test/dsl_examples/intergreper.json Просмотреть файл


graph/test/dsl_examples/interpreter.json → agent/test/dsl_examples/interpreter.json Просмотреть файл


graph/test/dsl_examples/keyword_wikipedia_and_generate.json → agent/test/dsl_examples/keyword_wikipedia_and_generate.json Просмотреть файл


graph/test/dsl_examples/retrieval_and_generate.json → agent/test/dsl_examples/retrieval_and_generate.json Просмотреть файл


graph/test/dsl_examples/retrieval_categorize_and_generate.json → agent/test/dsl_examples/retrieval_categorize_and_generate.json Просмотреть файл


graph/test/dsl_examples/retrieval_relevant_and_generate.json → agent/test/dsl_examples/retrieval_relevant_and_generate.json Просмотреть файл


graph/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json → agent/test/dsl_examples/retrieval_relevant_keyword_baidu_and_generate.json Просмотреть файл


graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json → agent/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json Просмотреть файл


+ 1
- 3
api/apps/api_app.py Просмотреть файл

@@ -20,7 +20,7 @@ from datetime import datetime, timedelta
from flask import request, Response
from flask_login import login_required, current_user
from api.db import FileType, ParserType, FileSource, LLMType
from api.db import FileType, ParserType, FileSource
from api.db.db_models import APIToken, API4Conversation, Task, File
from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService
@@ -29,7 +29,6 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
@@ -38,7 +37,6 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
from itsdangerous import URLSafeTimedSerializer
from api.utils.file_utils import filename_type, thumbnail
from rag.nlp import keyword_extraction
from rag.utils.minio_conn import MINIO

+ 1
- 4
api/apps/canvas_app.py Просмотреть файл

@@ -15,15 +15,12 @@
#
import json
from functools import partial

from flask import request, Response
from flask_login import login_required, current_user

from api.db.db_models import UserCanvas
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request
from graph.canvas import Canvas
from agent.canvas import Canvas


@manager.route('/templates', methods=['GET'])

+ 36
- 10
api/apps/chunk_app.py Просмотреть файл

@@ -14,6 +14,8 @@
# limitations under the License.
#
import datetime
import json
import traceback
from flask import request
from flask_login import login_required, current_user
@@ -29,7 +31,7 @@ from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler
from api.settings import RetCode, retrievaler, kg_retrievaler
from api.utils.api_utils import get_json_result
import hashlib
import re
@@ -61,7 +63,8 @@ def list_chunk():
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get(
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
id].get(
"content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
@@ -136,11 +139,11 @@ def set():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
@@ -185,7 +188,7 @@ def switch():
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("chunk_ids","doc_id")
@validate_request("chunk_ids", "doc_id")
def rm():
req = request.json
try:
@@ -230,11 +233,11 @@ def create():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
@@ -277,9 +280,10 @@ def retrieval_test():
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
@@ -290,3 +294,25 @@ def retrieval_test():
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)
@manager.route('/knowledge_graph', methods=['GET'])
@login_required
def knowledge_graph():
doc_id = request.args["doc_id"]
req = {
"doc_ids":[doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
tenant_id = DocumentService.get_tenant_id(doc_id)
sres = retrievaler.search(req, search.index_name(tenant_id))
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
obj[ty] = json.loads(sres.field[id]["content_with_weight"])
except Exception as e:
print(traceback.format_exc(), flush=True)
return get_json_result(data=obj)

+ 2
- 1
api/apps/dataset_api.py Просмотреть файл

@@ -623,7 +623,7 @@ def doc_parse_callback(doc_id, prog=None, msg=""):
if cancel:
raise Exception("The parsing process has been cancelled!")

"""
def doc_parse(binary, doc_name, parser_name, tenant_id, doc_id):
match parser_name:
case "book":
@@ -656,6 +656,7 @@ def doc_parse(binary, doc_name, parser_name, tenant_id, doc_id):
return False

return True
"""


@manager.route("/<dataset_id>/documents/<document_id>/status", methods=["POST"])

+ 1
- 0
api/db/__init__.py Просмотреть файл

@@ -85,6 +85,7 @@ class ParserType(StrEnum):
PICTURE = "picture"
ONE = "one"
AUDIO = "audio"
KG = "knowledge_graph"
class FileSource(StrEnum):

+ 3
- 3
api/db/init_data.py Просмотреть файл

@@ -122,7 +122,7 @@ def init_llm_factory():
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph"})
## insert openai two embedding models to the current openai user.
print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
@@ -145,7 +145,7 @@ def init_llm_factory():
"""
drop table llm;
drop table llm_factories;
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio';
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
alter table knowledgebase modify avatar longtext;
alter table user modify avatar longtext;
alter table dialog modify icon longtext;
@@ -153,7 +153,7 @@ def init_llm_factory():
def add_graph_templates():
dir = os.path.join(get_project_base_directory(), "graph", "templates")
dir = os.path.join(get_project_base_directory(), "agent", "templates")
for fnm in os.listdir(dir):
try:
cnvs = json.load(open(os.path.join(dir, fnm), "r"))

+ 8
- 5
api/db/services/dialog_service.py Просмотреть файл

@@ -18,12 +18,12 @@ import json
import re
from copy import deepcopy
from api.db import LLMType
from api.db import LLMType, ParserType
from api.db.db_models import Dialog, Conversation
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler
from api.settings import chat_logger, retrievaler, kg_retrievaler
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp import keyword_extraction
from rag.nlp.search import index_name
@@ -101,6 +101,9 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
if llm_id2llm_type(dialog.llm_id) == "image2text":
@@ -138,7 +141,7 @@ def chat(dialog, messages, stream=True, **kwargs):
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
@@ -147,7 +150,7 @@ def chat(dialog, messages, stream=True, **kwargs):
#self-rag
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
@@ -179,7 +182,7 @@ def chat(dialog, messages, stream=True, **kwargs):
nonlocal prompt_config, knowledges, kwargs, kbinfos
refs = []
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retrievaler.insert_citations(answer,
answer, idx = retr.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]

+ 2
- 0
api/db/services/task_service.py Просмотреть файл

@@ -139,6 +139,8 @@ def queue_tasks(doc, bucket, name):
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] == "one":
page_size = 1000000000
if doc["parser_id"] == "knowledge_graph":
page_size = 1000000000
if not do_layout:
page_size = 1000000000
page_ranges = doc["parser_config"].get("pages")

+ 3
- 1
api/settings.py Просмотреть файл

@@ -34,6 +34,7 @@ chat_logger = getLogger("chat")
from rag.utils.es_conn import ELASTICSEARCH
from rag.nlp import search
from graphrag import search as kg_search
from api.utils import get_base_config, decrypt_database_config
API_VERSION = "v1"
@@ -131,7 +132,7 @@ IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio")
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph")
# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
@@ -204,6 +205,7 @@ PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
class CustomEnum(Enum):

+ 278
- 0
graphrag/claim_extractor.py Просмотреть файл

@@ -0,0 +1,278 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import argparse
import json
import logging
import re
import traceback
from dataclasses import dataclass
from typing import Any

import tiktoken

from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
CLAIM_MAX_GLEANINGS = 1
log = logging.getLogger(__name__)


@dataclass
class ClaimExtractorResult:
"""Claim extractor result class definition."""

output: list[dict]
source_docs: dict[str, Any]


class ClaimExtractor:
"""Claim extractor class definition."""

_llm: CompletionLLM
_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
_input_text_key: str
_input_entity_spec_key: str
_input_claim_description_key: str
_tuple_delimiter_key: str
_record_delimiter_key: str
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn

def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
input_claim_description_key: str | None = None,
input_resolved_entities_key: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
completion_delimiter_key: str | None = None,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
self._input_text_key = input_text_key or "input_text"
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._input_claim_description_key = (
input_claim_description_key or "claim_description"
)
self._input_resolved_entities_key = (
input_resolved_entities_key or "resolved_entities"
)
self._max_gleanings = (
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)

# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}

def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
) -> ClaimExtractorResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
texts = inputs[self._input_text_key]
entity_spec = str(inputs[self._input_entity_spec_key])
claim_description = inputs[self._input_claim_description_key]
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
source_doc_map = {}

prompt_args = {
self._input_entity_spec_key: entity_spec,
self._input_claim_description_key: claim_description,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
}

all_claims: list[dict] = []
for doc_index, text in enumerate(texts):
document_id = f"d{doc_index}"
try:
claims = self._process_document(prompt_args, text, doc_index)
all_claims += [
self._clean_claim(c, document_id, resolved_entities) for c in claims
]
source_doc_map[document_id] = text
except Exception as e:
log.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),
{"doc_index": doc_index, "text": text},
)
continue

return ClaimExtractorResult(
output=all_claims,
source_docs=source_doc_map,
)

def _clean_claim(
self, claim: dict, document_id: str, resolved_entities: dict
) -> dict:
# clean the parsed claims to remove any claims with status = False
obj = claim.get("object_id", claim.get("object"))
subject = claim.get("subject_id", claim.get("subject"))

# If subject or object in resolved entities, then replace with resolved entity
obj = resolved_entities.get(obj, obj)
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
claim["doc_id"] = document_id
return claim

def _process_document(
self, prompt_args: dict, doc, doc_index: int
) -> list[dict]:
record_delimiter = prompt_args.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_args.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
variables = {
self._input_text_key: doc,
**prompt_args,
}
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
results = self._llm.chat(text, [], gen_conf)
claims = results.strip().removesuffix(completion_delimiter)
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
extension = self._llm.chat("", history, gen_conf)
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)

# If this isn't the last loop, check to see if we should continue
if i >= self._max_gleanings - 1:
break

history.append({"role": "assistant", "content": extension})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._llm.chat("", history, self._loop_args)
if continuation != "YES":
break

result = self._parse_claim_tuples(claims, prompt_args)
for r in result:
r["doc_id"] = f"{doc_index}"
return result

def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
) -> list[dict[str, Any]]:
"""Parse claim tuples."""
record_delimiter = prompt_variables.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_variables.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
tuple_delimiter = prompt_variables.get(
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
)

def pull_field(index: int, fields: list[str]) -> str | None:
return fields[index].strip() if len(fields) > index else None

result: list[dict[str, Any]] = []
claims_values = (
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
)
for claim in claims_values:
claim = claim.strip().removeprefix("(").removesuffix(")")
claim = re.sub(r".*Output:", "", claim)

# Ignore the completion delimiter
if claim == completion_delimiter:
continue

claim_fields = claim.split(tuple_delimiter)
o = {
"subject_id": pull_field(0, claim_fields),
"object_id": pull_field(1, claim_fields),
"type": pull_field(2, claim_fields),
"status": pull_field(3, claim_fields),
"start_date": pull_field(4, claim_fields),
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
"doc_id": pull_field(8, claim_fields),
}
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
continue
result.append(o)
return result


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()

from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler

ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",
"claim_description": ""
}
claim = ex(info)
print(json.dumps(claim.output, ensure_ascii=False, indent=2))

+ 84
- 0
graphrag/claim_prompt.py Просмотреть файл

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

CLAIM_EXTRACTION_PROMPT = """
################
-Target activity-
################
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.

################
-Goal-
################
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.

################
-Steps-
################
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
For each claim, extract the following information:
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.

- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
- 5. If there's nothing satisfy the above requirements, just keep output empty.
- 6. When finished, output {completion_delimiter}

################
-Examples-
################
Example 1:
Entity specification: organization
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{completion_delimiter}

###########################
Example 2:
Entity specification: Company A, Person C
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{record_delimiter}
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
{completion_delimiter}

################
-Real Data-
################
Use the following input for your answer.
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output:"""


CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"

+ 171
- 0
graphrag/community_report_prompt.py Просмотреть файл

@@ -0,0 +1,171 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

COMMUNITY_REPORT_PROMPT = """
You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.

# Goal
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.

# Report Structure

The report should include the following sections:

- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.

Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
]
}}

# Grounding Rules

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."

where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.

Do not include information where the supporting evidence for it is not provided.


# Example Input
-----------
Text:

-Entities-

id,entity,description
5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March
6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza

-Relationships-

id,source,target,description
37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March
38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza
39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza
40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza
41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march
43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March

Output:
{{
"title": "Verdant Oasis Plaza and Unity March",
"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.",
"rating": 5.0,
"rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.",
"findings": [
{{
"summary": "Verdant Oasis Plaza as the central location",
"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]"
}},
{{
"summary": "Harmony Assembly's role in the community",
"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]"
}},
{{
"summary": "Unity March as a significant event",
"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]"
}},
{{
"summary": "Role of Tribune Spotlight",
"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]"
}}
]
}}


# Real Data

Use the following text for your answer. Do not make anything up in your answer.

Text:

-Entities-
{entity_df}

-Relationships-
{relation_df}

The report should include the following sections:

- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.

Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
]
}}

# Grounding Rules

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."

where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.

Do not include information where the supporting evidence for it is not provided.

Output:"""

+ 135
- 0
graphrag/community_reports_extractor.py Просмотреть файл

@@ -0,0 +1,135 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import json
import logging
import re
import traceback
from dataclasses import dataclass
from typing import Any, List

import networkx as nx
import pandas as pd

from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types

log = logging.getLogger(__name__)


@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""

output: List[str]
structured_output: List[dict]


class CommunityReportsExtractor:
"""Community reports extractor class definition."""

_llm: CompletionLLM
_extraction_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_max_report_length: int

def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_report_length: int | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500

def __call__(self, graph: nx.Graph):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []
res_dict = []
for level, comm in communities.items():
for cm_id, ents in comm.items():
weight = ents["weight"]
ents = ents["nodes"]
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)

prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"),
"relation_df": rela_df.to_csv(index_label="id")
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.5}
try:
response = self._llm.chat(text, [], gen_conf)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
print(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]): continue
response["weight"] = weight
response["entities"] = ents
except Exception as e:
print("ERROR: ", traceback.format_exc())
self._on_error(e, traceback.format_exc(), None)
continue

add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)

return CommunityReportsResult(
structured_output=res_dict,
output=res_str,
)

def _get_text_output(self, parsed_output: dict) -> str:
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])

def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")

def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")

report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"

+ 167
- 0
graphrag/description_summary.py Просмотреть файл

@@ -0,0 +1,167 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import argparse
import html
import json
import logging
import numbers
import re
import traceback
from collections.abc import Callable
from dataclasses import dataclass

from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx

from rag.utils import num_tokens_from_string

SUMMARIZE_PROMPT = """
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.

#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
Output:
"""

# Max token size for input prompts
DEFAULT_MAX_INPUT_TOKENS = 4_000
# Max token count for LLM answers
DEFAULT_MAX_SUMMARY_LENGTH = 128


@dataclass
class SummarizationResult:
"""Unipartite graph extraction result class definition."""

items: str | tuple[str, str]
description: str


class SummarizeExtractor:
"""Unipartite graph extractor class definition."""

_llm: CompletionLLM
_entity_name_key: str
_input_descriptions_key: str
_summarization_prompt: str
_on_error: ErrorHandlerFn
_max_summary_length: int
_max_input_tokens: int

def __init__(
self,
llm_invoker: CompletionLLM,
entity_name_key: str | None = None,
input_descriptions_key: str | None = None,
summarization_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_summary_length: int | None = None,
max_input_tokens: int | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._entity_name_key = entity_name_key or "entity_name"
self._input_descriptions_key = input_descriptions_key or "description_list"

self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS

def __call__(
self,
items: str | tuple[str, str],
descriptions: list[str],
) -> SummarizationResult:
"""Call method definition."""
result = ""
if len(descriptions) == 0:
result = ""
if len(descriptions) == 1:
result = descriptions[0]
else:
result = self._summarize_descriptions(items, descriptions)

return SummarizationResult(
items=items,
description=result or "",
)

def _summarize_descriptions(
self, items: str | tuple[str, str], descriptions: list[str]
) -> str:
"""Summarize descriptions into a single description."""
sorted_items = sorted(items) if isinstance(items, list) else items

# Safety check, should always be a list
if not isinstance(descriptions, list):
descriptions = [descriptions]

# Iterate over descriptions, adding all until the max input tokens is reached
usable_tokens = self._max_input_tokens - num_tokens_from_string(
self._summarization_prompt
)
descriptions_collected = []
result = ""

for i, description in enumerate(descriptions):
usable_tokens -= num_tokens_from_string(description)
descriptions_collected.append(description)

# If buffer is full, or all descriptions have been added, summarize
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
i == len(descriptions) - 1
):
# Calculate result (final or partial)
result = await self._summarize_descriptions_with_llm(
sorted_items, descriptions_collected
)

# If we go for another loop, reset values to new
if i != len(descriptions) - 1:
descriptions_collected = [result]
usable_tokens = (
self._max_input_tokens
- num_tokens_from_string(self._summarization_prompt)
- num_tokens_from_string(result)
)

return result

def _summarize_descriptions_with_llm(
self, items: str | tuple[str, str] | list[str], descriptions: list[str]
):
"""Summarize descriptions using the LLM."""
variables = {
self._entity_name_key: json.dumps(items),
self._input_descriptions_key: json.dumps(sorted(descriptions)),
}
text = perform_variable_replacements(self._summarization_prompt, variables=variables)
return self._llm.chat("", [{"role": "user", "content": text}])

+ 78
- 0
graphrag/entity_embedding.py Просмотреть файл

@@ -0,0 +1,78 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

from typing import Any

import numpy as np
import networkx as nx
from graphrag.leiden import stable_largest_connected_component


@dataclass
class NodeEmbeddings:
"""Node embeddings class definition."""

nodes: list[str]
embeddings: np.ndarray


def embed_nod2vec(
graph: nx.Graph | nx.DiGraph,
dimensions: int = 1536,
num_walks: int = 10,
walk_length: int = 40,
window_size: int = 2,
iterations: int = 3,
random_seed: int = 86,
) -> NodeEmbeddings:
"""Generate node embeddings using Node2Vec."""
# generate embedding
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
graph=graph,
dimensions=dimensions,
window_size=window_size,
iterations=iterations,
num_walks=num_walks,
walk_length=walk_length,
random_seed=random_seed,
)
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])


def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
"""Run method definition."""
if args.get("use_lcc", True):
graph = stable_largest_connected_component(graph)

# create graph embedding using node2vec
embeddings = embed_nod2vec(
graph=graph,
dimensions=args.get("dimensions", 1536),
num_walks=args.get("num_walks", 10),
walk_length=args.get("walk_length", 40),
window_size=args.get("window_size", 2),
iterations=args.get("iterations", 3),
random_seed=args.get("random_seed", 86),
)

pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
sorted_pairs = sorted(pairs, key=lambda x: x[0])

return dict(sorted_pairs)

+ 212
- 0
graphrag/entity_resolution.py Просмотреть файл

@@ -0,0 +1,212 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import traceback
from dataclasses import dataclass
from typing import Any

import networkx as nx
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"


@dataclass
class EntityResolutionResult:
"""Entity resolution result class definition."""

output: nx.Graph


class EntityResolution:
"""Entity resolution class definition."""

_llm: CompletionLLM
_resolution_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_record_delimiter_key: str
_entity_index_delimiter_key: str
_resolution_result_delimiter_key: str

def __init__(
self,
llm_invoker: CompletionLLM,
resolution_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
record_delimiter_key: str | None = None,
entity_index_delimiter_key: str | None = None,
resolution_result_delimiter_key: str | None = None,
input_text_key: str | None = None
):
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
self._input_text_key = input_text_key or "input_text"

def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}

# Wire defaults into the prompt variables
prompt_variables = {
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
or DEFAULT_ENTITY_INDEX_DELIMITER,
self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
}

nodes = graph.nodes
entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}

for node in nodes:
node_clusters[graph.nodes[node]['entity_type']].append(node)

candidate_resolution = {entity_type: [] for entity_type in entity_types}
for node_cluster in node_clusters.items():
candidate_resolution_tmp = []
for a in node_cluster[1]:
for b in node_cluster[1]:
if a == b:
continue
if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp:
candidate_resolution_tmp.append((a, b))
if candidate_resolution_tmp:
candidate_resolution[node_cluster[0]] = candidate_resolution_tmp

gen_conf = {"temperature": 0.5}
resolution_result = set()
for candidate_resolution_i in candidate_resolution.items():
if candidate_resolution_i[1]:
try:
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)

variables = {
**prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)

response = self._llm.chat(text, [], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception as e:
logging.exception("error entity resolution")
self._on_error(e, traceback.format_exc(), None)

connect_graph = nx.Graph()
connect_graph.add_edges_from(resolution_result)
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
for remove_node in remove_nodes:
remove_node_neighbors = graph[remove_node]
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
if remove_node_neighbor == keep_node:
graph.remove_edge(keep_node, remove_node)
continue
if graph.has_edge(keep_node, remove_node_neighbor):
graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
'weight']
graph[keep_node][remove_node_neighbor]['description'] += \
graph[remove_node][remove_node_neighbor]['description']
graph.remove_edge(remove_node, remove_node_neighbor)
else:
graph.add_edge(keep_node, remove_node_neighbor,
weight=graph[remove_node][remove_node_neighbor]['weight'],
description=graph[remove_node][remove_node_neighbor]['description'],
source_id="")
graph.remove_edge(remove_node, remove_node_neighbor)
graph.remove_node(remove_node)

for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])

return EntityResolutionResult(
output=graph,
)

def _process_results(
self,
records_length: int,
results: str,
record_delimiter: str,
entity_index_delimiter: str,
resolution_result_delimiter: str
) -> list:
ans_list = []
records = [r.strip() for r in results.split(record_delimiter)]
for record in records:
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
match_int = re.search(pattern_int, record)
res_int = int(str(match_int.group(1) if match_int else '0'))
if res_int > records_length:
continue

pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
match_bool = re.search(pattern_bool, record)
res_bool = str(match_bool.group(1) if match_bool else '')

if res_int and res_bool:
if res_bool.lower() == 'yes':
ans_list.append((res_int, "yes"))

return ans_list

def is_similarity(self, a, b):
if is_english(a) and is_english(b):
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
return True

if len(set(a) & set(b)) > 0:
return True

return False

+ 74
- 0
graphrag/entity_resolution_prompt.py Просмотреть файл

@@ -0,0 +1,74 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

ENTITY_RESOLUTION_PROMPT = """
-Goal-
Please answer the following Question as required

-Steps-
1. Identify each line of questioning as required

2. Return output in English as a single list of each line answer in steps 1. Use **{record_delimiter}** as the list delimiter.

######################
-Examples-
######################
Example 1:

Question:
When determining whether two Products are the same, you should only focus on critical properties and overlook noisy factors.

Demonstration 1: name of Product A is : "computer", name of Product B is :"phone" No, Product A and Product B are different products.
Question 1: name of Product A is : "television", name of Product B is :"TV"
Question 2: name of Product A is : "cup", name of Product B is :"mug"
Question 3: name of Product A is : "soccer", name of Product B is :"football"
Question 4: name of Product A is : "pen", name of Product B is :"eraser"

Use domain knowledge of Products to help understand the text and answer the above 4 questions in the format: For Question i, Yes, Product A and Product B are the same product. or No, Product A and Product B are different products. For Question i+1, (repeat the above procedures)
################
Output:
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, Product A and Product B are the same product.){record_delimiter}
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
#############################

Example 2:

Question:
When determining whether two toponym are the same, you should only focus on critical properties and overlook noisy factors.

Demonstration 1: name of toponym A is : "nanjing", name of toponym B is :"nanjing city" No, toponym A and toponym B are same toponym.
Question 1: name of toponym A is : "Chicago", name of toponym B is :"ChiTown"
Question 2: name of toponym A is : "Shanghai", name of toponym B is :"Zhengzhou"
Question 3: name of toponym A is : "Beijing", name of toponym B is :"Peking"
Question 4: name of toponym A is : "Los Angeles", name of toponym B is :"Cleveland"

Use domain knowledge of toponym to help understand the text and answer the above 4 questions in the format: For Question i, Yes, toponym A and toponym B are the same toponym. or No, toponym A and toponym B are different toponym. For Question i+1, (repeat the above procedures)
################
Output:
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are same toponym.){record_delimiter}
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are the same toponym.){record_delimiter}
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
#############################

-Real Data-
######################
Question:{input_text}
######################
Output:
"""

+ 319
- 0
graphrag/graph_extractor.py Просмотреть файл

@@ -0,0 +1,319 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import numbers
import re
import traceback
from dataclasses import dataclass
from typing import Any, Mapping
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string

DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 1


@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""

output: nx.Graph
source_docs: dict[Any, Any]


class GraphExtractor:
"""Unipartite graph extractor class definition."""

_llm: CompletionLLM
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn

def __init__(
self,
llm_invoker: CompletionLLM,
prompt: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
join_descriptions=True,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)

# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}

def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
all_records: dict[int, str] = {}
source_doc_map: dict[int, str] = {}

# Wire defaults into the prompt variables
prompt_variables = {
**prompt_variables,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(
prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
),
}

for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
except Exception as e:
logging.exception("error extracting graph")
self._on_error(
e,
traceback.format_exc(),
{
"doc_index": doc_index,
"text": text,
},
)

output = self._process_results(
all_records,
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
)

return GraphExtractionResult(
output=output,
source_docs=source_doc_map,
)

def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
variables = {
**prompt_variables,
self._input_text_key: text,
}
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
response = self._llm.chat(text, [], gen_conf)

results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]

# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf)
results += response or ""

# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._llm.chat("", history, self._loop_args)
if continuation != "YES":
break

return results

def _process_results(
self,
results: dict[int, str],
tuple_delimiter: str,
record_delimiter: str,
) -> nx.Graph:
"""Parse the result string to create an undirected unipartite graph.

Args:
- results - dict of results from the extraction chain
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
- record_delimiter - delimiter between records, default is '##'
Returns:
- output - unipartite graph in graphML format
"""
graph = nx.Graph()
for source_doc_id, extracted_data in results.items():
records = [r.strip() for r in extracted_data.split(record_delimiter)]

for record in records:
record = re.sub(r"^\(|\)$", "", record.strip())
record_attributes = record.split(tuple_delimiter)

if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])

if entity_name in graph.nodes():
node = graph.nodes[entity_name]
if self._join_descriptions:
node["description"] = "\n".join(
list({
*_unpack_descriptions(node),
entity_description,
})
)
else:
if len(entity_description) > len(node["description"]):
node["description"] = entity_description
node["source_id"] = ", ".join(
list({
*_unpack_source_ids(node),
str(source_doc_id),
})
)
node["entity_type"] = (
entity_type if entity_type != "" else node["entity_type"]
)
else:
graph.add_node(
entity_name,
entity_type=entity_type,
description=entity_description,
source_id=str(source_doc_id),
weight=1
)

if (
record_attributes[0] == '"relationship"'
and len(record_attributes) >= 5
):
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = clean_str(str(source_doc_id))
weight = (
float(record_attributes[-1])
if isinstance(record_attributes[-1], numbers.Number)
else 1.0
)
if source not in graph.nodes():
graph.add_node(
source,
entity_type="",
description="",
source_id=edge_source_id,
weight=1
)
if target not in graph.nodes():
graph.add_node(
target,
entity_type="",
description="",
source_id=edge_source_id,
weight=1
)
if graph.has_edge(source, target):
edge_data = graph.get_edge_data(source, target)
if edge_data is not None:
weight += edge_data["weight"]
if self._join_descriptions:
edge_description = "\n".join(
list({
*_unpack_descriptions(edge_data),
edge_description,
})
)
edge_source_id = ", ".join(
list({
*_unpack_source_ids(edge_data),
str(source_doc_id),
})
)
graph.add_edge(
source,
target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)

for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return graph


def _unpack_descriptions(data: Mapping) -> list[str]:
value = data.get("description", None)
return [] if value is None else value.split("\n")


def _unpack_source_ids(data: Mapping) -> list[str]:
value = data.get("source_id", None)
return [] if value is None else value.split(", ")




+ 121
- 0
graphrag/graph_prompt.py Просмотреть файл

@@ -0,0 +1,121 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
GRAPH_EXTRACTION_PROMPT = """
-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.

-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [{entity_types}]
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>

2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)

3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.

4. When finished, output {completion_delimiter}

######################
-Examples-
######################
Example 1:

Entity_types: [person, technology, mission, organization, location]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.

Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”

The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.

It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter}
#############################
Example 2:

Entity_types: [person, technology, mission, organization, location]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.

Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.

Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter}
#############################
Example 3:

Entity_types: [person, role, technology, organization, event, location, concept]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.

"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."

Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."

Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.

The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:"""

CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"

+ 160
- 0
graphrag/index.py Просмотреть файл

@@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from concurrent.futures import ThreadPoolExecutor
import json
from functools import reduce
from typing import List
import networkx as nx
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graphrag.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.graph_extractor import GraphExtractor
from graphrag.mind_map_extractor import MindMapExtractor
from rag.nlp import rag_tokenizer
from rag.utils import num_tokens_from_string


def be_children(obj: dict):
arr = []
for k,v in obj.items():
k = re.sub(r"\*+", "", k)
if not k :continue
arr.append({
"id": k,
"children": be_children(v) if isinstance(v, dict) else []
})
return arr


def graph_merge(g1, g2):
g = g2.copy()
for n, attr in g1.nodes(data=True):
if n not in g2.nodes():
g2.add_node(n, **attr)
continue

g.nodes[n]["weight"] += 1
if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
g.nodes[n]["description"] += "\n" + attr["description"]

for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr["weight"]+1})
continue
g.add_edge(source, target, **attr)

for node_degree in g.degree:
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return g


def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=["organization", "person", "location", "event", "time"]):
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT)
ext = GraphExtractor(llm_bdl)
left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
left_token_count = llm_bdl.max_length * 0.4

assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"

texts, graphs = [], []
cnt = 0
threads = []
exe = ThreadPoolExecutor(max_workers=12)
for i in range(len(chunks[:512])):
tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts:
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
texts = []
cnt = 0
texts.append(chunks[i])
cnt += tkn_cnt
if texts:
threads.append(exe.submit(ext, texts))

callback(0.5, "Extracting entities.")
graphs = []
for i, _ in enumerate(threads):
graphs.append(_.result().output)
callback(0.5 + 0.1*i/len(threads))

graph = reduce(graph_merge, graphs)
er = EntityResolution(llm_bdl)
graph = er(graph).output

_chunks = chunks
chunks = []
for n, attr in graph.nodes(data=True):
if attr.get("rank", 0) == 0:
print(f"Ignore entity: {n}")
continue
chunk = {
"name_kwd": n,
"important_kwd": [n],
"title_tks": rag_tokenizer.tokenize(n),
"content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(attr["description"]),
"knowledge_graph_kwd": "entity",
"rank_int": attr["rank"],
"weight_int": attr["weight"]
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)

callback(0.6, "Extracting community reports.")
cr = CommunityReportsExtractor(llm_bdl)
cr = cr(graph)
for community, desc in zip(cr.structured_output, cr.output):
chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]),
"content_with_weight": desc,
"content_ltks": rag_tokenizer.tokenize(desc),
"knowledge_graph_kwd": "community_report",
"weight_flt": community["weight"],
"entities_kwd": community["entities"],
"important_kwd": community["entities"]
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)

chunks.append(
{
"content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
"knowledge_graph_kwd": "graph"
})

callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks

if len(mg.keys()) > 1: md_map = {"id": "root", "children": [{"id": re.sub(r"\*+", "", k), "children": be_children(v)} for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]}
else: md_map = {"id": re.sub(r"\*+", "", list(mg.keys())[0]), "children": be_children(list(mg.items())[1])}
print(json.dumps(md_map, ensure_ascii=False, indent=2))
chunks.append(
{
"content_with_weight": json.dumps(md_map, ensure_ascii=False, indent=2),
"knowledge_graph_kwd": "mind_map"
})

return chunks







+ 160
- 0
graphrag/leiden.py Просмотреть файл

@@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import logging
from typing import Any, cast, List
import html
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component

import networkx as nx

log = logging.getLogger(__name__)


def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Ensure an undirected graph with the same relationships will always be read the same way."""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()

sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])

fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))

# If the graph is undirected, we create the edges in a stable way, so we get the same results
# for example:
# A -> B
# in graph theory is the same as
# B -> A
# in an undirected graph
# however, this can lead to downstream issues because sometimes
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
# but they base some of their logic on the order of the nodes, so the order ends up being important
# so we sort the nodes in the edge in a stable way, so that we always get the same order
if not graph.is_directed():

def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data

edges = [_sort_source_target(edge) for edge in edges]

def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"

edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))

fixed_graph.add_edges_from(edges)
return fixed_graph


def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
"""Normalize node names."""
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
return nx.relabel_nodes(graph, node_mapping)


def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
graph = normalize_node_names(graph)
return _stabilize_graph(graph)


def _compute_leiden_communities(
graph: nx.Graph | nx.DiGraph,
max_cluster_size: int,
use_lcc: bool,
seed=0xDEADBEEF,
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
if use_lcc:
graph = stable_largest_connected_component(graph)

community_mapping = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=seed
)
results: dict[int, dict[str, int]] = {}
for partition in community_mapping:
results[partition.level] = results.get(partition.level, {})
results[partition.level][partition.node] = partition.cluster

return results


def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
"""Run method definition."""
max_cluster_size = args.get("max_cluster_size", 12)
use_lcc = args.get("use_lcc", True)
if args.get("verbose", False):
log.info(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes(): return {}

node_id_to_community_map = _compute_leiden_communities(
graph=graph,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=args.get("seed", 0xDEADBEEF),
)
levels = args.get("levels")

# If they don't pass in levels, use them all
if levels is None:
levels = sorted(node_id_to_community_map.keys())

results_by_level: dict[int, dict[str, list[str]]] = {}
for level in levels:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
community_id = str(raw_community_id)
if community_id not in result:
result[community_id] = {"weight": 0, "nodes": []}
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue
max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight

return results_by_level


def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]):
for lev, cluster_info in commu_info.items():
for cid, nodes in cluster_info.items():
for n in nodes["nodes"]:
if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {}
graph.nodes[n]["community"].update({lev: cid})


def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
for n in nodes:
if "communities" not in graph.nodes[n]:
graph.nodes[n]["communities"] = []
graph.nodes[n]["communities"].append(community_title)

+ 137
- 0
graphrag/mind_map_extractor.py Просмотреть файл

@@ -0,0 +1,137 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any

from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json
from functools import reduce
from rag.utils import num_tokens_from_string


@dataclass
class MindMapResult:
"""Unipartite Mind Graph result class definition."""
output: dict


class MindMapExtractor:

_llm: CompletionLLM
_input_text_key: str
_mind_map_prompt: str
_on_error: ErrorHandlerFn

def __init__(
self,
llm_invoker: CompletionLLM,
prompt: str | None = None,
input_text_key: str | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._input_text_key = input_text_key or "input_text"
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)

def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}

try:
exe = ThreadPoolExecutor(max_workers=12)
threads = []
token_count = self._llm.max_length * 0.7
texts = []
res = []
cnt = 0
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))

for i, _ in enumerate(threads):
res.append(_.result())

merge_json = reduce(self._merge, res)
merge_json = self._list_to_kv(merge_json)
except Exception as e:
logging.exception("error mind graph")
self._on_error(
e,
traceback.format_exc(), None
)

return MindMapResult(output=merge_json)

def _merge(self, d1, d2):
for k in d1:
if k in d2:
if isinstance(d1[k], dict) and isinstance(d2[k], dict):
self._merge(d1[k], d2[k])
elif isinstance(d1[k], list) and isinstance(d2[k], list):
d2[k].extend(d1[k])
else:
d2[k] = d1[k]
else:
d2[k] = d1[k]

return d2

def _list_to_kv(self, data):
for key, value in data.items():
if isinstance(value, dict):
self._list_to_kv(value)
elif isinstance(value, list):
new_value = {}
for i in range(len(value)):
if isinstance(value[i], list):
new_value[value[i - 1]] = value[i][0]
data[key] = new_value
else:
continue
return data

def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
variables = {
**prompt_variables,
self._input_text_key: text,
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
response = self._llm.chat(text, [], gen_conf)
print(response)
print("---------------------------------------------------\n", markdown_to_json.dictify(response))
return dict(markdown_to_json.dictify(response))

+ 42
- 0
graphrag/mind_map_prompt.py Просмотреть файл

@@ -0,0 +1,42 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
MIND_MAP_EXTRACTION_PROMPT = """
- Role: You're a talent text processor.

- Step of task:
1. Generate a title for user's 'TEXT'。
2. Classify the 'TEXT' into sections as you see fit.
3. If the subject matter is really complex, split them into sub-sections.

- Output requirement:
- Always try to maximize the number of sub-sections.
- In language of
- MUST IN FORMAT OF MARKDOWN
Output:
## <Title>
<Section Name>
<Section Name>
<Subsection Name>
<Subsection Name>
<Section Name>
<Subsection Name>
-TEXT-
{input_text}

Output:
"""

+ 109
- 0
graphrag/search.py Просмотреть файл

@@ -0,0 +1,109 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from copy import deepcopy

import pandas as pd
from elasticsearch_dsl import Q, Search

from rag.nlp.search import Dealer


class KGSearch(Dealer):
def search(self, req, idxnm, emb_mdl=None):
def merge_into_first(sres, title=""):
df,texts = [],[]
for d in sres["hits"]["hits"]:
try:
df.append(json.loads(d["_source"]["content_with_weight"]))
except Exception as e:
texts.append(d["_source"]["content_with_weight"])
pass
if not df and not texts: return False
if df:
try:
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
except Exception as e:
pass
else:
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
return True

src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
"weight_int", "weight_flt", "rank_int"
])

qst = req.get("question", "")
binary_query, keywords = self.qryr.question(qst, min_match="5%")
binary_query = self._add_filters(binary_query, req)

## Entity retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
s = Search()
s = s.query(bqry)[0: 32]

s = s.to_dict()
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.1), 1024)
s["knn"]["filter"] = bqry.to_dict()
q_vec = s["knn"]["query_vector"]

ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
ent_ids = self.es.getDocIds(ent_res)
if merge_into_first(ent_res, "-Entities-"):
ent_ids = ent_ids[0:1]

## Community retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", entities_kwd=entities))
bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
s = Search()
s = s.query(bqry)[0: 32]
s = s.to_dict()
comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
comm_ids = self.es.getDocIds(comm_res)
if merge_into_first(comm_res, "-Community Report-"):
comm_ids = comm_ids[0:1]

## Text content retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
s = Search()
s = s.query(bqry)[0: 6]
s = s.to_dict()
txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
txt_ids = self.es.getDocIds(comm_res)
if merge_into_first(txt_res, "-Original Content-"):
txt_ids = comm_ids[0:1]

return self.SearchResult(
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
ids=[*ent_ids, *comm_ids, *txt_ids],
query_vector=q_vec,
aggregation=None,
highlight=None,
field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
keywords=[]
)


+ 52
- 0
graphrag/smoke.py Просмотреть файл

@@ -0,0 +1,52 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import json
from graphrag import leiden
from graphrag.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.graph_extractor import GraphExtractor
from graphrag.leiden import add_community_info2graph

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()

from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler

ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
graph = ex(docs)

er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
graph = er(graph.output)

comm = leiden.run(graph.output, {})
add_community_info2graph(graph.output, comm)

# print(json.dumps(nx.node_link_data(graph.output), ensure_ascii=False,indent=2))
print(json.dumps(comm, ensure_ascii=False, indent=2))

cr = CommunityReportsExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
cr = cr(graph.output)
print("------------------ COMMUNITY REPORT ----------------------\n", cr.output)
print(json.dumps(cr.structured_output, ensure_ascii=False, indent=2))

+ 74
- 0
graphrag/utils.py Просмотреть файл

@@ -0,0 +1,74 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""

import html
import re
from collections.abc import Callable
from typing import Any

ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]


def perform_variable_replacements(
input: str, history: list[dict]=[], variables: dict | None ={}
) -> str:
"""Perform variable replacements on the input string and in a chat log."""
result = input

def replace_all(input: str) -> str:
result = input
if variables:
for entry in variables:
result = result.replace(f"{{{entry}}}", variables[entry])
return result

result = replace_all(result)
for i in range(len(history)):
entry = history[i]
if entry.get("role") == "system":
history[i]["content"] = replace_all(entry.get("content") or "")

return result


def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input

result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)


def dict_has_keys_with_types(
data: dict, expected_fields: list[tuple[str, type]]
) -> bool:
"""Return True if the given dictionary has the given keys with the given types."""
for field, field_type in expected_fields:
if field not in data:
return False

value = data[field]
if not isinstance(value, field_type):
return False
return True


+ 30
- 0
rag/app/knowledge_graph.py Просмотреть файл

@@ -0,0 +1,30 @@
import re

from graphrag.index import build_knowlege_graph_chunks
from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize_chunks


def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": False})
eng = lang.lower() == "english"

parser_config["layout_recognize"] = False
sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True, parser_config=parser_config)
chunks = build_knowlege_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
)
for c in chunks: c["docnm_kwd"] = filename

doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
"knowledge_graph_kwd": "text"
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
chunks.extend(tokenize_chunks(sections, doc, eng))

return chunks

+ 3
- 0
rag/app/naive.py Просмотреть файл

@@ -273,6 +273,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
raise NotImplementedError(
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
if kwargs.get("section_only", False):
return [t for t, _ in sections]
st = timer()
chunks = naive_merge(
sections, int(parser_config.get(

+ 1
- 1
rag/nlp/__init__.py Просмотреть файл

@@ -228,7 +228,7 @@ def tokenize(d, t, eng):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser):
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ck in chunks:

+ 18
- 17
rag/nlp/search.py Просмотреть файл

@@ -64,24 +64,25 @@ class Dealer:
"query_vector": [float(v) for v in qv]
}

def _add_filters(self, bqry, req):
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
if req.get("doc_ids"):
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
if req.get("knowledge_graph_kwd"):
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
if "available_int" in req:
if req["available_int"] == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry

def search(self, req, idxnm, emb_mdl=None):
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst)
def add_filters(bqry):
nonlocal req
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
if req.get("doc_ids"):
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
if "available_int" in req:
if req["available_int"] == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry

bqry = add_filters(bqry)
bqry = self._add_filters(bqry, req)
bqry.boost = 0.05

s = Search()
@@ -89,7 +90,7 @@ class Dealer:
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])

s = s.query(bqry)[pg * ps:(pg + 1) * ps]
@@ -137,7 +138,7 @@ class Dealer:
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
bqry = add_filters(bqry)
bqry = self._add_filters(bqry)
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17

+ 3
- 2
rag/svr/task_executor.py Просмотреть файл

@@ -45,7 +45,7 @@ from rag.nlp import search, rag_tokenizer
from io import BytesIO
import pandas as pd

from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph

from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@@ -68,7 +68,8 @@ FACTORY = {
ParserType.RESUME.value: resume,
ParserType.PICTURE.value: picture,
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio
ParserType.AUDIO.value: audio,
ParserType.KG.value: knowledge_graph
}



Загрузка…
Отмена
Сохранить