### What problem does this PR solve? add support for LocalLLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>tags/v0.9.0
| @@ -27,6 +27,8 @@ from groq import Groq | |||
| import os | |||
| import json | |||
| import requests | |||
| import asyncio | |||
| from rag.svr.jina_server import Prompt,Generation | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name, base_url): | |||
| @@ -381,8 +383,10 @@ class LocalLLM(Base): | |||
| def __conn(self): | |||
| from multiprocessing.connection import Client | |||
| self._connection = Client( | |||
| (self.host, self.port), authkey=b'infiniflow-token4kevinhu') | |||
| (self.host, self.port), authkey=b"infiniflow-token4kevinhu" | |||
| ) | |||
| def __getattr__(self, name): | |||
| import pickle | |||
| @@ -390,8 +394,7 @@ class LocalLLM(Base): | |||
| def do_rpc(*args, **kwargs): | |||
| for _ in range(3): | |||
| try: | |||
| self._connection.send( | |||
| pickle.dumps((name, args, kwargs))) | |||
| self._connection.send(pickle.dumps((name, args, kwargs))) | |||
| return pickle.loads(self._connection.recv()) | |||
| except Exception as e: | |||
| self.__conn() | |||
| @@ -399,35 +402,45 @@ class LocalLLM(Base): | |||
| return do_rpc | |||
| def __init__(self, key, model_name="glm-3-turbo"): | |||
| self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) | |||
| def __init__(self, key, model_name): | |||
| from jina import Client | |||
| def chat(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| ans = self.client.chat( | |||
| history, | |||
| gen_conf | |||
| ) | |||
| return ans, num_tokens_from_string(ans) | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| self.client = Client(port=12345, protocol="grpc", asyncio=True) | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| def _prepare_prompt(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| token_count = 0 | |||
| if "max_tokens" in gen_conf: | |||
| gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens") | |||
| return Prompt(message=history, gen_conf=gen_conf) | |||
| def _stream_response(self, endpoint, prompt): | |||
| answer = "" | |||
| try: | |||
| for ans in self.client.chat_streamly(history, gen_conf): | |||
| answer += ans | |||
| token_count += 1 | |||
| yield answer | |||
| res = self.client.stream_doc( | |||
| on=endpoint, inputs=prompt, return_type=Generation | |||
| ) | |||
| loop = asyncio.get_event_loop() | |||
| try: | |||
| while True: | |||
| answer = loop.run_until_complete(res.__anext__()).text | |||
| yield answer | |||
| except StopAsyncIteration: | |||
| pass | |||
| except Exception as e: | |||
| yield answer + "\n**ERROR**: " + str(e) | |||
| yield num_tokens_from_string(answer) | |||
| def chat(self, system, history, gen_conf): | |||
| prompt = self._prepare_prompt(system, history, gen_conf) | |||
| chat_gen = self._stream_response("/chat", prompt) | |||
| ans = next(chat_gen) | |||
| total_tokens = next(chat_gen) | |||
| return ans, total_tokens | |||
| yield token_count | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| prompt = self._prepare_prompt(system, history, gen_conf) | |||
| return self._stream_response("/stream", prompt) | |||
| class VolcEngineChat(Base): | |||
| @@ -0,0 +1,93 @@ | |||
| from jina import Deployment | |||
| from docarray import BaseDoc | |||
| from jina import Executor, requests | |||
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |||
| import argparse | |||
| import torch | |||
| class Prompt(BaseDoc): | |||
| message: list[dict] | |||
| gen_conf: dict | |||
| class Generation(BaseDoc): | |||
| text: str | |||
| tokenizer = None | |||
| model_name = "" | |||
| class TokenStreamingExecutor(Executor): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.model = AutoModelForCausalLM.from_pretrained( | |||
| model_name, device_map="auto", torch_dtype="auto" | |||
| ) | |||
| @requests(on="/chat") | |||
| async def generate(self, doc: Prompt, **kwargs) -> Generation: | |||
| text = tokenizer.apply_chat_template( | |||
| doc.message, | |||
| tokenize=False, | |||
| ) | |||
| inputs = tokenizer([text], return_tensors="pt") | |||
| generation_config = GenerationConfig( | |||
| **doc.gen_conf, | |||
| eos_token_id=tokenizer.eos_token_id, | |||
| pad_token_id=tokenizer.eos_token_id | |||
| ) | |||
| generated_ids = self.model.generate( | |||
| inputs.input_ids, generation_config=generation_config | |||
| ) | |||
| generated_ids = [ | |||
| output_ids[len(input_ids) :] | |||
| for input_ids, output_ids in zip(inputs.input_ids, generated_ids) | |||
| ] | |||
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |||
| yield Generation(text=response) | |||
| @requests(on="/stream") | |||
| async def task(self, doc: Prompt, **kwargs) -> Generation: | |||
| text = tokenizer.apply_chat_template( | |||
| doc.message, | |||
| tokenize=False, | |||
| ) | |||
| input = tokenizer([text], return_tensors="pt") | |||
| input_len = input["input_ids"].shape[1] | |||
| max_new_tokens = 512 | |||
| if "max_new_tokens" in doc.gen_conf: | |||
| max_new_tokens = doc.gen_conf.pop("max_new_tokens") | |||
| generation_config = GenerationConfig( | |||
| **doc.gen_conf, | |||
| eos_token_id=tokenizer.eos_token_id, | |||
| pad_token_id=tokenizer.eos_token_id | |||
| ) | |||
| for _ in range(max_new_tokens): | |||
| output = self.model.generate( | |||
| **input, max_new_tokens=1, generation_config=generation_config | |||
| ) | |||
| if output[0][-1] == tokenizer.eos_token_id: | |||
| break | |||
| yield Generation( | |||
| text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True) | |||
| ) | |||
| input = { | |||
| "input_ids": output, | |||
| "attention_mask": torch.ones(1, len(output[0])), | |||
| } | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--model_name", type=str, help="Model name or path") | |||
| parser.add_argument("--port", default=12345, type=int, help="Jina serving port") | |||
| args = parser.parse_args() | |||
| model_name = args.model_name | |||
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |||
| with Deployment( | |||
| uses=TokenStreamingExecutor, port=args.port, protocol="grpc" | |||
| ) as dep: | |||
| dep.block() | |||