training code done

This commit is contained in:
wl-zhao
2024-03-10 13:05:02 +00:00
parent c9c57a17f4
commit 7ade7b740e
16 changed files with 1533 additions and 47 deletions

View File

@@ -9,9 +9,9 @@ from scipy.io.wavfile import read
import torch
import torchaudio
import librosa
from .text import cleaned_text_to_sequence, get_bert
from .text.cleaner import clean_text
from . import commons
from melo.text import cleaned_text_to_sequence, get_bert
from melo.text.cleaner import clean_text
from melo import commons
MATPLOTLIB_FLAG = False
@@ -60,8 +60,8 @@ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
iteration = checkpoint_dict.get("iteration", 0)
learning_rate = checkpoint_dict.get("learning_rate", 0.)
if (
optimizer is not None
and not skip_optimizer
@@ -92,6 +92,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
v.shape,
)
except Exception as e:
print(e)
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
@@ -249,7 +250,9 @@ def get_hparams(init=True):
default="./configs/base.json",
help="JSON file for configuration",
)
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world-size', type=int, default=1)
parser.add_argument('--port', type=int, default=10000)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
parser.add_argument('--pretrain_G', type=str, default=None,
help='pretrain model')
@@ -280,6 +283,7 @@ def get_hparams(init=True):
hparams.pretrain_G = args.pretrain_G
hparams.pretrain_D = args.pretrain_D
hparams.pretrain_dur = args.pretrain_dur
hparams.port = args.port
return hparams