Browse Source

less text, better extraction (#1869)

### What problem does this PR solve?

#1861

### Type of change

- [x] Refactoring
tags/v0.10.0
Kevin Hu 1 year ago
parent
commit
db8f83104f
No account linked to committer's email address
1 changed files with 6 additions and 5 deletions
  1. 6
    5
      graphrag/index.py

+ 6
- 5
graphrag/index.py View File

llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
ext = GraphExtractor(llm_bdl) ext = GraphExtractor(llm_bdl)
left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024 left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
left_token_count = max(llm_bdl.max_length * 0.8, left_token_count)
left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)


assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})" assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"


BATCH_SIZE=1
texts, graphs = [], [] texts, graphs = [], []
cnt = 0 cnt = 0
threads = [] threads = []
for i in range(len(chunks)): for i in range(len(chunks)):
tkn_cnt = num_tokens_from_string(chunks[i]) tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts: if cnt+tkn_cnt >= left_token_count and texts:
for b in range(0, len(texts), 16):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback))
for b in range(0, len(texts), BATCH_SIZE):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
texts = [] texts = []
cnt = 0 cnt = 0
texts.append(chunks[i]) texts.append(chunks[i])
cnt += tkn_cnt cnt += tkn_cnt
if texts: if texts:
for b in range(0, len(texts), 16):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback))
for b in range(0, len(texts), BATCH_SIZE):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))


callback(0.5, "Extracting entities.") callback(0.5, "Extracting entities.")
graphs = [] graphs = []

Loading…
Cancel
Save