training code done
This commit is contained in:
@@ -21,7 +21,9 @@ class TTS(nn.Module):
|
||||
def __init__(self,
|
||||
language,
|
||||
device='auto',
|
||||
use_hf=True):
|
||||
use_hf=True,
|
||||
config_path=None,
|
||||
ckpt_path=None):
|
||||
super().__init__()
|
||||
if device == 'auto':
|
||||
device = 'cpu'
|
||||
@@ -31,7 +33,7 @@ class TTS(nn.Module):
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
# config_path =
|
||||
hps = load_or_download_config(language, use_hf=use_hf)
|
||||
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
|
||||
|
||||
num_languages = hps.num_languages
|
||||
num_tones = hps.num_tones
|
||||
@@ -54,7 +56,7 @@ class TTS(nn.Module):
|
||||
self.device = device
|
||||
|
||||
# load state_dict
|
||||
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
|
||||
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
|
||||
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
|
||||
|
||||
language = language.split('_')[0]
|
||||
|
||||
Reference in New Issue
Block a user