|
|
|
|
|
|
|
|
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') |
|
|
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') |
|
|
if not streaming: |
|
|
if not streaming: |
|
|
try: |
|
|
try: |
|
|
|
|
|
message_result = {} |
|
|
for message in pubsub.listen(): |
|
|
for message in pubsub.listen(): |
|
|
if message["type"] == "message": |
|
|
if message["type"] == "message": |
|
|
result = message["data"].decode('utf-8') |
|
|
result = message["data"].decode('utf-8') |
|
|
|
|
|
|
|
|
if result.get('error'): |
|
|
if result.get('error'): |
|
|
cls.handle_error(result) |
|
|
cls.handle_error(result) |
|
|
if result['event'] == 'message' and 'data' in result: |
|
|
if result['event'] == 'message' and 'data' in result: |
|
|
return cls.get_message_response_data(result.get('data')) |
|
|
|
|
|
|
|
|
message_result['message'] = result.get('data') |
|
|
|
|
|
if result['event'] == 'message_end' and 'data' in result: |
|
|
|
|
|
message_result['message_end'] = result.get('data') |
|
|
|
|
|
return cls.get_blocking_message_response_data(message_result) |
|
|
except ValueError as e: |
|
|
except ValueError as e: |
|
|
if e.args[0] != "I/O operation on closed file.": # ignore this error |
|
|
if e.args[0] != "I/O operation on closed file.": # ignore this error |
|
|
raise CompletionStoppedError() |
|
|
raise CompletionStoppedError() |
|
|
|
|
|
|
|
|
if event == "end": |
|
|
if event == "end": |
|
|
logging.debug("{} finished".format(generate_channel)) |
|
|
logging.debug("{} finished".format(generate_channel)) |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
if event == 'message': |
|
|
if event == 'message': |
|
|
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" |
|
|
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" |
|
|
elif event == 'chain': |
|
|
elif event == 'chain': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return response_data |
|
|
return response_data |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
|
def get_blocking_message_response_data(cls, data: dict): |
|
|
|
|
|
message = data.get('message') |
|
|
|
|
|
response_data = { |
|
|
|
|
|
'event': 'message', |
|
|
|
|
|
'task_id': message.get('task_id'), |
|
|
|
|
|
'id': message.get('message_id'), |
|
|
|
|
|
'answer': message.get('text'), |
|
|
|
|
|
'metadata': {}, |
|
|
|
|
|
'created_at': int(time.time()) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if message.get('mode') == 'chat': |
|
|
|
|
|
response_data['conversation_id'] = message.get('conversation_id') |
|
|
|
|
|
if 'message_end' in data: |
|
|
|
|
|
message_end = data.get('message_end') |
|
|
|
|
|
if 'retriever_resources' in message_end: |
|
|
|
|
|
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources') |
|
|
|
|
|
|
|
|
|
|
|
return response_data |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
def get_message_end_data(cls, data: dict): |
|
|
def get_message_end_data(cls, data: dict): |
|
|
response_data = { |
|
|
response_data = { |