这里我们需要在火山云语音控制台开通大模型的流式语音对话、获取豆包模型的apiKey,开通语音合成项目。
这里使用的豆包模型是Doubao-lite,延迟会更低一些
配置说明
这里一共有四个文件,分别是主要的fastAPI、LLM、STT、文件
TTS中需要配置
appid = "123" #填写控制台的APPID
token = "XXXX" #填写控制台上的Access Token
cluster = "XXXXX" #填写语音生成的组id
voice_type = "BV034_streaming" #这里是生成声音的类型选择host = "openspeech.bytedance.com" #无需更改
api_url = f"wss://{host}/api/v1/tts/ws_binary" #无需更改
LLM中配置
# 初始化客户端,传入 API 密钥self.client = Ark(api_key="XXXX")
在STT的146行中配置
header = {"X-Api-Resource-Id": "volc.bigasr.sauc.duration","X-Api-Access-Key": "XXXXX", #和TTS配置内容相同"X-Api-App-Key": "123", #和TTS配置内容相同"X-Api-Request-Id": reqid}
还有前端HTML的配置中记得根据自己服务的所在ip更改配置
前端测试html
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><title>WebSocket 音频传输测试</title><style>body {font-family: Arial, sans-serif;}#status {margin-bottom: 10px;}#messages {border: 1px solid #ccc;height: 200px;overflow-y: scroll;padding: 10px;}#controls {margin-top: 10px;}#controls button {margin-right: 5px;}#latency {margin-top: 10px;font-weight: bold;}</style>
</head>
<body><h1>WebSocket 音频传输测试</h1><!-- 显示当前连接状态 -->
<div id="status">状态:未连接</div><!-- 显示日志消息 -->
<div id="messages"></div><!-- 控制按钮 -->
<div id="controls"><button id="startButton">开始录音并发送</button><button id="stopButton" disabled>停止录音</button>
</div><!-- 延迟显示区域 -->
<div id="latency"></div><script>// WebSocket 服务器地址,请根据实际情况替换const wsUrl = 'ws://127.0.0.1:8000/ws';// 全局变量let socket = null; // WebSocket 实例const messagesDiv = document.getElementById('messages'); // 日志消息显示区域const statusDiv = document.getElementById('status'); // 连接状态显示区域const startButton = document.getElementById('startButton'); // 开始录音按钮const stopButton = document.getElementById('stopButton'); // 停止录音按钮let recordingAudioContext; // 音频录制上下文let audioInput; // 音频输入节点let processor; // 音频处理节点// 播放相关变量let playbackAudioContext;let playbackQueue = [];let playbackTime = 0;let isPlaying = false;// 延迟测量变量let overSentTime = null; // 记录发送 'over' 的时间let latencyMeasured = false; // 标记是否已经测量延迟/*** 向日志区域添加消息* @param {string} message - 要记录的消息*/function logMessage(message) {const p = document.createElement('p');p.textContent = message;messagesDiv.appendChild(p);messagesDiv.scrollTop = messagesDiv.scrollHeight; // 自动滚动到最新消息}/*** 初始化Playback AudioContext*/function initializePlayback() {playbackAudioContext = new (window.AudioContext || window.webkitAudioContext)();logMessage('Playback AudioContext 已创建');}/*** 解码并添加到播放队列* @param {ArrayBuffer} data - 接收到的音频数据*/function appendToPlaybackQueue(data) {playbackAudioContext.decodeAudioData(data, (audioBuffer) => {playbackQueue.push(audioBuffer);schedulePlayback();}, (error) => {logMessage('解码音频数据时出错:' + error);});}/*** 调度播放队列中的音频缓冲区*/function schedulePlayback() {if (isPlaying) return;if (playbackQueue.length === 0) return;// 获取下一个缓冲区const buffer = playbackQueue.shift();// 创建一个缓冲源const source = playbackAudioContext.createBufferSource();source.buffer = buffer;source.connect(playbackAudioContext.destination);// 如果 playbackTime 小于当前时间,则更新为当前时间if (playbackTime < playbackAudioContext.currentTime) {playbackTime = playbackAudioContext.currentTime;}// 计划在 playbackTime 播放source.start(playbackTime);logMessage(`Scheduled buffer to play at ${playbackTime.toFixed(2)}s`);// 更新 playbackTimeplaybackTime += buffer.duration;// 标记为正在播放isPlaying = true;// 当缓冲源播放结束时source.onended = () => {isPlaying = false;// 继续播放队列中的下一个缓冲区schedulePlayback();};}/*** 创建并连接 WebSocket*/function createWebSocket() {if (socket !== null && (socket.readyState === WebSocket.OPEN || socket.readyState === WebSocket.CONNECTING)) {logMessage('WebSocket 已经连接或正在连接中');return;}socket = new WebSocket(wsUrl);socket.binaryType = 'arraybuffer';socket.onopen = function () {statusDiv.textContent = '状态:已连接';logMessage('WebSocket 连接已打开');startButton.disabled = false; // 启用开始录音按钮};socket.onmessage = function (event) {// 如果接收到的是字符串且内容为 'over'if (typeof event.data === 'string' && event.data === 'over') {logMessage('收到结束信号: over');// 标记 MediaSource 结束return;}// 如果接收到的是二进制数据(ArrayBuffer)if (event.data instanceof ArrayBuffer) {logMessage('接收到音频数据');// 检查是否已经发送 'over' 并且尚未测量延迟if (overSentTime !== null && !latencyMeasured) {let receiveTime = performance.now();let latency = receiveTime - overSentTime;logMessage(`延迟时间:${latency.toFixed(2)} 毫秒`);document.getElementById('latency').textContent = `延迟时间:${latency.toFixed(2)} 毫秒`;latencyMeasured = true; // 标记为已测量}appendToPlaybackQueue(event.data); // 解码并添加到播放队列}};socket.onerror = function (error) {statusDiv.textContent = '状态:连接错误';logMessage('WebSocket 发生错误:' + error.message);};socket.onclose = function (event) {// 根据关闭代码判断关闭原因if (event.code === 1000) { // 正常关闭statusDiv.textContent = '状态:已断开连接';logMessage('WebSocket 正常关闭');} else {statusDiv.textContent = '状态:连接错误';logMessage(`WebSocket 关闭,代码:${event.code}, 原因:${event.reason}`);}startButton.disabled = false; // 启用开始录音按钮stopButton.disabled = true; // 禁用停止录音按钮};}/*** 初始化音频播放*/function initializeAudioPlayback() {initializePlayback();}/*** 开始录音并通过 WebSocket 发送音频数据*/function startRecording() {// 创建并连接 WebSocketcreateWebSocket();// 请求访问麦克风navigator.mediaDevices.getUserMedia({ audio: true }).then(function (stream) {// 创建音频上下文,设置采样率为16000HzrecordingAudioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });// 创建音频源节点,连接到麦克风输入流audioInput = recordingAudioContext.createMediaStreamSource(stream);// 创建脚本处理节点,用于处理音频数据processor = recordingAudioContext.createScriptProcessor(4096, 1, 1);// 连接音频节点audioInput.connect(processor);processor.connect(recordingAudioContext.destination);// 当音频处理节点有音频数据可处理时触发processor.onaudioprocess = function (e) {const audioData = e.inputBuffer.getChannelData(0); // 获取单声道音频数据const int16Data = floatTo16BitPCM(audioData); // 将浮点数PCM数据转换为16位PCMconst wavBuffer = encodeWAV(int16Data, recordingAudioContext.sampleRate); // 编码为WAV格式// 如果WebSocket连接打开,则发送WAV数据if (socket && socket.readyState === WebSocket.OPEN) {socket.send(wavBuffer);}};logMessage('开始录音并发送音频数据');startButton.disabled = true; // 禁用开始录音按钮stopButton.disabled = false; // 启用停止录音按钮initializeAudioPlayback(); // 初始化音频播放}).catch(function (err) {// 如果无法访问麦克风,则记录错误消息logMessage('无法访问麦克风:' + err.message);});}/*** 停止录音并关闭音频节点*/function stopRecording() {// 断开并释放音频处理节点if (processor) {processor.disconnect();processor = null;}// 断开并释放音频输入节点if (audioInput) {audioInput.disconnect();audioInput = null;}// 关闭音频上下文if (recordingAudioContext) {recordingAudioContext.close();recordingAudioContext = null;}logMessage('停止录音');startButton.disabled = false; // 启用开始录音按钮stopButton.disabled = true; // 禁用停止录音按钮// 通过WebSocket发送结束信号if (socket && socket.readyState === WebSocket.OPEN) {socket.send("over"); // 与后端约定的结束信号// 记录发送 'over' 的时间overSentTime = performance.now();latencyMeasured = false;logMessage('发送结束信号 "over"');}}/*** 将浮点数PCM数据转换为16位PCM数据* @param {Float32Array} float32Array - 浮点数PCM数据* @returns {Int16Array} 16位PCM数据*/function floatTo16BitPCM(float32Array) {const int16Array = new Int16Array(float32Array.length);for (let i = 0; i < float32Array.length; i++) {// 限制值在[-1, 1]范围内let s = Math.max(-1, Math.min(1, float32Array[i]));// 转换为16位整数int16Array[i] = s < 0 ? s * 0x8000 : s * 0x7FFF;}return int16Array;}/*** 编码PCM数据为WAV格式* @param {Int16Array} samples - 16位PCM数据* @param {number} sampleRate - 采样率* @returns {ArrayBuffer} WAV格式数据*/function encodeWAV(samples, sampleRate) {const buffer = new ArrayBuffer(44 + samples.length * 2); // WAV头部44字节 + PCM数据const view = new DataView(buffer);/* RIFF identifier */writeString(view, 0, 'RIFF');/* 文件长度 */view.setUint32(4, 36 + samples.length * 2, true);/* RIFF类型 */writeString(view, 8, 'WAVE');/* 格式块标识符 */writeString(view, 12, 'fmt ');/* 格式块长度 */view.setUint32(16, 16, true);/* 音频格式(1为PCM) */view.setUint16(20, 1, true);/* 声道数(1为单声道) */view.setUint16(22, 1, true);/* 采样率 */view.setUint32(24, sampleRate, true);/* 字节率(采样率 * 声道数 * 每个样本的字节数) */view.setUint32(28, sampleRate * 2, true);/* 块对齐(声道数 * 每个样本的字节数) */view.setUint16(32, 2, true);/* 每个样本的位数 */view.setUint16(34, 16, true);/* 数据块标识符 */writeString(view, 36, 'data');/* 数据块长度 */view.setUint32(40, samples.length * 2, true);// 写入PCM采样数据let offset = 44;for (let i = 0; i < samples.length; i++, offset += 2) {view.setInt16(offset, samples[i], true);}return buffer;}/*** 将字符串写入DataView* @param {DataView} view - DataView实例* @param {number} offset - 写入起始位置* @param {string} string - 要写入的字符串*/function writeString(view, offset, string) {for (let i = 0; i < string.length; i++) {view.setUint8(offset + i, string.charCodeAt(i));}}// 事件绑定/*** 绑定开始录音按钮的点击事件*/startButton.addEventListener('click', function () {startRecording();});/*** 绑定停止录音按钮的点击事件*/stopButton.addEventListener('click', function () {stopRecording();});/*** 页面加载完成后不再自动连接到WebSocket服务器* 连接将在用户点击“开始录音并发送”时创建*/window.onload = function () {logMessage('请点击 "开始录音并发送" 按钮以开始录音');};
</script></body>
</html>
后端fastAPI的服务入口
import asyncio
import re
import time
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from STT import generate_Ws, segment_data_processor
from LLM import LLMDobaoClient
from TTS import long_sentence,create_tts_wsrouter = APIRouter()@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):await websocket.accept()logger.info("WebSocket connection accepted")audio_result = ""audio_ws=Nonetts_ws=Noneseq = 1 # 将 语音识别序列号seq 初始化放在循环外部llm_client = LLMDobaoClient()llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")try:while True:message = await websocket.receive()if 'bytes' in message:data = message['bytes']# print("接收到数据的大小:", len(data))if audio_ws is None:audio_ws=await generate_Ws()tts_ws=await create_tts_ws()if data is not None:audio_result = await segment_data_processor(audio_ws,data, seq)if audio_result is not None:print("识别结果:",audio_result)seq += 1elif 'text' in message:# 大模型交互llm_client.add_user_message(audio_result)audio_result = "" # 清空识别结果#TTS的ws连接#这里是TTS的语音文件的保存,如果需要请取消下面和TTS中的相关注释# file_to_save = open("test.mp3", "ab") # 使用追加模式打开文件,以便保存多个段落的音频 file_to_save="123"result=""seq=1for response in llm_client.stream_response():result += responseif len(result) > 100:# 查找最接近50个字符的标点符号位置cut_pos = find_nearest_punctuation(result, 50)# 拼接缓存并截取到cut_pos位置的文本full_result = result[:cut_pos+1] # 包含标点符号print(full_result)await long_sentence(full_result, file_to_save, tts_ws,websocket) # 处理分割后的文本# 更新 result 和缓存result = result[cut_pos+1:] # 剩余未处理的部分# 处理结束时的剩余缓存if result:print(result)await long_sentence(result, file_to_save, tts_ws,websocket)await websocket.send_text("over")print("============ 结束 ==============")except WebSocketDisconnect:logger.info("WebSocket disconnected")if audio_ws is not None:await audio_ws.close()await audio_ws.close()except Exception as e:logger.error(f"WebSocket Error: {e}")if audio_ws is not None:await audio_ws.close()await tts_ws.close()await websocket.close()def find_nearest_punctuation(text, max_length):"""查找距离max_length最接近的标点符号位置"""# 使用正则表达式查找所有标点符号punctuation_matches = [m.start() for m in re.finditer(r'[,。!?;]', text)]# 如果没有标点符号,返回max_length作为分割点if not punctuation_matches:return max_length# 找到离max_length最近的标点符号nearest_punctuation = max_lengthfor pos in punctuation_matches:if pos <= max_length:nearest_punctuation = poselse:break # 当位置超过max_length时停止遍历return nearest_punctuation
LLM文件代码
from volcenginesdkarkruntime import Ark
from app.config.settings import settingsclass LLMDobaoClient:def __init__(self):# 初始化客户端,传入 API 密钥self.client = Ark(api_key="")# 存储对话消息的列表self.messages = []def add_user_message(self, content):"""添加一条用户消息到对话历史中"""self.messages.append({"role": "user", "content": content})def add_system_message(self,content):"""添加一条用户消息到对话历史中"""self.messages.append({"role": "system", "content": content})def add_assistant_message(self, content):"""添加一条消息到对话历史中"""self.messages.append({"role": "assistant", "content": content})def clear_messages(self):"""清除所有对话历史中的消息"""self.messages = []def print_messages(self):"""打印对话历史中的消息"""print(self.messages)def stream_response(self, model="ep-20241013161850-bnqsx"):"""基于当前消息流式获取模型的响应,并将完整响应添加到消息中"""print("----- 流式请求开始 -----")full_response = ""stream = self.client.chat.completions.create(model=model,messages=self.messages,stream=True)for chunk in stream:if not chunk.choices:continuecontent = chunk.choices[0].delta.contentfull_response += contentyield content# print(content, end="")# print() # 在流式输出完成后添加一个换行符# 将完整响应添加到对话历史中self.add_assistant_message(full_response)if __name__ == "__main__":llm_client = LLMDobaoClient()# 添加初始消息llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")llm_client.add_user_message("请你讲个小故事")# 流式获取响应for response in llm_client.stream_response():# 这里可以处理每个响应片段passllm_client.print_messages()print("\n流式请求完成。")
STT文件代码
import asyncio
import gzip
import json
import uuid
import traceback
import websockets
from app.config.settings import settings
# from settings import settingsPROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001# Message Type:
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111# Message Type Specific Flags
NO_SEQUENCE = 0b0000 # no check sequence
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
NEG_SEQUENCE_1 = 0b0011# Message Serialization
NO_SERIALIZATION = 0b0000
JSON = 0b0001# Message Compression
NO_COMPRESSION = 0b0000
GZIP = 0b0001# 生成请求头
def generate_header(message_type=FULL_CLIENT_REQUEST,message_type_specific_flags=NO_SEQUENCE,serial_method=JSON,compression_type=GZIP,reserved_data=0x00
):"""protocol_version(4 bits), header_size(4 bits),message_type(4 bits), message_type_specific_flags(4 bits)serialization_method(4 bits) message_compression(4 bits)reserved (8bits) 保留字段"""header = bytearray()header_size = 1header.append((PROTOCOL_VERSION << 4) | header_size)header.append((message_type << 4) | message_type_specific_flags)header.append((serial_method << 4) | compression_type)header.append(reserved_data)return header# 添加序列号信息
def generate_before_payload(sequence: int):before_payload = bytearray()before_payload.extend(sequence.to_bytes(4, 'big', signed=True)) # sequencereturn before_payload# 解析服务器响应
def parse_response(res):"""protocol_version(4 bits), header_size(4 bits),message_type(4 bits), message_type_specific_flags(4 bits)serialization_method(4 bits) message_compression(4 bits)reserved (8bits) 保留字段header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )payload 类似与http 请求体"""protocol_version = res[0] >> 4header_size = res[0] & 0x0fmessage_type = res[1] >> 4message_type_specific_flags = res[1] & 0x0fserialization_method = res[2] >> 4message_compression = res[2] & 0x0freserved = res[3]header_extensions = res[4:header_size * 4]payload = res[header_size * 4:]result = {'is_last_package': False,}payload_msg = Nonepayload_size = 0if message_type_specific_flags & 0x01:# receive frame with sequenceseq = int.from_bytes(payload[:4], "big", signed=True)result['payload_sequence'] = seqpayload = payload[4:]if message_type_specific_flags & 0x02:# receive last packageresult['is_last_package'] = Trueif message_type == FULL_SERVER_RESPONSE:payload_size = int.from_bytes(payload[:4], "big", signed=True)payload_msg = payload[4:]elif message_type == SERVER_ACK:seq = int.from_bytes(payload[:4], "big", signed=True)result['seq'] = seqif len(payload) >= 8:payload_size = int.from_bytes(payload[4:8], "big", signed=False)payload_msg = payload[8:]elif message_type == SERVER_ERROR_RESPONSE:code = int.from_bytes(payload[:4], "big", signed=False)result['code'] = codepayload_size = int.from_bytes(payload[4:8], "big", signed=False)payload_msg = payload[8:]if payload_msg is None:return resultif message_compression == GZIP:payload_msg = gzip.decompress(payload_msg)if serialization_method == JSON:payload_msg = json.loads(str(payload_msg, "utf-8"))elif serialization_method != NO_SERIALIZATION:payload_msg = str(payload_msg, "utf-8")result['payload_msg'] = payload_msgresult['payload_size'] = payload_sizereturn result#建立ws连接
async def generate_Ws():print("开始建立ws连接")try:reqid = str(uuid.uuid4())request_params = {"user": {"uid": reqid,},"audio": {'format': settings.format,"rate": settings.framerate,"bits": settings.bits,"channel": settings.nchannels,"codec": "raw"}}payload_bytes = gzip.compress(json.dumps(request_params).encode('utf-8'))full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE))full_client_request.extend(generate_before_payload(sequence=1))full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))full_client_request.extend(payload_bytes)header = {"X-Api-Resource-Id": "volc.bigasr.sauc.duration","X-Api-Access-Key": "XXXXXXXX","X-Api-App-Key": "XXXXXXX","X-Api-Request-Id": reqid}# 使用 await 获取实际的 ws 对象ws = await websockets.connect("wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", extra_headers=header, max_size=1000000000)print("连接成功")await ws.send(full_client_request)res = await ws.recv()result = parse_response(res)print("******************")print("sauc result", result)print("******************")return ws # 返回 ws 对象except websockets.exceptions.ConnectionClosedError as e:print(f"WebSocket connection closed with error: {e}")except websockets.exceptions.InvalidStatusCode as e:print(f"WebSocket connection failed with status code: {e.status_code}")except Exception as e:print(f"An error occurred: {e}")print(f"Exception type: {type(e)}")print("Stack trace:")traceback.print_exc()#发送数据
async def segment_data_processor(ws,audio_data, seq):try:# 压缩当前的音频数据分段payload_bytes = gzip.compress(audio_data)except OSError as e:print(f"压缩音频数据时出错: {e}")return Nonetry:# 生成音频数据的请求头,如果是最后一段,使用负序列的标志audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE))# if seq == -1:# audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=NEG_WITH_SEQUENCE))# 将当前音频段的序列号添加到请求中audio_only_request.extend(generate_before_payload(sequence=seq))# 将音频段数据的大小(4字节)附加到请求中audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))# 将压缩后的音频数据附加到请求中并发送audio_only_request.extend(payload_bytes)# 发送请求await ws.send(audio_only_request)# 接收服务器响应res = await ws.recv()# 解析服务器响应result = parse_response(res)# json_start_index = audio_result.find(b'{')# json_data = audio_result[json_start_index:]# decoded_str = json_data.decode('utf-8')# parsed_result = json.loads(decoded_str)return result["payload_msg"]["result"]["text"]except websockets.exceptions.ConnectionClosedError as e:print(f"WebSocket 连接关闭,状态码: {e.code}, 原因: {e.reason}")return Noneexcept websockets.exceptions.WebSocketException as e:print(f"WebSocket 连接错误: {e}")return Noneexcept Exception as e:print(f"处理音频段时发生未知错误: {e}")return None#接收数据
async def receive_data(ws):while True:res = await ws.recv()# print(res)result = parse_response(res)# print("******************")print("sauc result", result)# print("******************")return result
#在这里创建一个主函数调用generate_Ws
async def main():ws = await generate_Ws()if ws is not None:print("ws is not None")ws.close()if __name__ == "__main__":asyncio.run(main())
TTS文件代码
import asyncio
import websockets
import uuid
import json
import gzip
import copyMESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",2: "last message from server (seq < 0)", 3: "sequence number < 0"}
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}appid = "123"
token = "XXXXXX"
cluster = "XXXXXXX"
voice_type = "BV034_streaming"
host = "openspeech.bytedance.com"
api_url = f"wss://{host}/api/v1/tts/ws_binary"# version: b0001 (4 bits)
# header size: b0001 (4 bits)
# message type: b0001 (Full client request) (4bits)
# message type specific flags: b0000 (none) (4bits)
# message serialization method: b0001 (JSON) (4 bits)
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')request_json = {"app": {"appid": appid,"token": "access_token","cluster": cluster},"user": {"uid": "388808087185088"},"audio": {"voice_type": "xxx","encoding": "mp3","speed_ratio": 1.0,"volume_ratio": 1.0,"pitch_ratio": 1.0,},"request": {"reqid": "xxx","text": "字节跳动语音合成。","text_type": "plain","operation": "xxx"}
}# 分割长句子并逐段合成音频
async def long_sentence(text,file,ws,websocket):# 将长句子分成较短的段落# segments = [text[i:i+50] for i in range(0, len(text), 50)] # for i, segment in enumerate(segments):request_json["request"]["text"] = textawait test_submit(request_json, file,ws,websocket)async def create_tts_ws():header = {"Authorization": f"Bearer; {token}"}ws=await websockets.connect(api_url, extra_headers=header, ping_interval=None)return ws# 异步函数,提交文本请求以进行语音合成
async def test_submit(request_json, file,ws,websocket):submit_request_json = copy.deepcopy(request_json)submit_request_json["audio"]["voice_type"] = voice_typesubmit_request_json["request"]["reqid"] = str(uuid.uuid4())submit_request_json["request"]["operation"] = "submit"payload_bytes = str.encode(json.dumps(submit_request_json))payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this linefull_client_request = bytearray(default_header) full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)full_client_request.extend(payload_bytes) # payload# print("\n------------------------ test 'submit' -------------------------")# print("request json: ", submit_request_json)# print("\nrequest bytes: ", full_client_request)await ws.send(full_client_request)while True:res = await ws.recv()done = await parse_response(res, file,websocket)if done:break# 解析服务器返回的响应消息
async def parse_response(res, file,websocket):# 解析响应头部的各个字段protocol_version = res[0] >> 4header_size = res[0] & 0x0fmessage_type = res[1] >> 4message_type_specific_flags = res[1] & 0x0fserialization_method = res[2] >> 4message_compression = res[2] & 0x0freserved = res[3]header_extensions = res[4:header_size*4]payload = res[header_size*4:]# print(f" Protocol version: {protocol_version:#x} - version {protocol_version}")# print(f" Header size: {header_size:#x} - {header_size * 4} bytes ")# print(f" Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}")# print(f" Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}")# print(f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}")# print(f" Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}")# print(f" Reserved: {reserved:#04x}")# if header_size != 1:# print(f" Header extensions: {header_extensions}")# 根据消息类型对响应进行处理if message_type == 0xb: # 处理音频服务器响应if message_type_specific_flags == 0: # 无序列号作为ACK# print(" Payload size: 0")return Falseelse:sequence_number = int.from_bytes(payload[:4], "big", signed=True)payload_size = int.from_bytes(payload[4:8], "big", signed=False)payload = payload[8:]# print(f" Sequence number: {sequence_number}")# print(f" Payload size: {payload_size} bytes")# file.write(payload)await websocket.send_bytes(payload)if sequence_number < 0: # 如果序列号为负,表示结束return Trueelse:return Falseelif message_type == 0xf: # 处理错误消息code = int.from_bytes(payload[:4], "big", signed=False)msg_size = int.from_bytes(payload[4:8], "big", signed=False)error_msg = payload[8:]if message_compression == 1:error_msg = gzip.decompress(error_msg)error_msg = str(error_msg, "utf-8")print(f" Error message code: {code}")print(f" Error message size: {msg_size} bytes")print(f" Error message: {error_msg}")return Trueelif message_type == 0xc: # 处理前端消息msg_size = int.from_bytes(payload[:4], "big", signed=False)payload = payload[4:]if message_compression == 1:payload = gzip.decompress(payload)print(f" Frontend message: {payload}")else:print("undefined message type!")return True# 主程序入口
if __name__ == '__main__':loop = asyncio.get_event_loop()long_text = "这是一个很长的句子,需要分成多个段落来逐步合成语音,以便处理。" # 示例长句子try:loop.run_until_complete(test_long_sentence(long_text))finally:loop.close()