選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import html
  8. import json
  9. import re
  10. from typing import Any, Callable
  11. import numpy as np
  12. import xxhash
  13. from rag.utils.redis_conn import REDIS_CONN
  14. ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
  15. def perform_variable_replacements(
  16. input: str, history: list[dict] | None = None, variables: dict | None = None
  17. ) -> str:
  18. """Perform variable replacements on the input string and in a chat log."""
  19. if history is None:
  20. history = []
  21. if variables is None:
  22. variables = {}
  23. result = input
  24. def replace_all(input: str) -> str:
  25. result = input
  26. for k, v in variables.items():
  27. result = result.replace(f"{{{k}}}", v)
  28. return result
  29. result = replace_all(result)
  30. for i, entry in enumerate(history):
  31. if entry.get("role") == "system":
  32. entry["content"] = replace_all(entry.get("content") or "")
  33. return result
  34. def clean_str(input: Any) -> str:
  35. """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
  36. # If we get non-string input, just give it back
  37. if not isinstance(input, str):
  38. return input
  39. result = html.unescape(input.strip())
  40. # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
  41. return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
  42. def dict_has_keys_with_types(
  43. data: dict, expected_fields: list[tuple[str, type]]
  44. ) -> bool:
  45. """Return True if the given dictionary has the given keys with the given types."""
  46. for field, field_type in expected_fields:
  47. if field not in data:
  48. return False
  49. value = data[field]
  50. if not isinstance(value, field_type):
  51. return False
  52. return True
  53. def get_llm_cache(llmnm, txt, history, genconf):
  54. hasher = xxhash.xxh64()
  55. hasher.update(str(llmnm).encode("utf-8"))
  56. hasher.update(str(txt).encode("utf-8"))
  57. hasher.update(str(history).encode("utf-8"))
  58. hasher.update(str(genconf).encode("utf-8"))
  59. k = hasher.hexdigest()
  60. bin = REDIS_CONN.get(k)
  61. if not bin:
  62. return
  63. return bin
  64. def set_llm_cache(llmnm, txt, v: str, history, genconf):
  65. hasher = xxhash.xxh64()
  66. hasher.update(str(llmnm).encode("utf-8"))
  67. hasher.update(str(txt).encode("utf-8"))
  68. hasher.update(str(history).encode("utf-8"))
  69. hasher.update(str(genconf).encode("utf-8"))
  70. k = hasher.hexdigest()
  71. REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
  72. def get_embed_cache(llmnm, txt):
  73. hasher = xxhash.xxh64()
  74. hasher.update(str(llmnm).encode("utf-8"))
  75. hasher.update(str(txt).encode("utf-8"))
  76. k = hasher.hexdigest()
  77. bin = REDIS_CONN.get(k)
  78. if not bin:
  79. return
  80. return np.array(json.loads(bin))
  81. def set_embed_cache(llmnm, txt, arr):
  82. hasher = xxhash.xxh64()
  83. hasher.update(str(llmnm).encode("utf-8"))
  84. hasher.update(str(txt).encode("utf-8"))
  85. k = hasher.hexdigest()
  86. arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
  87. REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)