2017-09-26 52 views
8
def train(): 
# Model 
model = Model() 

# Loss, Optimizer 
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step') 
loss_fn = model.loss() 
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step) 

# Summaries 
summary_op = summaries(model, loss_fn) 

with tf.Session(config=TrainConfig.session_conf) as sess: 

    # Initialized, Load state 
    sess.run(tf.global_variables_initializer()) 
    model.load_state(sess, TrainConfig.CKPT_PATH) 

    writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph) 

    # Input source 
    data = Data(TrainConfig.DATA_PATH) 

    loss = Diff() 
    for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP): 

      mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step) 

      mixed_spec = to_spectrogram(mixed_wav) 
      mixed_mag = get_magnitude(mixed_spec) 

      src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav) 
      src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec) 

      src1_batch, _ = model.spec_to_batch(src1_mag) 
      src2_batch, _ = model.spec_to_batch(src2_mag) 
      mixed_batch, _ = model.spec_to_batch(mixed_mag) 

      # Initializae our callback. 
      #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5) 


      l, _, summary = sess.run([loss_fn, optimizer, summary_op], 
            feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch, 
               model.y_src2: src2_batch}) 

      loss.update(l) 
      print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value)) 

      writer.add_summary(summary, global_step=step) 

      # Save state 
      if step % TrainConfig.CKPT_STEP == 0: 
       tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step) 

    writer.close() 

.wav dosyasındaki bir sesten müziği ayıran bu yapay ağ kodum var. Tren bölümünü durdurmak için erken bir durdurma algoritmasını nasıl tanıtabilirim? ValidationMonitor hakkında konuşacak bir proje görüyorum. Birisi bana yardım edebilir mi?tensorflow'da erken durdurmanın nasıl gerçekleştirileceği

cevap

0

DoğrulamaMonitör kullanımdan kaldırıldı olarak işaretlendi. tavsiye edilmez. ama yine de kullanabilirsin. Burada benim benim uygulanmasını

validation_monitor = monitors.ValidationMonitor(
     input_fn=functools.partial(input_fn, subset="evaluation"), 
     eval_steps=128, 
     every_n_steps=88, 
     early_stopping_metric="accuracy", 
     early_stopping_rounds = 1000 
    ) 

ve kendiniz uygulayabilirsiniz:

  if (loss_value < self.best_loss): 
      self.stopping_step = 0 
      self.best_loss = loss_value 
      else: 
      self.stopping_step += 1 
      if self.stopping_step >= FLAGS.early_stopping_step: 
      self.should_stop = True 
      print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value)) 
      run_context.request_stop() 
burada biri nasıl oluşturulacağı bir örnektir