Nano Hash - криптовалюты, майнинг, программирование

Потери Tensorflow RNN LM не уменьшаются

Я пытаюсь обучить базовую однонаправленную языковую модель LSTM RNN в банке PennTree. Моя нейросеть работает, но потери на тестовом наборе совсем не уменьшаются. Мне интересно, почему это?

Параметры сети:

V = 10000
batch_size = 20
hidden_size = 650
embed_size = hidden_size
num_unrollings = 35
max_epoch = 6
learning_rate = 1.0

Определение графика:

graph = tf.Graph()
with graph.as_default():
  cell_state = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), name="CellState")
  hidden_state = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), name="HiddenState")
  curr_batch = tf.placeholder(tf.int32, shape=[num_unrollings + 1, batch_size])
  lstm = tf.contrib.rnn.BasicLSTMCell(hidden_size)
  embeddings = tf.Variable(tf.truncated_normal([V, embed_size], -0.1, 0.1), trainable=True, dtype=tf.float32)
  W = tf.Variable(tf.truncated_normal([hidden_size, V], -0.1, 0.1))
  b = tf.Variable(tf.zeros(V))

  inputs = curr_batch[:num_unrollings,:] # num_unrollings x batch_size
  labels = curr_batch[1:, :] # num_unrollings x batch_size

  input_list = list()
  for t in range(num_unrollings):
    emb = tf.nn.embedding_lookup(embeddings, inputs[t,:])
    input_list.append(emb)

  outputs, states = tf.nn.static_rnn(lstm, input_list, initial_state=[cell_state, hidden_state])  # outputs: num_unrollings x batch_size x hidden
  cell_state, hidden_state = states
  outputs_flat = tf.reshape(outputs, [-1, lstm.output_size]) # output_flat: (num_unrollings x batch_size) x hidden
  logits = tf.nn.softmax(tf.matmul(outputs_flat, W) + b)   # logits_tensor: (num_unrollings x batch_size) x V
  logits_tensor = tf.reshape(logits, [batch_size, num_unrollings, V])

  targets = tf.transpose(labels)  # targets: batch_size x num_unrollings
  weights = tf.ones([batch_size, num_unrollings]) # weights: batch_size x num_unrollings
  loss = tf.reduce_sum(tf.contrib.seq2seq.sequence_loss(logits_tensor, targets, weights, average_across_timesteps=False, average_across_batch=True))
  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

Сессия:

with tf.Session(graph=graph) as session:
  tf.global_variables_initializer().run()
  cstate = np.zeros([batch_size, hidden_size]).astype(np.float32)
  hstate = np.zeros([batch_size, hidden_size]).astype(np.float32)
  for epoch in range(max_epoch):
    CURSOR_train = 0
    epoch_over = False
    steps = 0
    average_loss = 0.0
    while not epoch_over:
      new_batch, epoch_over = nextBatch()
      feed_data = {curr_batch: new_batch, "CellState:0": cstate, "HiddenState:0": hstate}
      _, l, new_cell_state, new_hidden_state = session.run([optimizer, loss, cell_state, hidden_state], feed_dict=feed_data)
      cstate = new_cell_state
      hstate = new_hidden_state
      average_loss += l
      PRINT_INTERVAL = 200
      if steps % PRINT_INTERVAL == 0:
        print("Avg loss for last {0} batches: {1}".format(PRINT_INTERVAL, average_loss / PRINT_INTERVAL))
        average_loss = 0

      TEST_INTERVAL = 600
      if steps % TEST_INTERVAL == 0:
        # Evaluate the model
        test_over = False
        test_loss = 0.0
        test_batch_num = 0
        print("Testing ... ")
        while not test_over:
          test_batch_num += 1
          test_batch, test_over = nextBatch(setup='test')
          feed_data_test = { curr_batch: test_batch, "CellState:0": cstate, "HiddenState:0": hstate }
          tl, d1, d2 = session.run([loss, cell_state, hidden_state], feed_dict=feed_data_test)
          test_loss += tl
        test_loss = test_loss / test_batch_num
        print("Avg loss on test set: {0}".format(test_loss))

      steps += 1
      sys.stdout.write('\rStep: {0}'.format(steps))

Потеря на тестовом наборе всегда 320,2430792614422, независимо от того, как долго я его тренирую. Потери на тренировочном наборе действительно меняются. Заранее спасибо!

20.04.2018

  • Почему вы добавляете 1 к своим средним потерям на каждом этапе обучения? 20.04.2018
  • @chrisz, ты имеешь в виду medium_loss += l? Это строчная буква L для потери :) 20.04.2018
  • Ничего себе, я должен был лечь спать раньше прошлой ночью. Если это так, скорость обучения может быть слишком высокой. 20.04.2018

Ответы:


1

ваша скорость обучения слишком высока, попробуйте скорость обучения 0.0005 и просто настройте это число.

20.04.2018
Новые материалы

Кластеризация: более глубокий взгляд
Кластеризация — это метод обучения без учителя, в котором мы пытаемся найти группы в наборе данных на основе некоторых известных или неизвестных свойств, которые могут существовать. Независимо от..

Как написать эффективное резюме
Предложения по дизайну и макету, чтобы представить себя профессионально Вам не позвонили на собеседование после того, как вы несколько раз подали заявку на работу своей мечты? У вас может..

Частный метод Python: улучшение инкапсуляции и безопасности
Введение Python — универсальный и мощный язык программирования, известный своей простотой и удобством использования. Одной из ключевых особенностей, отличающих Python от других языков, является..

Как я автоматизирую тестирование с помощью Jest
Шутка для победы, когда дело касается автоматизации тестирования Одной очень важной частью разработки программного обеспечения является автоматизация тестирования, поскольку она создает..

Работа с векторными символическими архитектурами, часть 4 (искусственный интеллект)
Hyperseed: неконтролируемое обучение с векторными символическими архитектурами (arXiv) Автор: Евгений Осипов , Сачин Кахавала , Диланта Хапутантри , Тимал Кемпития , Дасвин Де Сильва ,..

Понимание расстояния Вассерштейна: мощная метрика в машинном обучении
В обширной области машинного обучения часто возникает необходимость сравнивать и измерять различия между распределениями вероятностей. Традиционные метрики расстояния, такие как евклидово..

Обеспечение масштабируемости LLM: облачный анализ с помощью AWS Fargate и Copilot
В динамичной области искусственного интеллекта все большее распространение получают модели больших языков (LLM). Они жизненно важны для различных приложений, таких как интеллектуальные..