8
def train():
# Model
model = Model()
# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)
# Summaries
summary_op = summaries(model, loss_fn)
with tf.Session(config=TrainConfig.session_conf) as sess:
# Initialized, Load state
sess.run(tf.global_variables_initializer())
model.load_state(sess, TrainConfig.CKPT_PATH)
writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)
# Input source
data = Data(TrainConfig.DATA_PATH)
loss = Diff()
for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):
mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)
mixed_spec = to_spectrogram(mixed_wav)
mixed_mag = get_magnitude(mixed_spec)
src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)
src1_batch, _ = model.spec_to_batch(src1_mag)
src2_batch, _ = model.spec_to_batch(src2_mag)
mixed_batch, _ = model.spec_to_batch(mixed_mag)
# Initializae our callback.
#early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)
l, _, summary = sess.run([loss_fn, optimizer, summary_op],
feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
model.y_src2: src2_batch})
loss.update(l)
print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))
writer.add_summary(summary, global_step=step)
# Save state
if step % TrainConfig.CKPT_STEP == 0:
tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)
writer.close()
.wav dosyasındaki bir sesten müziği ayıran bu yapay ağ kodum var. Tren bölümünü durdurmak için erken bir durdurma algoritmasını nasıl tanıtabilirim? ValidationMonitor hakkında konuşacak bir proje görüyorum. Birisi bana yardım edebilir mi?tensorflow'da erken durdurmanın nasıl gerçekleştirileceği