training code done
This commit is contained in:
@@ -9,9 +9,9 @@ from scipy.io.wavfile import read
|
||||
import torch
|
||||
import torchaudio
|
||||
import librosa
|
||||
from .text import cleaned_text_to_sequence, get_bert
|
||||
from .text.cleaner import clean_text
|
||||
from . import commons
|
||||
from melo.text import cleaned_text_to_sequence, get_bert
|
||||
from melo.text.cleaner import clean_text
|
||||
from melo import commons
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
@@ -60,8 +60,8 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
iteration = checkpoint_dict["iteration"]
|
||||
learning_rate = checkpoint_dict["learning_rate"]
|
||||
iteration = checkpoint_dict.get("iteration", 0)
|
||||
learning_rate = checkpoint_dict.get("learning_rate", 0.)
|
||||
if (
|
||||
optimizer is not None
|
||||
and not skip_optimizer
|
||||
@@ -92,6 +92,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
||||
v.shape,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# For upgrading from the old version
|
||||
if "ja_bert_proj" in k:
|
||||
v = torch.zeros_like(v)
|
||||
@@ -249,7 +250,9 @@ def get_hparams(init=True):
|
||||
default="./configs/base.json",
|
||||
help="JSON file for configuration",
|
||||
)
|
||||
parser.add_argument('--local-rank', type=int, default=0)
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument('--world-size', type=int, default=1)
|
||||
parser.add_argument('--port', type=int, default=10000)
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
|
||||
parser.add_argument('--pretrain_G', type=str, default=None,
|
||||
help='pretrain model')
|
||||
@@ -280,6 +283,7 @@ def get_hparams(init=True):
|
||||
hparams.pretrain_G = args.pretrain_G
|
||||
hparams.pretrain_D = args.pretrain_D
|
||||
hparams.pretrain_dur = args.pretrain_dur
|
||||
hparams.port = args.port
|
||||
return hparams
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user