formatting: run precommit on all files

master
lanvent 2023-04-22 12:01:29 +08:00
parent eaf4e9174f
commit 618c94edb8
40 changed files with 229 additions and 647 deletions

1
app.py
View File

@ -19,6 +19,7 @@ def sigterm_handler_wrap(_signo):
if callable(old_handler): # check old_handler if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame) return old_handler(_signo, _stack_frame)
sys.exit(0) sys.exit(0)
signal.signal(_signo, func) signal.signal(_signo, func)

View File

@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
class BaiduUnitBot(Bot): class BaiduUnitBot(Bot):
def reply(self, query, context=None): def reply(self, query, context=None):
token = self.get_token() token = self.get_token()
url = ( url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ token
)
post_data = ( post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query + query
@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
def get_token(self): def get_token(self):
access_key = "YOUR_ACCESS_KEY" access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY" secret_key = "YOUR_SECRET_KEY"
host = ( host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+ access_key
+ "&client_secret="
+ secret_key
)
response = requests.get(host) response = requests.get(host)
if response: if response:
print(response.json()) print(response.json())

View File

@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt"): if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
self.sessions = SessionManager( self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
)
self.args = { self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数 # "max_tokens":4096, # 回复最大的字符数
"top_p": 1, "top_p": 1,
"frequency_penalty": conf().get( "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"frequency_penalty", 0.0 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "request_timeout": conf().get("request_timeout", None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
} }
@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
reply_content["completion_tokens"], reply_content["completion_tokens"],
) )
) )
if ( if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply_content["completion_tokens"] == 0
and len(reply_content["content"]) > 0
):
reply = Reply(ReplyType.ERROR, reply_content["content"]) reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0: elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply( self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply_content["content"], session_id, reply_content["total_tokens"]
)
reply = Reply(ReplyType.TEXT, reply_content["content"]) reply = Reply(ReplyType.TEXT, reply_content["content"])
else: else:
reply = Reply(ReplyType.ERROR, reply_content["content"]) reply = Reply(ReplyType.ERROR, reply_content["content"])
@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used # if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
api_key=api_key, messages=session.messages, **self.args
)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return { return {
"total_tokens": response["usage"]["total_tokens"], "total_tokens": response["usage"]["total_tokens"],

View File

@ -25,9 +25,7 @@ class ChatGPTSession(Session):
precise = False precise = False
if cur_tokens is None: if cur_tokens is None:
raise e raise e
logger.debug( logger.debug("Exception when counting tokens precisely for query: {}".format(e))
"Exception when counting tokens precisely for query: {}".format(e)
)
while cur_tokens > max_tokens: while cur_tokens > max_tokens:
if len(self.messages) > 2: if len(self.messages) > 2:
self.messages.pop(1) self.messages.pop(1)
@ -39,16 +37,10 @@ class ChatGPTSession(Session):
cur_tokens = cur_tokens - max_tokens cur_tokens = cur_tokens - max_tokens
break break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn( logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
)
break break
else: else:
logger.debug( logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
"max_tokens={}, total_tokens={}, len(messages)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
break break
if precise: if precise:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()
@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4": elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314") return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301": elif model == "gpt-3.5-turbo-0301":
tokens_per_message = ( tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314": elif model == "gpt-4-0314":
tokens_per_message = 3 tokens_per_message = 3
tokens_per_name = 1 tokens_per_name = 1
else: else:
logger.warn( logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
)
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0 num_tokens = 0
for message in messages: for message in messages:

View File

@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
if proxy: if proxy:
openai.proxy = proxy openai.proxy = proxy
self.sessions = SessionManager( self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
OpenAISession, model=conf().get("model") or "text-davinci-003"
)
self.args = { self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数 "max_tokens": 1200, # 回复最大的字符数
"top_p": 1, "top_p": 1,
"frequency_penalty": conf().get( "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"frequency_penalty", 0.0 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "request_timeout": conf().get("request_timeout", None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"], "stop": ["\n\n\n"],
} }
@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
result["content"], result["content"],
) )
logger.debug( logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
str(session), session_id, reply_content, completion_tokens
)
) )
if total_tokens == 0: if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content) reply = Reply(ReplyType.ERROR, reply_content)
else: else:
self.sessions.session_reply( self.sessions.session_reply(reply_content, session_id, total_tokens)
reply_content, session_id, total_tokens
)
reply = Reply(ReplyType.TEXT, reply_content) reply = Reply(ReplyType.TEXT, reply_content)
return reply return reply
elif context.type == ContextType.IMAGE_CREATE: elif context.type == ContextType.IMAGE_CREATE:
@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
def reply_text(self, session: OpenAISession, retry_count=0): def reply_text(self, session: OpenAISession, retry_count=0):
try: try:
response = openai.Completion.create(prompt=str(session), **self.args) response = openai.Completion.create(prompt=str(session), **self.args)
res_content = ( res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
total_tokens = response["usage"]["total_tokens"] total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"] completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content)) logger.info("[OPEN_AI] reply={}".format(res_content))

View File

@ -23,9 +23,7 @@ class OpenAIImage(object):
response = openai.Image.create( response = openai.Image.create(
prompt=query, # 图片描述 prompt=query, # 图片描述
n=1, # 每次生成图片的数量 n=1, # 每次生成图片的数量
size=conf().get( size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
"image_create_size", "256x256"
), # 图片大小,可选有 256x256, 512x512, 1024x1024
) )
image_url = response["data"][0]["url"] image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
@ -34,11 +32,7 @@ class OpenAIImage(object):
logger.warn(e) logger.warn(e)
if retry_count < 1: if retry_count < 1:
time.sleep(5) time.sleep(5)
logger.warn( logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
retry_count + 1
)
)
return self.create_img(query, retry_count + 1) return self.create_img(query, retry_count + 1)
else: else:
return False, "提问太快啦,请休息一下再问我吧" return False, "提问太快啦,请休息一下再问我吧"

View File

@ -36,9 +36,7 @@ class OpenAISession(Session):
precise = False precise = False
if cur_tokens is None: if cur_tokens is None:
raise e raise e
logger.debug( logger.debug("Exception when counting tokens precisely for query: {}".format(e))
"Exception when counting tokens precisely for query: {}".format(e)
)
while cur_tokens > max_tokens: while cur_tokens > max_tokens:
if len(self.messages) > 1: if len(self.messages) > 1:
self.messages.pop(0) self.messages.pop(0)
@ -50,18 +48,10 @@ class OpenAISession(Session):
cur_tokens = len(str(self)) cur_tokens = len(str(self))
break break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn( logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
"user question exceed max_tokens. total_tokens={}".format(
cur_tokens
)
)
break break
else: else:
logger.debug( logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
break break
if precise: if precise:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()

View File

@ -55,9 +55,7 @@ class SessionManager(object):
return self.sessioncls(session_id, system_prompt, **self.session_args) return self.sessioncls(session_id, system_prompt, **self.session_args)
if session_id not in self.sessions: if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls( self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
session_id, system_prompt, **self.session_args
)
elif system_prompt is not None: # 如果有新的system_prompt更新并重置session elif system_prompt is not None: # 如果有新的system_prompt更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt) self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id] session = self.sessions[session_id]
@ -71,9 +69,7 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None) total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens)) logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e: except Exception as e:
logger.debug( logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
"Exception when counting tokens precisely for prompt: {}".format(str(e))
)
return session return session
def session_reply(self, reply, session_id, total_tokens=None): def session_reply(self, reply, session_id, total_tokens=None):
@ -82,17 +78,9 @@ class SessionManager(object):
try: try:
max_tokens = conf().get("conversation_max_tokens", 1000) max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug( logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
"raw total_tokens={}, savesession tokens={}".format(
total_tokens, tokens_cnt
)
)
except Exception as e: except Exception as e:
logger.debug( logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
"Exception when counting tokens precisely for session: {}".format(
str(e)
)
)
return session return session
def clear_session(self, session_id): def clear_session(self, session_id):

View File

@ -60,6 +60,4 @@ class Context:
del self.kwargs[key] del self.kwargs[key]
def __str__(self): def __str__(self):
return "Context(type={}, content={}, kwargs={})".format( return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
self.type, self.content, self.kwargs
)

View File

@ -53,9 +53,7 @@ class ChatChannel(Channel):
group_id = cmsg.other_user_id group_id = cmsg.other_user_id
group_name_white_list = config.get("group_name_white_list", []) group_name_white_list = config.get("group_name_white_list", [])
group_name_keyword_white_list = config.get( group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
"group_name_keyword_white_list", []
)
if any( if any(
[ [
group_name in group_name_white_list, group_name in group_name_white_list,
@ -63,9 +61,7 @@ class ChatChannel(Channel):
check_contain(group_name, group_name_keyword_white_list), check_contain(group_name, group_name_keyword_white_list),
] ]
): ):
group_chat_in_one_session = conf().get( group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
"group_chat_in_one_session", []
)
session_id = cmsg.actual_user_id session_id = cmsg.actual_user_id
if any( if any(
[ [
@ -81,17 +77,11 @@ class ChatChannel(Channel):
else: else:
context["session_id"] = cmsg.other_user_id context["session_id"] = cmsg.other_user_id
context["receiver"] = cmsg.other_user_id context["receiver"] = cmsg.other_user_id
e_context = PluginManager().emit_event( e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
EventContext(
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
)
)
context = e_context["context"] context = e_context["context"]
if e_context.is_pass() or context is None: if e_context.is_pass() or context is None:
return context return context
if cmsg.from_user_id == self.user_id and not config.get( if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
"trigger_by_self", True
):
logger.debug("[WX]self message skipped") logger.debug("[WX]self message skipped")
return None return None
@ -119,19 +109,13 @@ class ChatChannel(Channel):
if not flag: if not flag:
if context["origin_ctype"] == ContextType.VOICE: if context["origin_ctype"] == ContextType.VOICE:
logger.info( logger.info("[WX]receive group voice, but checkprefix didn't match")
"[WX]receive group voice, but checkprefix didn't match"
)
return None return None
else: # 单聊 else: # 单聊
match_prefix = check_prefix( match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
content, conf().get("single_chat_prefix", [""])
)
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, "", 1).strip() content = content.replace(match_prefix, "", 1).strip()
elif ( elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
context["origin_ctype"] == ContextType.VOICE
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass pass
else: else:
return None return None
@ -143,18 +127,10 @@ class ChatChannel(Channel):
else: else:
context.type = ContextType.TEXT context.type = ContextType.TEXT
context.content = content.strip() context.content = content.strip()
if ( if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
"desire_rtype" not in context
and conf().get("always_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
context["desire_rtype"] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE: elif context.type == ContextType.VOICE:
if ( if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
"desire_rtype" not in context
and conf().get("voice_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
context["desire_rtype"] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE
return context return context
@ -182,15 +158,8 @@ class ChatChannel(Channel):
) )
reply = e_context["reply"] reply = e_context["reply"]
if not e_context.is_pass(): if not e_context.is_pass():
logger.debug( logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
"[WX] ready to handle context: type={}, content={}".format( if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
context.type, context.content
)
)
if (
context.type == ContextType.TEXT
or context.type == ContextType.IMAGE_CREATE
): # 文字和图片消息
reply = super().build_reply_content(context.content, context) reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息 elif context.type == ContextType.VOICE: # 语音消息
cmsg = context["msg"] cmsg = context["msg"]
@ -214,9 +183,7 @@ class ChatChannel(Channel):
# logger.warning("[WX]delete temp file error: " + str(e)) # logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
new_context = self._compose_context( new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
ContextType.TEXT, reply.content, **context.kwargs
)
if new_context: if new_context:
reply = self._generate_reply(new_context) reply = self._generate_reply(new_context)
else: else:
@ -246,48 +213,24 @@ class ChatChannel(Channel):
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
reply_text = reply.content reply_text = reply.content
if ( if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
desire_rtype == ReplyType.VOICE
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
reply = super().build_text_to_voice(reply.content) reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply) return self._decorate_reply(context, reply)
if context.get("isgroup", False): if context.get("isgroup", False):
reply_text = ( reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
"@" reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
+ context["msg"].actual_user_nickname
+ " "
+ reply_text.strip()
)
reply_text = (
conf().get("group_chat_reply_prefix", "") + reply_text
)
else: else:
reply_text = ( reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
conf().get("single_chat_reply_prefix", "") + reply_text
)
reply.content = reply_text reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "[" + str(reply.type) + "]\n" + reply.content reply.content = "[" + str(reply.type) + "]\n" + reply.content
elif ( elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
reply.type == ReplyType.IMAGE_URL
or reply.type == ReplyType.VOICE
or reply.type == ReplyType.IMAGE
):
pass pass
else: else:
logger.error("[WX] unknown reply type: {}".format(reply.type)) logger.error("[WX] unknown reply type: {}".format(reply.type))
return return
if ( if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
desire_rtype logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
and desire_rtype != reply.type
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
):
logger.warning(
"[WX] desire_rtype: {}, but reply type: {}".format(
context.get("desire_rtype"), reply.type
)
)
return reply return reply
def _send_reply(self, context: Context, reply: Reply): def _send_reply(self, context: Context, reply: Reply):
@ -300,9 +243,7 @@ class ChatChannel(Channel):
) )
reply = e_context["reply"] reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type: if not e_context.is_pass() and reply and reply.type:
logger.debug( logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
"[WX] ready to send reply: {}, context: {}".format(reply, context)
)
self._send(reply, context) self._send(reply, context)
def _send(self, reply: Reply, context: Context, retry_cnt=0): def _send(self, reply: Reply, context: Context, retry_cnt=0):
@ -328,9 +269,7 @@ class ChatChannel(Channel):
try: try:
worker_exception = worker.exception() worker_exception = worker.exception()
if worker_exception: if worker_exception:
self._fail_callback( self._fail_callback(session_id, exception=worker_exception, **kwargs)
session_id, exception=worker_exception, **kwargs
)
else: else:
self._success_callback(session_id, **kwargs) self._success_callback(session_id, **kwargs)
except CancelledError as e: except CancelledError as e:
@ -366,24 +305,14 @@ class ChatChannel(Channel):
if not context_queue.empty(): if not context_queue.empty():
context = context_queue.get() context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context)) logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit( future: Future = self.handler_pool.submit(self._handle, context)
self._handle, context future.add_done_callback(self._thread_pool_callback(session_id, context=context))
)
future.add_done_callback(
self._thread_pool_callback(session_id, context=context)
)
if session_id not in self.futures: if session_id not in self.futures:
self.futures[session_id] = [] self.futures[session_id] = []
self.futures[session_id].append(future) self.futures[session_id].append(future)
elif ( elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
semaphore._initial_value == semaphore._value + 1 self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 assert len(self.futures[session_id]) == 0, "thread pool error"
self.futures[session_id] = [
t for t in self.futures[session_id] if not t.done()
]
assert (
len(self.futures[session_id]) == 0
), "thread pool error"
del self.sessions[session_id] del self.sessions[session_id]
else: else:
semaphore.release() semaphore.release()
@ -397,9 +326,7 @@ class ChatChannel(Channel):
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt > 0: if cnt > 0:
logger.info( logger.info("Cancel {} messages in session {}".format(cnt, session_id))
"Cancel {} messages in session {}".format(cnt, session_id)
)
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()
def cancel_all_session(self): def cancel_all_session(self):
@ -409,9 +336,7 @@ class ChatChannel(Channel):
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt > 0: if cnt > 0:
logger.info( logger.info("Cancel {} messages in session {}".format(cnt, session_id))
"Cancel {} messages in session {}".format(cnt, session_id)
)
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()

View File

@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel):
if check_prefix(prompt, trigger_prefixs) is None: if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
context = self._compose_context( context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
)
if context: if context:
self.produce(context) self.produce(context)
else: else:

View File

@ -56,10 +56,7 @@ def _check(func):
return return
self.receivedMsgs[msgId] = cmsg self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳 create_time = cmsg.create_time # 消息时间戳
if ( if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
conf().get("hot_reload") == True
and int(create_time) < int(time.time()) - 60
): # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId)) logger.debug("[WX]history message {} skipped".format(msgId))
return return
return func(self, cmsg) return func(self, cmsg)
@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
url = f"https://login.weixin.qq.com/l/{uuid}" url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2 = ( qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
url
)
)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
url
)
print("You can also scan QRCode in any website below:") print("You can also scan QRCode in any website below:")
print(qr_api3) print(qr_api3)
print(qr_api4) print(qr_api4)
@ -134,18 +125,12 @@ class WechatChannel(ChatChannel):
logger.error("Hot reload failed, try to login without hot reload") logger.error("Hot reload failed, try to login without hot reload")
itchat.logout() itchat.logout()
os.remove(status_path) os.remove(status_path)
itchat.auto_login( itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
)
else: else:
raise e raise e
self.user_id = itchat.instance.storageClass.userName self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName self.name = itchat.instance.storageClass.nickName
logger.info( logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
"Wechat login success, user_id: {}, nickname: {}".format(
self.user_id, self.name
)
)
# start message listener # start message listener
itchat.run() itchat.run()
@ -173,16 +158,10 @@ class WechatChannel(ChatChannel):
elif cmsg.ctype == ContextType.PATPAT: elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT: elif cmsg.ctype == ContextType.TEXT:
logger.debug( logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
"[WX]receive text msg: {}, cmsg={}".format(
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
)
)
else: else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context( context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
)
if context: if context:
self.produce(context) self.produce(context)
@ -202,9 +181,7 @@ class WechatChannel(ChatChannel):
pass pass
else: else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content)) logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context( context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
)
if context: if context:
self.produce(context) self.produce(context)

View File

@ -27,37 +27,23 @@ class WeChatMessage(ChatMessage):
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content) self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
if is_group and ( if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
"加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]
):
self.ctype = ContextType.JOIN_GROUP self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"] self.content = itchat_msg["Content"]
# 这里只能得到nickname actual_user_id还是机器人的id # 这里只能得到nickname actual_user_id还是机器人的id
if "加入了群聊" in itchat_msg["Content"]: if "加入了群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall( self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
r"\"(.*?)\"", itchat_msg["Content"]
)[-1]
elif "加入群聊" in itchat_msg["Content"]: elif "加入群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall( self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
elif "拍了拍我" in itchat_msg["Content"]: elif "拍了拍我" in itchat_msg["Content"]:
self.ctype = ContextType.PATPAT self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"] self.content = itchat_msg["Content"]
if is_group: if is_group:
self.actual_user_nickname = re.findall( self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
else: else:
raise NotImplementedError( raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
"Unsupported note message: " + itchat_msg["Content"]
)
else: else:
raise NotImplementedError( raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
"Unsupported message type: Type:{} MsgType:{}".format(
itchat_msg["Type"], itchat_msg["MsgType"]
)
)
self.from_user_id = itchat_msg["FromUserName"] self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"] self.to_user_id = itchat_msg["ToUserName"]

View File

@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel):
receiver_id = context["receiver"] receiver_id = context["receiver"]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if context["isgroup"]: if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe( receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
self.bot.Room.find(receiver_id), loop
).result()
else: else:
receiver = asyncio.run_coroutine_threadsafe( receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
self.bot.Contact.find(receiver_id), loop
).result()
msg = None msg = None
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
msg = reply.content msg = reply.content
@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel):
voiceLength = int(any_to_sil(file_path, sil_file)) voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000: if voiceLength >= 60000:
voiceLength = 60000 voiceLength = 60000
logger.info( logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
)
# 发送语音 # 发送语音
t = int(time.time()) t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + ".sil") msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel):
os.remove(sil_file) os.remove(sil_file)
except Exception as e: except Exception as e:
pass pass
logger.info( logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
)
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content img_url = reply.content
t = int(time.time()) t = int(time.time())
@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
t = int(time.time()) t = int(time.time())
msg = FileBox.from_base64( msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
base64.b64encode(image_storage.read()), str(t) + ".png"
)
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver)) logger.info("[WX] sendImage, receiver={}".format(receiver))

View File

@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject):
def func(): def func():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
voice_file.to_file(self.content), loop
).result()
self._prepare_fn = func self._prepare_fn = func
else: else:
raise NotImplementedError( raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
"Unsupported message type: {}".format(wechaty_msg.type())
)
from_contact = wechaty_msg.talker() # 获取消息的发送者 from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id self.from_user_id = from_contact.contact_id
@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name self.to_user_nickname = to_contact.name
if ( if self.is_group or wechaty_msg.is_self(): # 如果是群消息other_user设置为群如果是私聊消息而且自己发的就设置成对方。
self.is_group or wechaty_msg.is_self()
): # 如果是群消息other_user设置为群如果是私聊消息而且自己发的就设置成对方。
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname self.other_user_nickname = self.to_user_nickname
else: else:

View File

@ -1,16 +1,17 @@
import time import time
import web import web
from wechatpy import parse_message
from wechatpy.replies import create_reply
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
from wechatpy import parse_message from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger from common.log import logger
from config import conf from config import conf
from wechatpy.replies import create_reply
# This class is instantiated once per query # This class is instantiated once per query
class Query: class Query:
@ -50,29 +51,19 @@ class Query:
) )
) )
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context( context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
else: else:
context = channel._compose_context( context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
)
if context: if context:
# set private openai_api_key # set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel # if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user) user_data = conf().get_user_data(from_user)
context["openai_api_key"] = user_data.get( context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
"openai_api_key"
) # None or user openai_api_key
channel.produce(context) channel.produce(context)
# The reply will be sent by channel.send() in another thread # The reply will be sent by channel.send() in another thread
return "success" return "success"
elif msg.type == "event": elif msg.type == "event":
logger.info( logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
if msg.event in ["subscribe", "subscribe_scan"]: if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg() reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)

View File

@ -1,10 +1,12 @@
import textwrap import textwrap
import web
from config import conf import web
from wechatpy.utils import check_signature
from wechatpy.crypto import WeChatCrypto from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import InvalidSignatureException from wechatpy.exceptions import InvalidSignatureException
from wechatpy.utils import check_signature
from config import conf
MAX_UTF8_LEN = 2048 MAX_UTF8_LEN = 2048

View File

@ -1,17 +1,18 @@
import time
import asyncio import asyncio
import time
import web import web
from wechatpy import parse_message
from wechatpy.replies import ImageReply, VoiceReply, create_reply
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger from common.log import logger
from config import conf from config import conf
from wechatpy import parse_message
from wechatpy.replies import create_reply, ImageReply, VoiceReply
# This class is instantiated once per query # This class is instantiated once per query
class Query: class Query:
@ -54,16 +55,10 @@ class Query:
): ):
# The first query begin # The first query begin
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context( context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
else: else:
context = channel._compose_context( context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
)
logger.debug(
"[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)
)
if supported and context: if supported and context:
# set private openai_api_key # set private openai_api_key
@ -98,19 +93,13 @@ class Query:
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())
# Wechat official server will request 3 times (5 seconds each), with the same message_id. # Wechat official server will request 3 times (5 seconds each), with the same message_id.
# Because the interval is 5 seconds, here assumed that do not have multithreading problems. # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
request_cnt = channel.request_cnt.get(message_id, 0) + 1 request_cnt = channel.request_cnt.get(message_id, 0) + 1
channel.request_cnt[message_id] = request_cnt channel.request_cnt[message_id] = request_cnt
logger.info( logger.info(
"[wechatmp] Request {} from {} {} {}:{}\n{}".format( "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
request_cnt, request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
from_user,
message_id,
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
content
) )
) )
@ -140,10 +129,7 @@ class Query:
channel.request_cnt.pop(message_id) channel.request_cnt.pop(message_id)
# no return because of bandwords or other reasons # no return because of bandwords or other reasons
if ( if from_user not in channel.cache_dict and from_user not in channel.running:
from_user not in channel.cache_dict
and from_user not in channel.running
):
return "success" return "success"
# Only one request can access to the cached data # Only one request can access to the cached data
@ -152,7 +138,7 @@ class Query:
except KeyError: except KeyError:
return "success" return "success"
if (reply_type == "text"): if reply_type == "text":
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
reply_text = reply_content reply_text = reply_content
else: else:
@ -177,7 +163,7 @@ class Query:
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())
elif (reply_type == "voice"): elif reply_type == "voice":
media_id = reply_content media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info( logger.info(
@ -193,7 +179,7 @@ class Query:
replyPost.media_id = media_id replyPost.media_id = media_id
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())
elif (reply_type == "image"): elif reply_type == "image":
media_id = reply_content media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info( logger.info(
@ -210,11 +196,7 @@ class Query:
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())
elif msg.type == "event": elif msg.type == "event":
logger.info( logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
if msg.event in ["subscribe", "subscribe_scan"]: if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg() reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)

View File

@ -1,24 +1,26 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
import imghdr
import io import io
import os import os
import time
import imghdr
import requests
import asyncio
import threading import threading
from config import conf import time
import requests
import web
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import WeChatClientException
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from common.log import logger
from common.singleton import singleton
from voice.audio_convert import any_to_mp3
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_client import WechatMPClient from channel.wechatmp.wechatmp_client import WechatMPClient
from wechatpy.exceptions import WeChatClientException from common.log import logger
from wechatpy.crypto import WeChatCrypto from common.singleton import singleton
from config import conf
from voice.audio_convert import any_to_mp3
import web
# If using SSL, uncomment the following lines, and modify the certificate path. # If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer # from cheroot.server import HTTPServer
# from cheroot.ssl.builtin import BuiltinSSLAdapter # from cheroot.ssl.builtin import BuiltinSSLAdapter
@ -54,7 +56,6 @@ class WechatMPChannel(ChatChannel):
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()
def startup(self): def startup(self):
if self.passive_reply: if self.passive_reply:
urls = ("/wx", "channel.wechatmp.passive_reply.Query") urls = ("/wx", "channel.wechatmp.passive_reply.Query")
@ -84,7 +85,7 @@ class WechatMPChannel(ChatChannel):
elif reply.type == ReplyType.VOICE: elif reply.type == ReplyType.VOICE:
try: try:
voice_file_path = reply.content voice_file_path = reply.content
with open(voice_file_path, 'rb') as f: with open(voice_file_path, "rb") as f:
# support: <2M, <60s, mp3/wma/wav/amr # support: <2M, <60s, mp3/wma/wav/amr
response = self.client.material.add("voice", f) response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response)) logger.debug("[wechatmp] upload voice response: {}".format(response))
@ -107,7 +108,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.material.add("image", (filename, image_storage, content_type)) response = self.client.material.add("image", (filename, image_storage, content_type))
@ -122,7 +123,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.material.add("image", (filename, image_storage, content_type)) response = self.client.material.add("image", (filename, image_storage, content_type))
@ -137,7 +138,7 @@ class WechatMPChannel(ChatChannel):
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content reply_text = reply.content
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
if len(texts)>1: if len(texts) > 1:
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
for text in texts: for text in texts:
self.client.message.send_text(receiver, text) self.client.message.send_text(receiver, text)
@ -174,7 +175,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.media.upload("image", (filename, image_storage, content_type)) response = self.client.media.upload("image", (filename, image_storage, content_type))
@ -188,7 +189,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.media.upload("image", (filename, image_storage, content_type)) response = self.client.media.upload("image", (filename, image_storage, content_type))
@ -201,20 +202,12 @@ class WechatMPChannel(ChatChannel):
return return
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug( logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
"[wechatmp] Success to generate reply, msgId={}".format(
context["msg"].msg_id
)
)
if self.passive_reply: if self.passive_reply:
self.running.remove(session_id) self.running.remove(session_id)
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception( logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
context["msg"].msg_id, exception
)
)
if self.passive_reply: if self.passive_reply:
assert session_id not in self.cache_dict assert session_id not in self.cache_dict
self.running.remove(session_id) self.running.remove(session_id)

View File

@ -1,17 +1,16 @@
import time
import threading import threading
from channel.wechatmp.common import * import time
from wechatpy.client import WeChatClient from wechatpy.client import WeChatClient
from common.log import logger
from wechatpy.exceptions import APILimitedException from wechatpy.exceptions import APILimitedException
from channel.wechatmp.common import *
from common.log import logger
class WechatMPClient(WeChatClient): class WechatMPClient(WeChatClient):
def __init__(self, appid, secret, access_token=None, def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
session=None, timeout=None, auto_retry=True): super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
super(WechatMPClient, self).__init__(
appid, secret, access_token, session, timeout, auto_retry
)
self.fetch_access_token_lock = threading.Lock() self.fetch_access_token_lock = threading.Lock()
def clear_quota(self): def clear_quota(self):

View File

@ -6,7 +6,6 @@ from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
class WeChatMPMessage(ChatMessage): class WeChatMPMessage(ChatMessage):
def __init__(self, msg, client=None): def __init__(self, msg, client=None):
super().__init__(msg) super().__init__(msg)
@ -18,12 +17,9 @@ class WeChatMPMessage(ChatMessage):
self.ctype = ContextType.TEXT self.ctype = ContextType.TEXT
self.content = msg.content self.content = msg.content
elif msg.type == "voice": elif msg.type == "voice":
if msg.recognition == None: if msg.recognition == None:
self.ctype = ContextType.VOICE self.ctype = ContextType.VOICE
self.content = ( self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
TmpDir().path() + msg.media_id + "." + msg.format
) # content直接存临时目录路径
def download_voice(): def download_voice():
# 如果响应状态码是200则将响应内容写入本地文件 # 如果响应状态码是200则将响应内容写入本地文件
@ -32,9 +28,7 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info( logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
f"[wechatmp] Failed to download voice file, {response.content}"
)
self._prepare_fn = download_voice self._prepare_fn = download_voice
else: else:
@ -43,6 +37,7 @@ class WeChatMPMessage(ChatMessage):
elif msg.type == "image": elif msg.type == "image":
self.ctype = ContextType.IMAGE self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
def download_image(): def download_image():
# 如果响应状态码是200则将响应内容写入本地文件 # 如果响应状态码是200则将响应内容写入本地文件
response = client.media.download(msg.media_id) response = client.media.download(msg.media_id)
@ -50,15 +45,11 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info( logger.info(f"[wechatmp] Failed to download image file, {response.content}")
f"[wechatmp] Failed to download image file, {response.content}"
)
self._prepare_fn = download_image self._prepare_fn = download_image
else: else:
raise NotImplementedError( raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
"Unsupported message type: Type:{} ".format(msg.type)
)
self.from_user_id = msg.source self.from_user_id = msg.source
self.to_user_id = msg.target self.to_user_id = msg.target

View File

@ -13,23 +13,15 @@ def time_checker(f):
if chat_time_module: if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00") chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00") chat_stopt_time = _config.get("chat_stop_time", "24:00")
time_regex = re.compile( time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配包含24:00
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
) # 时间匹配包含24:00
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
# 时间格式检查 # 时间格式检查
if not ( if not (starttime_format_check and stoptime_format_check and chat_time_check):
starttime_format_check and stoptime_format_check and chat_time_check logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
):
logger.warn(
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
starttime_format_check, stoptime_format_check
)
)
if chat_start_time > "23:59": if chat_start_time > "23:59":
logger.error("启动时间可能存在问题,请修改!") logger.error("启动时间可能存在问题,请修改!")

View File

@ -158,9 +158,7 @@ def load_config():
for name, value in os.environ.items(): for name, value in os.environ.items():
name = name.lower() name = name.lower()
if name in available_setting: if name in available_setting:
logger.info( logger.info("[INIT] override config by environ args: {}={}".format(name, value))
"[INIT] override config by environ args: {}={}".format(name, value)
)
try: try:
config[name] = eval(value) config[name] = eval(value)
except: except:

View File

@ -50,9 +50,7 @@ class Banwords(Plugin):
self.reply_action = conf.get("reply_action", "ignore") self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited") logger.info("[Banwords] inited")
except Exception as e: except Exception as e:
logger.warn( logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
)
raise e raise e
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
@ -72,9 +70,7 @@ class Banwords(Plugin):
return return
elif self.action == "replace": elif self.action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply( reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
)
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
@ -94,9 +90,7 @@ class Banwords(Plugin):
return return
elif self.reply_action == "replace": elif self.reply_action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply( reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
)
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.CONTINUE e_context.action = EventAction.CONTINUE
return return

View File

@ -76,9 +76,7 @@ class BDunit(Plugin):
Returns: Returns:
string: access_token string: access_token
""" """
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
self.api_key, self.secret_key
)
payload = "" payload = ""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -94,10 +92,7 @@ class BDunit(Plugin):
:returns: UNIT 解析结果如果解析失败返回 None :returns: UNIT 解析结果如果解析失败返回 None
""" """
url = ( url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ self.access_token
)
request = { request = {
"query": query, "query": query,
"user_id": str(get_mac())[:32], "user_id": str(get_mac())[:32],
@ -124,10 +119,7 @@ class BDunit(Plugin):
:param query: 用户的指令字符串 :param query: 用户的指令字符串
:returns: UNIT 解析结果如果解析失败返回 None :returns: UNIT 解析结果如果解析失败返回 None
""" """
url = ( url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
+ self.access_token
)
request = {"query": query, "user_id": str(get_mac())[:32]} request = {"query": query, "user_id": str(get_mac())[:32]}
body = { body = {
"log_id": str(uuid.uuid1()), "log_id": str(uuid.uuid1()),
@ -170,11 +162,7 @@ class BDunit(Plugin):
if parsed and "result" in parsed and "response_list" in parsed["result"]: if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
return True return True
return False return False
else: else:
@ -198,12 +186,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return [] return []
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and "slots" in response["schema"]
and response["schema"]["intent"] == intent
):
return response["schema"]["slots"] return response["schema"]["slots"]
return [] return []
else: else:
@ -239,11 +222,7 @@ class BDunit(Plugin):
if ( if (
"schema" in response "schema" in response
and "intent_confidence" in response["schema"] and "intent_confidence" in response["schema"]
and ( and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
not answer
or response["schema"]["intent_confidence"]
> answer["schema"]["intent_confidence"]
)
): ):
answer = response answer = response
return answer["action_list"][0]["say"] return answer["action_list"][0]["say"]
@ -267,11 +246,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return "" return ""
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
try: try:
return response["action_list"][0]["say"] return response["action_list"][0]["say"]
except Exception as e: except Exception as e:

View File

@ -84,9 +84,7 @@ class Dungeon(Plugin):
if len(clist) > 1: if len(clist) > 1:
story = clist[1] story = clist[1]
else: else:
story = ( story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
)
self.games[sessionid] = StoryTeller(bot, sessionid, story) self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context["reply"] = reply e_context["reply"] = reply
@ -102,11 +100,7 @@ class Dungeon(Plugin):
if kwargs.get("verbose") != True: if kwargs.get("verbose") != True:
return help_text return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = ( help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
f"{trigger_prefix}开始冒险 "
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
)
if kwargs.get("verbose") == True: if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text return help_text

View File

@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
if plugins[plugin].enabled and not plugins[plugin].hidden: if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn namecn = plugins[plugin].namecn
help_text += "\n%s:" % namecn help_text += "\n%s:" % namecn
help_text += ( help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
)
if ADMIN_COMMANDS and isadmin: if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n" help_text += "\n\n管理员指令:\n"
@ -191,9 +189,7 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command) COMMANDS["reset"]["alias"].append(custom_command)
self.password = gconf["password"] self.password = gconf["password"]
self.admin_users = gconf[ self.admin_users = gconf["admin_users"] # 预存的管理员账号这些账号不需要认证。itchat的用户名每次都会变不可用
"admin_users"
] # 预存的管理员账号这些账号不需要认证。itchat的用户名每次都会变不可用
self.isrunning = True # 机器人是否运行中 self.isrunning = True # 机器人是否运行中
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@ -248,11 +244,7 @@ class Godcmd(Plugin):
if not plugincls.enabled: if not plugincls.enabled:
continue continue
if query_name == name or query_name == plugincls.namecn: if query_name == name or query_name == plugincls.namecn:
ok, result = True, PluginManager().instances[ ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
name
].get_help_text(
isgroup=isgroup, isadmin=isadmin, verbose=True
)
break break
if not ok: if not ok:
result = "插件不存在或未启用" result = "插件不存在或未启用"
@ -285,11 +277,7 @@ class Godcmd(Plugin):
if isgroup: if isgroup:
ok, result = False, "群聊不可执行管理员指令" ok, result = False, "群聊不可执行管理员指令"
else: else:
cmd = next( cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
c
for c, info in ADMIN_COMMANDS.items()
if cmd in info["alias"]
)
if cmd == "stop": if cmd == "stop":
self.isrunning = False self.isrunning = False
ok, result = True, "服务已暂停" ok, result = True, "服务已暂停"
@ -325,18 +313,14 @@ class Godcmd(Plugin):
PluginManager().activate_plugins() PluginManager().activate_plugins()
if len(new_plugins) > 0: if len(new_plugins) > 0:
result += "\n发现新插件:\n" result += "\n发现新插件:\n"
result += "\n".join( result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
[f"{p.name}_v{p.version}" for p in new_plugins]
)
else: else:
result += ", 未发现新插件" result += ", 未发现新插件"
elif cmd == "setpri": elif cmd == "setpri":
if len(args) != 2: if len(args) != 2:
ok, result = False, "请提供插件名和优先级" ok, result = False, "请提供插件名和优先级"
else: else:
ok = PluginManager().set_plugin_priority( ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
args[0], int(args[1])
)
if ok: if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1] result = "插件" + args[0] + "优先级已设置为" + args[1]
else: else:

View File

@ -33,9 +33,7 @@ class Hello(Plugin):
if e_context["context"].type == ContextType.JOIN_GROUP: if e_context["context"].type == ContextType.JOIN_GROUP:
e_context["context"].type = ContextType.TEXT e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"] msg: ChatMessage = e_context["context"]["msg"]
e_context[ e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
"context"
].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return return
@ -53,9 +51,7 @@ class Hello(Plugin):
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
msg: ChatMessage = e_context["context"]["msg"] msg: ChatMessage = e_context["context"]["msg"]
if e_context["context"]["isgroup"]: if e_context["context"]["isgroup"]:
reply.content = ( reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
)
else: else:
reply.content = f"Hello, {msg.from_user_nickname}" reply.content = f"Hello, {msg.from_user_nickname}"
e_context["reply"] = reply e_context["reply"] = reply

View File

@ -41,9 +41,7 @@ class Keyword(Plugin):
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[keyword] inited.") logger.info("[keyword] inited.")
except Exception as e: except Exception as e:
logger.warn( logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
"[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ."
)
raise e raise e
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):

View File

@ -31,23 +31,14 @@ class PluginManager:
plugincls.desc = kwargs.get("desc") plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get("author") plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path plugincls.path = self.current_plugin_path
plugincls.version = ( plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
kwargs.get("version") if kwargs.get("version") != None else "1.0" plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
) plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
plugincls.namecn = (
kwargs.get("namecn") if kwargs.get("namecn") != None else name
)
plugincls.hidden = (
kwargs.get("hidden") if kwargs.get("hidden") != None else False
)
plugincls.enabled = True plugincls.enabled = True
if self.current_plugin_path == None: if self.current_plugin_path == None:
raise Exception("Plugin path not set") raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls self.plugins[name.upper()] = plugincls
logger.info( logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
"Plugin %s_v%s registered, path=%s"
% (name, plugincls.version, plugincls.path)
)
return wrapper return wrapper
@ -62,9 +53,7 @@ class PluginManager:
if os.path.exists("./plugins/plugins.json"): if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f) pconf = json.load(f)
pconf["plugins"] = SortedDict( pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
lambda k, v: v["priority"], pconf["plugins"], reverse=True
)
else: else:
modified = True modified = True
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
@ -90,26 +79,16 @@ class PluginManager:
if plugin_path in self.loaded: if plugin_path in self.loaded:
if self.loaded[plugin_path] == None: if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name) logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload( self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
sys.modules[import_path] dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
)
dependent_module_names = [
name
for name in sys.modules.keys()
if name.startswith(import_path + ".")
]
for name in dependent_module_names: for name in dependent_module_names:
logger.info("reload module %s" % name) logger.info("reload module %s" % name)
importlib.reload(sys.modules[name]) importlib.reload(sys.modules[name])
else: else:
self.loaded[plugin_path] = importlib.import_module( self.loaded[plugin_path] = importlib.import_module(import_path)
import_path
)
self.current_plugin_path = None self.current_plugin_path = None
except Exception as e: except Exception as e:
logger.exception( logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
"Failed to import plugin %s: %s" % (plugin_name, e)
)
continue continue
pconf = self.pconf pconf = self.pconf
news = [self.plugins[name] for name in self.plugins] news = [self.plugins[name] for name in self.plugins]
@ -119,9 +98,7 @@ class PluginManager:
rawname = plugincls.name rawname = plugincls.name
if rawname not in pconf["plugins"]: if rawname not in pconf["plugins"]:
modified = True modified = True
logger.info( logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
"Plugin %s not found in pconfig, adding to pconfig..." % name
)
pconf["plugins"][rawname] = { pconf["plugins"][rawname] = {
"enabled": plugincls.enabled, "enabled": plugincls.enabled,
"priority": plugincls.priority, "priority": plugincls.priority,
@ -136,9 +113,7 @@ class PluginManager:
def refresh_order(self): def refresh_order(self):
for event in self.listening_plugins.keys(): for event in self.listening_plugins.keys():
self.listening_plugins[event].sort( self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
key=lambda name: self.plugins[name].priority, reverse=True
)
def activate_plugins(self): # 生成新开启的插件实例 def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = [] failed_plugins = []
@ -184,13 +159,8 @@ class PluginManager:
def emit_event(self, e_context: EventContext, *args, **kwargs): def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins: if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]: for name in self.listening_plugins[e_context.event]:
if ( if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
self.plugins[name].enabled logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
and e_context.action == EventAction.CONTINUE
):
logger.debug(
"Plugin %s triggered by event %s" % (name, e_context.event)
)
instance = self.instances[name] instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs) instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context return e_context
@ -262,9 +232,7 @@ class PluginManager:
source = json.load(f) source = json.load(f)
if repo in source["repo"]: if repo in source["repo"]:
repo = source["repo"][repo]["url"] repo = source["repo"][repo]["url"]
match = re.match( match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
)
if not match: if not match:
return False, "安装插件失败source中的仓库地址不合法" return False, "安装插件失败source中的仓库地址不合法"
else: else:

View File

@ -69,13 +69,9 @@ class Role(Plugin):
logger.info("[Role] inited") logger.info("[Role] inited")
except Exception as e: except Exception as e:
if isinstance(e, FileNotFoundError): if isinstance(e, FileNotFoundError):
logger.warn( logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
else: else:
logger.warn( logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
raise e raise e
def get_role(self, name, find_closest=True, min_sim=0.35): def get_role(self, name, find_closest=True, min_sim=0.35):
@ -143,9 +139,7 @@ class Role(Plugin):
else: else:
help_text = f"未知角色类型。\n" help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
help_text += ( help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n"
"".join([self.tags[tag][0] for tag in self.tags]) + "\n"
)
else: else:
help_text = f"请输入角色类型。\n" help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
@ -158,9 +152,7 @@ class Role(Plugin):
return return
logger.debug("[Role] on_handle_context. content: %s" % content) logger.debug("[Role] on_handle_context. content: %s" % content)
if desckey is not None: if desckey is not None:
if len(clist) == 1 or ( if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
@ -178,9 +170,7 @@ class Role(Plugin):
self.roles[role][desckey], self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"), self.roles[role].get("wrapper", "%s"),
) )
reply = Reply( reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
)
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
elif customize == True: elif customize == True:
@ -199,17 +189,10 @@ class Role(Plugin):
if not verbose: if not verbose:
return help_text return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = ( help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
f"使用方法:\n{trigger_prefix}角色"
+ " 预设角色名: 设定角色为{预设角色名}。\n"
+ f"{trigger_prefix}role"
+ " 预设角色名: 同上,但使用英文设定。\n"
)
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
help_text += ( help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
)
help_text += "\n目前的角色类型有: \n" help_text += "\n目前的角色类型有: \n"
help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n" help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"

View File

@ -82,9 +82,7 @@ class Tool(Plugin):
return return
elif content_list[1].startswith("reset"): elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind") logger.debug("[tool]: remind")
e_context[ e_context["context"].content = "请你随机用一种聊天风格提醒用户如果想重置tool插件reset之后不要加任何字符"
"context"
].content = "请你随机用一种聊天风格提醒用户如果想重置tool插件reset之后不要加任何字符"
e_context.action = EventAction.BREAK e_context.action = EventAction.BREAK
return return
@ -93,18 +91,14 @@ class Tool(Plugin):
# Don't modify bot name # Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions all_sessions = Bridge().get_bot("chat").sessions
user_session = all_sessions.session_query( user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
query, e_context["context"]["session_id"]
).messages
# chatgpt-tool-hub will reply you with many tools # chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go") logger.debug("[tool]: just-go")
try: try:
_reply = self.app.ask(query, user_session) _reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
all_sessions.session_reply( all_sessions.session_reply(_reply, e_context["context"]["session_id"])
_reply, e_context["context"]["session_id"]
)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
logger.error(str(e)) logger.error(str(e))

View File

@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb") wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes()) return wav.readframes(wav.getnframes())
def any_to_mp3(any_path, mp3_path): def any_to_mp3(any_path, mp3_path):
""" """
把任意格式转成mp3文件 把任意格式转成mp3文件
@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path):
if any_path.endswith(".mp3"): if any_path.endswith(".mp3"):
shutil.copy2(any_path, mp3_path) shutil.copy2(any_path, mp3_path)
return return
if ( if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
sil_to_wav(any_path, any_path) sil_to_wav(any_path, any_path)
any_path = mp3_path any_path = mp3_path
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)
audio.export(mp3_path, format="mp3") audio.export(mp3_path, format="mp3")
def any_to_wav(any_path, wav_path): def any_to_wav(any_path, wav_path):
""" """
把任意格式转成wav文件 把任意格式转成wav文件
@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path):
if any_path.endswith(".wav"): if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path) shutil.copy2(any_path, wav_path)
return return
if ( if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
return sil_to_wav(any_path, wav_path) return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav") audio.export(wav_path, format="wav")
@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path):
""" """
把任意格式转成sil文件 把任意格式转成sil文件
""" """
if ( if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
shutil.copy2(any_path, sil_path) shutil.copy2(any_path, sil_path)
return 10000 return 10000
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)

View File

@ -40,57 +40,33 @@ class AzureVoice(Voice):
config = json.load(fr) config = json.load(fr)
self.api_key = conf().get("azure_voice_api_key") self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get("azure_voice_region") self.api_region = conf().get("azure_voice_region")
self.speech_config = speechsdk.SpeechConfig( self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
subscription=self.api_key, region=self.api_region self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
) self.speech_config.speech_recognition_language = config["speech_recognition_language"]
self.speech_config.speech_synthesis_voice_name = config[
"speech_synthesis_voice_name"
]
self.speech_config.speech_recognition_language = config[
"speech_recognition_language"
]
except Exception as e: except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e) logger.warn("AzureVoice init failed: %s, ignore " % e)
def voiceToText(self, voice_file): def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file) audio_config = speechsdk.AudioConfig(filename=voice_file)
speech_recognizer = speechsdk.SpeechRecognizer( speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
speech_config=self.speech_config, audio_config=audio_config
)
result = speech_recognizer.recognize_once() result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech: if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info( logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
"[Azure] voiceToText voice file name={} text={}".format(
voice_file, result.text
)
)
reply = Reply(ReplyType.TEXT, result.text) reply = Reply(ReplyType.TEXT, result.text)
else: else:
logger.error( logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
"[Azure] voiceToText error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply return reply
def textToVoice(self, text): def textToVoice(self, text):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName) audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer( speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
speech_config=self.speech_config, audio_config=audio_config
)
result = speech_synthesizer.speak_text(text) result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info( logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
)
reply = Reply(ReplyType.VOICE, fileName) reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error( logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
"[Azure] textToVoice error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply return reply

View File

@ -85,9 +85,7 @@ class BaiduVoice(Voice):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
with open(fileName, "wb") as f: with open(fileName, "wb") as f:
f.write(result) f.write(result)
logger.info( logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
)
reply = Reply(ReplyType.VOICE, fileName) reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error("[Baidu] textToVoice error={}".format(result)) logger.error("[Baidu] textToVoice error={}".format(result))

View File

@ -24,11 +24,7 @@ class GoogleVoice(Voice):
audio = self.recognizer.record(source) audio = self.recognizer.record(source)
try: try:
text = self.recognizer.recognize_google(audio, language="zh-CN") text = self.recognizer.recognize_google(audio, language="zh-CN")
logger.info( logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
"[Google] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
reply = Reply(ReplyType.TEXT, text) reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError: except speech_recognition.UnknownValueError:
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@ -42,9 +38,7 @@ class GoogleVoice(Voice):
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
tts = gTTS(text=text, lang="zh") tts = gTTS(text=text, lang="zh")
tts.save(mp3File) tts.save(mp3File)
logger.info( logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
)
reply = Reply(ReplyType.VOICE, mp3File) reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e: except Exception as e:
reply = Reply(ReplyType.ERROR, str(e)) reply = Reply(ReplyType.ERROR, str(e))

View File

@ -22,11 +22,7 @@ class OpenaiVoice(Voice):
result = openai.Audio.transcribe("whisper-1", file) result = openai.Audio.transcribe("whisper-1", file)
text = result["text"] text = result["text"]
reply = Reply(ReplyType.TEXT, text) reply = Reply(ReplyType.TEXT, text)
logger.info( logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
"[Openai] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
except Exception as e: except Exception as e:
reply = Reply(ReplyType.ERROR, str(e)) reply = Reply(ReplyType.ERROR, str(e))
finally: finally:

View File

@ -5,6 +5,7 @@ pytts voice service (offline)
import os import os
import sys import sys
import time import time
import pyttsx3 import pyttsx3
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
@ -12,6 +13,7 @@ from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from voice.voice import Voice from voice.voice import Voice
class PyttsVoice(Voice): class PyttsVoice(Voice):
engine = pyttsx3.init() engine = pyttsx3.init()
@ -20,7 +22,7 @@ class PyttsVoice(Voice):
self.engine.setProperty("rate", 125) self.engine.setProperty("rate", 125)
# 音量 # 音量
self.engine.setProperty("volume", 1.0) self.engine.setProperty("volume", 1.0)
if sys.platform == 'win32': if sys.platform == "win32":
for voice in self.engine.getProperty("voices"): for voice in self.engine.getProperty("voices"):
if "Chinese" in voice.name: if "Chinese" in voice.name:
self.engine.setProperty("voice", voice.id) self.engine.setProperty("voice", voice.id)
@ -33,13 +35,13 @@ class PyttsVoice(Voice):
def textToVoice(self, text): def textToVoice(self, text):
try: try:
# avoid the same filename # avoid the same filename
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav" wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
wavFile = TmpDir().path() + wavFileName wavFile = TmpDir().path() + wavFileName
logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))
self.engine.save_to_file(text, wavFile) self.engine.save_to_file(text, wavFile)
if sys.platform == 'win32': if sys.platform == "win32":
self.engine.runAndWait() self.engine.runAndWait()
else: else:
# In ubuntu, runAndWait do not really wait until the file created. # In ubuntu, runAndWait do not really wait until the file created.