training code done

This commit is contained in:
wl-zhao
2024-03-10 13:05:02 +00:00
parent c9c57a17f4
commit 7ade7b740e
16 changed files with 1533 additions and 47 deletions

View File

@@ -21,7 +21,9 @@ class TTS(nn.Module):
def __init__(self,
language,
device='auto',
use_hf=True):
use_hf=True,
config_path=None,
ckpt_path=None):
super().__init__()
if device == 'auto':
device = 'cpu'
@@ -31,7 +33,7 @@ class TTS(nn.Module):
assert torch.cuda.is_available()
# config_path =
hps = load_or_download_config(language, use_hf=use_hf)
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
num_languages = hps.num_languages
num_tones = hps.num_tones
@@ -54,7 +56,7 @@ class TTS(nn.Module):
self.device = device
# load state_dict
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]