wake-up-neo.com

Was sind c_state und m_state in Tensorflow LSTM?

Die Dokumentation von Tensorflow r0.12 für tf.nn.rnn_cell.LSTMCell beschreibt dies als Init:

tf.nn.rnn_cell.LSTMCell.__call__(inputs, state, scope=None)

wobei state wie folgt lautet:

state: Wenn state_is_Tuple False ist, muss dies ein Zustand Tensor, 2-D, Batch x state_size sein. Wenn state_is_Tuple auf True gesetzt ist, muss dies ein Tuple von State Tensors sein, beide 2-D, mit den Spaltengrößen c_state und m_state.

Was sind c_state und m_state und wie passen sie in LSTMs? Ich kann nirgendwo in der Dokumentation einen Verweis darauf finden.

Hier ist ein Link zu dieser Seite in der Dokumentation.

15
Haziq Nordin

Ich bin auf dieselbe Frage gestoßen, wie ich sie verstehe! Minimalistisches LSTM-Beispiel:

import tensorflow as tf

sample_input = tf.constant([[1,2,3]],dtype=tf.float32)

LSTM_CELL_SIZE = 2

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)
print sess.run(output)

Beachten Sie, dass state_is_Tuple=True, wenn Sie state an diese cell übergeben, diese in der Tuple-Form sein muss. c_state und m_state sind wahrscheinlich "Memory State" und "Cell State", obwohl ich mir ehrlich gesagt NICHT sicher bin, da diese Begriffe nur in den Dokumenten erwähnt werden. Im Code und in den Papieren werden LSTM - h und c häufig verwendet, um "Ausgabewert" und "Zellenstatus" zu bezeichnen . http://colah.github.io/posts/2015-08-Understanding- LSTMs/ Diese Tensoren stellen den kombinierten internen Zustand der Zelle dar und sollten zusammen durchgelassen werden. Der alte Weg bestand darin, sie einfach zu verketten, und der neue Weg ist die Verwendung von Tupeln.

ALTER WEG:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=False)
state = tf.zeros([1,LSTM_CELL_SIZE*2])

output, state_new = lstm_cell(sample_input, state)

NEUER WEG:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

Im Grunde haben wir also state von 1 Tensor der Länge 4 in zwei Tensoren der Länge 2 geändert. Der Inhalt blieb gleich. [0,0,0,0] wird zu ([0,0],[0,0]). (Dies soll es schneller machen)

12
avloss

Ich stimme zu, dass die Dokumentation unklar ist. Ein Blick auf tf.nn.rnn_cell.LSTMCell.__call__ klärt (ich habe den Code von TensorFlow 1.0.0 übernommen):

def __call__(self, inputs, state, scope=None):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: if `state_is_Tuple` is False, this must be a state Tensor,
        `2-D, batch x state_size`.  If `state_is_Tuple` is True, this must be a
        Tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.
      scope: VariableScope for the created subgraph; defaults to "lstm_cell".

    Returns:
      A Tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    num_proj = self._num_units if self._num_proj is None else self._num_proj

    if self._state_is_Tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    with vs.variable_scope(scope or "lstm_cell",
                           initializer=self._initializer) as unit_scope:
      if self._num_unit_shards is not None:
        unit_scope.set_partitioner(
            partitioned_variables.fixed_size_partitioner(
                self._num_unit_shards))
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
                            scope=scope)
      i, j, f, o = array_ops.split(
          value=lstm_matrix, num_or_size_splits=4, axis=1)

      # Diagonal connections
      if self._use_peepholes:
        with vs.variable_scope(unit_scope) as projection_scope:
          if self._num_unit_shards is not None:
            projection_scope.set_partitioner(None)
          w_f_diag = vs.get_variable(
              "w_f_diag", shape=[self._num_units], dtype=dtype)
          w_i_diag = vs.get_variable(
              "w_i_diag", shape=[self._num_units], dtype=dtype)
          w_o_diag = vs.get_variable(
              "w_o_diag", shape=[self._num_units], dtype=dtype)

      if self._use_peepholes:
        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
             sigmoid(i + w_i_diag * c_prev) * self._activation(j))
      else:
        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
             self._activation(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      if self._use_peepholes:
        m = sigmoid(o + w_o_diag * c) * self._activation(c)
      else:
        m = sigmoid(o) * self._activation(c)

      if self._num_proj is not None:
        with vs.variable_scope("projection") as proj_scope:
          if self._num_proj_shards is not None:
            proj_scope.set_partitioner(
                partitioned_variables.fixed_size_partitioner(
                    self._num_proj_shards))
          m = _linear(m, self._num_proj, bias=False, scope=scope)

        if self._proj_clip is not None:
          # pylint: disable=invalid-unary-operand-type
          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
          # pylint: enable=invalid-unary-operand-type

    new_state = (LSTMStateTuple(c, m) if self._state_is_Tuple else
                 array_ops.concat([c, m], 1))
    return m, new_state

Die wichtigsten Zeilen sind:

c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
         self._activation(j))

und 

m = sigmoid(o) * self._activation(c)

und 

new_state = (LSTMStateTuple(c, m) 

Wenn Sie den Code zur Berechnung von c und m mit den LSTM-Gleichungen (siehe unten) vergleichen, können Sie sehen, dass er dem Zellstatus (normalerweise mit c bezeichnet) und dem verborgenen Zustand (normalerweise mit h angegeben) entspricht:

 enter image description here

new_state = (LSTMStateTuple(c, m) gibt an, dass das erste Element des zurückgegebenen Zustands Tuple c ist (Zellenstatus a.k.a. c_state) und das zweite Element des zurückgegebenen Zustands Tuple m (verborgener Zustand a.k.a. m_state).

16

Vielleicht hilft dieser Auszug aus dem Code

def __call__(self, inputs, state, scope=None):
  """Long short-term memory cell (LSTM)."""
  with vs.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_Tuple:
      c, h = state
    else:
      c, h = array_ops.split(1, 2, state)
    concat = _linear([inputs, h], 4 * self._num_units, True)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(1, 4, concat)

    new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
             self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)

    if self._state_is_Tuple:
      new_state = LSTMStateTuple(new_c, new_h)
    else:
      new_state = array_ops.concat(1, [new_c, new_h])
    return new_h, new_state
2

https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/python/ops/rnn_cell_impl.py

Zeile 308 - 314

klasse LSTMStateTuple (_LSTMStateTuple): "" "Tupel, das von LSTM-Zellen für state_size, zero_state und Ausgabestatus verwendet wird . Speichert zwei Elemente: (c, h) in dieser Reihenfolge . Wird nur verwendet, wenn state_is_Tuple=True." ""

0
Z Chen