### What problem does this PR solve? Some models force thinking, resulting in the absence of the think tag in the returned content ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.19.0
| component_name = "ExeSQL" | component_name = "ExeSQL" | ||||
| def _refactor(self, ans): | def _refactor(self, ans): | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL) | match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL) | ||||
| if match: | if match: | ||||
| ans = match.group(1) # Query content | ans = match.group(1) # Query content |
| if len(msg) < 2: | if len(msg) < 2: | ||||
| msg.append({"role": "user", "content": "Output: "}) | msg.append({"role": "user", "content": "Output: "}) | ||||
| ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) | ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| if self._param.cite and "chunks" in retrieval_res.columns: | if self._param.cite and "chunks" in retrieval_res.columns: | ||||
| res = self.set_cite(retrieval_res, ans) | res = self.set_cite(retrieval_res, ans) |
| ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}], | ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}], | ||||
| self._param.gen_conf()) | self._param.gen_conf()) | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r".*keyword:", "", ans).strip() | ans = re.sub(r".*keyword:", "", ans).strip() | ||||
| logging.debug(f"ans: {ans}") | logging.debug(f"ans: {ans}") | ||||
| return KeywordExtract.be_output(ans) | return KeywordExtract.be_output(ans) |
| msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" | msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" | ||||
| for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): | for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| if not ans: | if not ans: | ||||
| continue | continue | ||||
| query_think = ans | query_think = ans | ||||
| [{"role": "user", | [{"role": "user", | ||||
| "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], | "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], | ||||
| {"temperature": 0.7}): | {"temperature": 0.7}): | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| if not ans: | if not ans: | ||||
| continue | continue | ||||
| summary_think = ans | summary_think = ans |
| answer = "" | answer = "" | ||||
| for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): | for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): | ||||
| if thought: | if thought: | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| answer = ans | answer = ans | ||||
| delta_ans = ans[len(last_ans) :] | delta_ans = ans[len(last_ans) :] | ||||
| if num_tokens_from_string(delta_ans) < 16: | if num_tokens_from_string(delta_ans) < 16: | ||||
| def get_table(): | def get_table(): | ||||
| nonlocal sys_prompt, user_prompt, question, tried_times | nonlocal sys_prompt, user_prompt, question, tried_times | ||||
| sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) | sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) | ||||
| sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL) | |||||
| sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL) | |||||
| logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") | logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") | ||||
| sql = re.sub(r"[\r\n]+", " ", sql.lower()) | sql = re.sub(r"[\r\n]+", " ", sql.lower()) | ||||
| sql = re.sub(r".*select ", "select ", sql.lower()) | sql = re.sub(r".*select ", "select ", sql.lower()) |
| return response | return response | ||||
| _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) | _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) | ||||
| response = self._llm.chat(system_msg[0]["content"], hist, conf) | response = self._llm.chat(system_msg[0]["content"], hist, conf) | ||||
| response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) | |||||
| response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) | |||||
| if response.find("**ERROR**") >= 0: | if response.find("**ERROR**") >= 0: | ||||
| logging.warning(f"Extractor._chat got error. response: {response}") | logging.warning(f"Extractor._chat got error. response: {response}") | ||||
| return "" | return "" |
| kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | ||||
| if isinstance(kwd, tuple): | if isinstance(kwd, tuple): | ||||
| kwd = kwd[0] | kwd = kwd[0] | ||||
| kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL) | |||||
| kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) | |||||
| if kwd.find("**ERROR**") >= 0: | if kwd.find("**ERROR**") >= 0: | ||||
| return "" | return "" | ||||
| return kwd | return kwd | ||||
| kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | ||||
| if isinstance(kwd, tuple): | if isinstance(kwd, tuple): | ||||
| kwd = kwd[0] | kwd = kwd[0] | ||||
| kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL) | |||||
| kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) | |||||
| if kwd.find("**ERROR**") >= 0: | if kwd.find("**ERROR**") >= 0: | ||||
| return "" | return "" | ||||
| return kwd | return kwd | ||||
| ############### | ############### | ||||
| """ | """ | ||||
| ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2}) | ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2}) | ||||
| ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL) | |||||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||||
| return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] | return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] | ||||
| kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) | kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) | ||||
| if isinstance(kwd, tuple): | if isinstance(kwd, tuple): | ||||
| kwd = kwd[0] | kwd = kwd[0] | ||||
| kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL) | |||||
| kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) | |||||
| if kwd.find("**ERROR**") >= 0: | if kwd.find("**ERROR**") >= 0: | ||||
| raise Exception(kwd) | raise Exception(kwd) | ||||
| response = await trio.to_thread.run_sync( | response = await trio.to_thread.run_sync( | ||||
| lambda: self._llm_model.chat(system, history, gen_conf) | lambda: self._llm_model.chat(system, history, gen_conf) | ||||
| ) | ) | ||||
| response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) | |||||
| response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) | |||||
| if response.find("**ERROR**") >= 0: | if response.find("**ERROR**") >= 0: | ||||
| raise Exception(response) | raise Exception(response) | ||||
| set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) |