304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""
|
||
CosyVoice API 服务类
|
||
负责与CosyVoice API的交互
|
||
"""
|
||
import os
|
||
import logging
|
||
from typing import Optional, Dict, Any, Tuple
|
||
from gradio_client import Client, handle_file
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class CosyVoiceService:
|
||
"""CosyVoice API服务类"""
|
||
|
||
def __init__(self, api_url: str = "http://127.0.0.1:8080/"):
|
||
self.api_url = api_url
|
||
self.client = None
|
||
|
||
def connect(self) -> bool:
|
||
"""连接到CosyVoice服务"""
|
||
try:
|
||
self.client = Client(self.api_url)
|
||
logger.info(f"成功连接到CosyVoice服务: {self.api_url}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"连接CosyVoice服务失败: {str(e)}")
|
||
return False
|
||
|
||
def get_available_voices(self) -> list:
|
||
"""获取可用的音色列表"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return []
|
||
|
||
result = self.client.predict(api_name="/refresh_sft_spk")
|
||
|
||
# 处理返回的字典格式
|
||
if isinstance(result, dict) and 'choices' in result:
|
||
# 从choices中提取音色名称
|
||
voices = [choice[0] for choice in result['choices'] if choice[0] != '.ipynb_checkpoints']
|
||
return voices
|
||
elif isinstance(result, list):
|
||
# 直接是列表格式
|
||
return [voice for voice in result if voice != '.ipynb_checkpoints']
|
||
else:
|
||
logger.error(f"未知的音色列表格式: {result}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取音色列表失败: {str(e)}")
|
||
return []
|
||
|
||
def get_reference_audios(self) -> list:
|
||
"""获取参考音频列表"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return []
|
||
|
||
result = self.client.predict(api_name="/refresh_prompt_wav")
|
||
|
||
# 处理返回的字典格式
|
||
if isinstance(result, dict) and 'choices' in result:
|
||
audios = [choice[0] for choice in result['choices']]
|
||
return audios
|
||
elif isinstance(result, list):
|
||
return result
|
||
else:
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取参考音频列表失败: {str(e)}")
|
||
return []
|
||
|
||
def recognize_audio(self, audio_file_path: str) -> str:
|
||
"""语音识别:将音频转换为文本"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return ""
|
||
|
||
text = self.client.predict(
|
||
prompt_wav=handle_file(audio_file_path),
|
||
api_name="/prompt_wav_recognition"
|
||
)
|
||
return text if isinstance(text, str) else ""
|
||
except Exception as e:
|
||
logger.error(f"语音识别失败: {str(e)}")
|
||
return ""
|
||
|
||
def generate_speech_with_preset_voice(
|
||
self,
|
||
text: str,
|
||
voice: str = "中文女",
|
||
seed: int = 42,
|
||
speed: float = 1.0,
|
||
stream: bool = False
|
||
) -> Tuple[Optional[str], Optional[str]]:
|
||
"""使用预训练音色生成语音"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return None, None
|
||
|
||
# 创建临时空音频文件用于占位
|
||
import tempfile
|
||
import wave
|
||
import numpy as np
|
||
|
||
# 创建一个短的静音音频作为占位符
|
||
temp_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
||
with wave.open(temp_audio.name, 'w') as wav_file:
|
||
wav_file.setnchannels(1) # 单声道
|
||
wav_file.setsampwidth(2) # 16位
|
||
wav_file.setframerate(16000) # 16kHz采样率
|
||
# 写入很短的静音(0.1秒)
|
||
silence = np.zeros(1600, dtype=np.int16)
|
||
wav_file.writeframes(silence.tobytes())
|
||
|
||
result = self.client.predict(
|
||
tts_text=text,
|
||
mode_checkbox_group="预训练音色",
|
||
sft_dropdown=voice,
|
||
prompt_text="",
|
||
prompt_wav_upload=handle_file(temp_audio.name),
|
||
prompt_wav_record=handle_file(temp_audio.name),
|
||
instruct_text="",
|
||
seed=float(seed),
|
||
stream="True" if stream else "False",
|
||
speed=float(speed),
|
||
api_name="/generate_audio"
|
||
)
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(temp_audio.name)
|
||
except:
|
||
pass
|
||
|
||
# result是一个元组 [流式音频路径, 完整音频路径]
|
||
if isinstance(result, (list, tuple)) and len(result) >= 2:
|
||
return result[0], result[1]
|
||
else:
|
||
return result, result
|
||
|
||
except Exception as e:
|
||
logger.error(f"预训练音色语音生成失败: {str(e)}")
|
||
return None, None
|
||
|
||
def generate_speech_with_voice_cloning(
|
||
self,
|
||
text: str,
|
||
reference_audio_path: str,
|
||
reference_text: str = "",
|
||
seed: int = 42
|
||
) -> Tuple[Optional[str], Optional[str]]:
|
||
"""使用语音克隆生成语音"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return None, None
|
||
|
||
# 如果没有提供参考文本,先进行语音识别
|
||
if not reference_text:
|
||
reference_text = self.recognize_audio(reference_audio_path)
|
||
if not reference_text:
|
||
logger.warning("参考音频识别失败,使用空文本")
|
||
reference_text = ""
|
||
|
||
result = self.client.predict(
|
||
tts_text=text,
|
||
mode_checkbox_group="3s极速复刻",
|
||
sft_dropdown="中文女",
|
||
prompt_text=reference_text,
|
||
prompt_wav_upload=handle_file(reference_audio_path),
|
||
prompt_wav_record=handle_file(reference_audio_path),
|
||
instruct_text="",
|
||
seed=float(seed),
|
||
stream="False",
|
||
speed=1.0,
|
||
api_name="/generate_audio"
|
||
)
|
||
|
||
if isinstance(result, (list, tuple)) and len(result) >= 2:
|
||
return result[0], result[1]
|
||
else:
|
||
return result, result
|
||
|
||
except Exception as e:
|
||
logger.error(f"语音克隆生成失败: {str(e)}")
|
||
return None, None
|
||
|
||
def generate_speech_with_natural_control(
|
||
self,
|
||
text: str,
|
||
instruction: str = "请用温柔甜美的女声朗读",
|
||
seed: int = 42
|
||
) -> Tuple[Optional[str], Optional[str]]:
|
||
"""使用自然语言控制生成语音"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return None, None
|
||
|
||
# 创建临时空音频文件用于占位
|
||
import tempfile
|
||
import wave
|
||
import numpy as np
|
||
|
||
# 创建一个短的静音音频作为占位符
|
||
temp_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
||
with wave.open(temp_audio.name, 'w') as wav_file:
|
||
wav_file.setnchannels(1) # 单声道
|
||
wav_file.setsampwidth(2) # 16位
|
||
wav_file.setframerate(16000) # 16kHz采样率
|
||
# 写入很短的静音(0.1秒)
|
||
silence = np.zeros(1600, dtype=np.int16)
|
||
wav_file.writeframes(silence.tobytes())
|
||
|
||
result = self.client.predict(
|
||
tts_text=text,
|
||
mode_checkbox_group="自然语言控制",
|
||
sft_dropdown="中文女",
|
||
prompt_text="",
|
||
prompt_wav_upload=handle_file(temp_audio.name),
|
||
prompt_wav_record=handle_file(temp_audio.name),
|
||
instruct_text=instruction,
|
||
seed=float(seed),
|
||
stream="False",
|
||
speed=1.0,
|
||
api_name="/generate_audio"
|
||
)
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(temp_audio.name)
|
||
except:
|
||
pass
|
||
|
||
if isinstance(result, (list, tuple)) and len(result) >= 2:
|
||
return result[0], result[1]
|
||
else:
|
||
return result, result
|
||
|
||
except Exception as e:
|
||
logger.error(f"自然语言控制语音生成失败: {str(e)}")
|
||
return None, None
|
||
|
||
def generate_random_seed(self) -> int:
|
||
"""生成随机种子"""
|
||
try:
|
||
if not self.client:
|
||
if not self.connect():
|
||
return 42
|
||
|
||
result = self.client.predict(api_name="/generate_random_seed")
|
||
|
||
# 处理返回的字典格式
|
||
if isinstance(result, dict) and 'value' in result:
|
||
seed = int(result['value'])
|
||
elif isinstance(result, (int, float)):
|
||
seed = int(result)
|
||
elif isinstance(result, str) and result.isdigit():
|
||
seed = int(result)
|
||
else:
|
||
logger.warning(f"未知的随机种子格式: {result}")
|
||
seed = 42
|
||
|
||
return seed
|
||
except Exception as e:
|
||
logger.error(f"生成随机种子失败: {str(e)}")
|
||
return 42
|
||
|
||
def test_connection(self) -> Dict[str, Any]:
|
||
"""测试与CosyVoice服务的连接"""
|
||
try:
|
||
if not self.connect():
|
||
return {
|
||
"success": False,
|
||
"message": "无法连接到CosyVoice服务",
|
||
"api_url": self.api_url
|
||
}
|
||
|
||
# 尝试获取音色列表来测试连接
|
||
voices = self.get_available_voices()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "CosyVoice服务连接成功",
|
||
"api_url": self.api_url,
|
||
"available_voices": voices
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"message": f"连接测试失败: {str(e)}",
|
||
"api_url": self.api_url
|
||
}
|
||
|
||
# 全局服务实例
|
||
cosyvoice_service = CosyVoiceService()
|