tf keras SimpleRNN源码解析
生活随笔
收集整理的這篇文章主要介紹了
tf keras SimpleRNN源码解析
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
環(huán)境
| tensorflow | 2.3.0 |
| keras | 2.4.3 |
源碼
部分主要源碼
class RNN(Layer):def __init__(self,cell,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,time_major=False,**kwargs):if isinstance(cell, (list, tuple)):cell = StackedRNNCells(cell)# If True, the output for masked timestep will be zeros, whereas in the# False case, output from previous timestep is returned for masked timestep.self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)if 'input_shape' not in kwargs and ('input_dim' in kwargs or 'input_length' in kwargs):input_shape = (kwargs.pop('input_length', None),kwargs.pop('input_dim', None))kwargs['input_shape'] = input_shapesuper(RNN, self).__init__(**kwargs)self.cell = cellself.return_sequences = return_sequencesself.return_state = return_stateself.go_backwards = go_backwardsself.stateful = statefulself.unroll = unrollself.time_major = time_majorself.supports_masking = Trueself.input_spec = Noneself.state_spec = Noneself._states = Noneself.constants_spec = Noneself._num_constants = 0if stateful:if ds_context.has_strategy():raise ValueError('RNNs with stateful=True not yet supported with ''tf.distribute.Strategy.')@propertydef states(self):if self._states is None:state = nest.map_structure(lambda _: None, self.cell.state_size)return state if nest.is_sequence(self.cell.state_size) else [state]return self._states@trackable.no_automatic_dependency_trackingdef states(self, states):self._states = statesdef compute_mask(self, inputs, mask):# Time step masks must be the same for each input.# This is because the mask for an RNN is of size [batch, time_steps, 1],# and specifies which time steps should be skipped, and a time step# must be skipped for all inputs.# TODO(scottzhu): Should we accept multiple different masks?mask = nest.flatten(mask)[0]output_mask = mask if self.return_sequences else Noneif self.return_state:state_mask = [None for _ in self.states]return [output_mask] + state_maskelse:return output_maskdef build(self, input_shape):if isinstance(input_shape, list):input_shape = input_shape[0]# The input_shape here could be a nest structure.# do the tensor_shape to shapes here. The input could be single tensor, or a# nested structure of tensors.def get_input_spec(shape):"""Convert input shape to InputSpec."""if isinstance(shape, tensor_shape.TensorShape):input_spec_shape = shape.as_list()else:input_spec_shape = list(shape)batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)if not self.stateful:input_spec_shape[batch_index] = Noneinput_spec_shape[time_step_index] = Nonereturn InputSpec(shape=tuple(input_spec_shape))def get_step_input_shape(shape):if isinstance(shape, tensor_shape.TensorShape):shape = tuple(shape.as_list())# remove the timestep from the input_shapereturn shape[1:] if self.time_major else (shape[0],) + shape[2:]# Check whether the input shape contains any nested shapes. It could be# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy# inputs.try:input_shape = tensor_shape.as_shape(input_shape)except (ValueError, TypeError):# A nested tensor inputpassif not nest.is_sequence(input_shape):# This indicates the there is only one input.if self.input_spec is not None:self.input_spec[0] = get_input_spec(input_shape)else:self.input_spec = [get_input_spec(input_shape)]step_input_shape = get_step_input_shape(input_shape)else:if self.input_spec is not None:self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)else:self.input_spec = generic_utils.to_list(nest.map_structure(get_input_spec, input_shape))step_input_shape = nest.map_structure(get_step_input_shape, input_shape)# allow cell (if layer) to build before we set or validate state_spec.if isinstance(self.cell, Layer) and not self.cell.built:with K.name_scope(self.cell.name):self.cell.build(step_input_shape)self.cell.built = True# set or validate state_specif _is_multiple_state(self.cell.state_size):state_size = list(self.cell.state_size)else:state_size = [self.cell.state_size]if self.state_spec is not None:# initial_state was passed in call, check compatibilityself._validate_state_spec(state_size, self.state_spec)else:self.state_spec = [InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list())for dim in state_size]if self.stateful:self.reset_states()self.built = True@staticmethoddef _validate_state_spec(cell_state_sizes, init_state_specs):"""Validate the state spec between the initial_state and the state_size.Args:cell_state_sizes: list, the `state_size` attribute from the cell.init_state_specs: list, the `state_spec` from the initial_state that ispassed in `call()`.Raises:ValueError: When initial state spec is not compatible with the state size."""validation_error = ValueError('An `initial_state` was passed that is not compatible with ''`cell.state_size`. Received `state_spec`={}; ''however `cell.state_size` is ''{}'.format(init_state_specs, cell_state_sizes))flat_cell_state_sizes = nest.flatten(cell_state_sizes)flat_state_specs = nest.flatten(init_state_specs)if len(flat_cell_state_sizes) != len(flat_state_specs):raise validation_errorfor cell_state_spec, cell_state_size in zip(flat_state_specs,flat_cell_state_sizes):if not tensor_shape.TensorShape(# Ignore the first axis for init_state which is for batchcell_state_spec.shape[1:]).is_compatible_with(tensor_shape.TensorShape(cell_state_size)):raise validation_error@doc_controls.do_not_doc_inheritabledef get_initial_state(self, inputs):get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)if nest.is_sequence(inputs):# The input are nested sequences. Use the first element in the seq to get# batch size and dtype.inputs = nest.flatten(inputs)[0]input_shape = array_ops.shape(inputs)batch_size = input_shape[1] if self.time_major else input_shape[0]dtype = inputs.dtypeif get_initial_state_fn:init_state = get_initial_state_fn(inputs=None, batch_size=batch_size, dtype=dtype)else:init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,dtype)# Keras RNN expect the states in a list, even if it's a single state tensor.if not nest.is_sequence(init_state):init_state = [init_state]# Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.return list(init_state)def __call__(self, inputs, initial_state=None, constants=None, **kwargs):inputs, initial_state, constants = _standardize_args(inputs,initial_state,constants,self._num_constants)if initial_state is None and constants is None:return super(RNN, self).__call__(inputs, **kwargs)# If any of `initial_state` or `constants` are specified and are Keras# tensors, then add them to the inputs and temporarily modify the# input_spec to include them.additional_inputs = []additional_specs = []if initial_state is not None:additional_inputs += initial_stateself.state_spec = nest.map_structure(lambda s: InputSpec(shape=K.int_shape(s)), initial_state)additional_specs += self.state_specif constants is not None:additional_inputs += constantsself.constants_spec = [InputSpec(shape=K.int_shape(constant)) for constant in constants]self._num_constants = len(constants)additional_specs += self.constants_spec# additional_inputs can be empty if initial_state or constants are provided# but empty (e.g. the cell is stateless).flat_additional_inputs = nest.flatten(additional_inputs)is_keras_tensor = K.is_keras_tensor(flat_additional_inputs[0]) if flat_additional_inputs else Truefor tensor in flat_additional_inputs:if K.is_keras_tensor(tensor) != is_keras_tensor:raise ValueError('The initial state or constants of an RNN'' layer cannot be specified with a mix of'' Keras tensors and non-Keras tensors'' (a "Keras tensor" is a tensor that was'' returned by a Keras layer, or by `Input`)')if is_keras_tensor:# Compute the full input spec, including state and constantsfull_input = [inputs] + additional_inputsif self.built:# Keep the input_spec since it has been populated in build() method.full_input_spec = self.input_spec + additional_specselse:# The original input_spec is None since there could be a nested tensor# input. Update the input_spec to match the inputs.full_input_spec = generic_utils.to_list(nest.map_structure(lambda _: None, inputs)) + additional_specs# Perform the call with temporarily replaced input_specself.input_spec = full_input_specoutput = super(RNN, self).__call__(full_input, **kwargs)# Remove the additional_specs from input spec and keep the rest. It is# important to keep since the input spec was populated by build(), and# will be reused in the stateful=True.self.input_spec = self.input_spec[:-len(additional_specs)]return outputelse:if initial_state is not None:kwargs['initial_state'] = initial_stateif constants is not None:kwargs['constants'] = constantsreturn super(RNN, self).__call__(inputs, **kwargs)def call(self,inputs,mask=None,training=None,initial_state=None,constants=None):# The input should be dense, padded with zeros. If a ragged input is fed# into the layer, it is padded and the row lengths are used for masking.inputs, row_lengths = K.convert_inputs_if_ragged(inputs)is_ragged_input = (row_lengths is not None)self._validate_args_if_ragged(is_ragged_input, mask)inputs, initial_state, constants = self._process_inputs(inputs, initial_state, constants)self._maybe_reset_cell_dropout_mask(self.cell)if isinstance(self.cell, StackedRNNCells):for cell in self.cell.cells:self._maybe_reset_cell_dropout_mask(cell)if mask is not None:# Time step masks must be the same for each input.# TODO(scottzhu): Should we accept multiple different masks?mask = nest.flatten(mask)[0]if nest.is_sequence(inputs):# In the case of nested input, use the first element for shape check.input_shape = K.int_shape(nest.flatten(inputs)[0])else:input_shape = K.int_shape(inputs)timesteps = input_shape[0] if self.time_major else input_shape[1]if self.unroll and timesteps is None:raise ValueError('Cannot unroll a RNN if the ''time dimension is undefined. \n''- If using a Sequential model, ''specify the time dimension by passing ''an `input_shape` or `batch_input_shape` ''argument to your first layer. If your ''first layer is an Embedding, you can ''also use the `input_length` argument.\n''- If using the functional API, specify ''the time dimension by passing a `shape` ''or `batch_shape` argument to your Input layer.')kwargs = {}if generic_utils.has_arg(self.cell.call, 'training'):kwargs['training'] = training# TF RNN cells expect single tensor as state instead of list wrapped tensor.is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None# Use the __call__ function for callable objects, eg layers, so that it# will have the proper name scopes for the ops, etc.cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.callif constants:if not generic_utils.has_arg(self.cell.call, 'constants'):raise ValueError('RNN cell does not support constants')def step(inputs, states):constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-typestates = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-typestates = states[0] if len(states) == 1 and is_tf_rnn_cell else statesoutput, new_states = cell_call_fn(inputs, states, constants=constants, **kwargs)if not nest.is_sequence(new_states):new_states = [new_states]return output, new_stateselse:def step(inputs, states):states = states[0] if len(states) == 1 and is_tf_rnn_cell else statesoutput, new_states = cell_call_fn(inputs, states, **kwargs)if not nest.is_sequence(new_states):new_states = [new_states]return output, new_stateslast_output, outputs, states = K.rnn(step,inputs,initial_state,constants=constants,go_backwards=self.go_backwards,mask=mask,unroll=self.unroll,input_length=row_lengths if row_lengths is not None else timesteps,time_major=self.time_major,zero_output_for_mask=self.zero_output_for_mask)if self.stateful:updates = [state_ops.assign(self_state, state) for self_state, state in zip(nest.flatten(self.states), nest.flatten(states))]self.add_update(updates)if self.return_sequences:output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)else:output = last_outputif self.return_state:if not isinstance(states, (list, tuple)):states = [states]else:states = list(states)return generic_utils.to_list(output) + stateselse:return outputdef _process_inputs(self, inputs, initial_state, constants):# input shape: `(samples, time (padded with zeros), input_dim)`# note that the .build() method of subclasses MUST define# self.input_spec and self.state_spec with complete input shapes.if (isinstance(inputs, collections_abc.Sequence)and not isinstance(inputs, tuple)):# get initial_state from full input spec# as they could be copied to multiple GPU.if not self._num_constants:initial_state = inputs[1:]else:initial_state = inputs[1:-self._num_constants]constants = inputs[-self._num_constants:]if len(initial_state) == 0:initial_state = Noneinputs = inputs[0]if self.stateful:if initial_state is not None:# When layer is stateful and initial_state is provided, check if the# recorded state is same as the default value (zeros). Use the recorded# state if it is not same as the default.non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s)for s in nest.flatten(self.states)])# Set strict = True to keep the original structure of the state.initial_state = control_flow_ops.cond(non_zero_count > 0,true_fn=lambda: self.states,false_fn=lambda: initial_state,strict=True)else:initial_state = self.stateselif initial_state is None:initial_state = self.get_initial_state(inputs)if len(initial_state) != len(self.states):raise ValueError('Layer has ' + str(len(self.states)) +' states but was passed ' + str(len(initial_state)) +' initial states.')return inputs, initial_state, constantsdef _validate_args_if_ragged(self, is_ragged_input, mask):if not is_ragged_input:returnif mask is not None:raise ValueError('The mask that was passed in was ' + str(mask) +' and cannot be applied to RaggedTensor inputs. Please ''make sure that there is no mask passed in by upstream ''layers.')if self.unroll:raise ValueError('The input received contains RaggedTensors and does ''not support unrolling. Disable unrolling by passing ''`unroll=False` in the RNN Layer constructor.')def reset_states(self, states=None):"""Reset the recorded states for the stateful RNN layer.Can only be used when RNN layer is constructed with `stateful` = `True`.Args:states: Numpy arrays that contains the value for the initial state, whichwill be feed to cell at the first time step. When the value is None,zero filled numpy array will be created based on the cell state size.Raises:AttributeError: When the RNN layer is not stateful.ValueError: When the batch size of the RNN layer is unknown.ValueError: When the input numpy array is not compatible with the RNNlayer state, either size wise or dtype wise."""if not self.stateful:raise AttributeError('Layer must be stateful.')spec_shape = Noneif self.input_spec is not None:spec_shape = nest.flatten(self.input_spec[0])[0].shapeif spec_shape is None:# It is possible to have spec shape to be None, eg when construct a RNN# with a custom cell, or standard RNN layers (LSTM/GRU) which we only know# it has 3 dim input, but not its full shape spec before build().batch_size = Noneelse:batch_size = spec_shape[1] if self.time_major else spec_shape[0]if not batch_size:raise ValueError('If a RNN is stateful, it needs to know ''its batch size. Specify the batch size ''of your input tensors: \n''- If using a Sequential model, ''specify the batch size by passing ''a `batch_input_shape` ''argument to your first layer.\n''- If using the functional API, specify ''the batch size by passing a ''`batch_shape` argument to your Input layer.')# initialize state if Noneif nest.flatten(self.states)[0] is None:def create_state_variable(state):return K.zeros([batch_size] + tensor_shape.as_shape(state).as_list())self.states = nest.map_structure(create_state_variable, self.cell.state_size)if not nest.is_sequence(self.states):self.states = [self.states]elif states is None:for state, size in zip(nest.flatten(self.states),nest.flatten(self.cell.state_size)):K.set_value(state, np.zeros([batch_size] +tensor_shape.as_shape(size).as_list()))else:flat_states = nest.flatten(self.states)flat_input_states = nest.flatten(states)if len(flat_input_states) != len(flat_states):raise ValueError('Layer ' + self.name + ' expects ' +str(len(flat_states)) + ' states, ''but it received ' + str(len(flat_input_states)) +' state values. Input received: ' + str(states))set_value_tuples = []for i, (value, state) in enumerate(zip(flat_input_states,flat_states)):if value.shape != state.shape:raise ValueError('State ' + str(i) + ' is incompatible with layer ' +self.name + ': expected shape=' + str((batch_size, state)) + ', found shape=' + str(value.shape))set_value_tuples.append((state, value))K.batch_set_value(set_value_tuples)流程
build
input_shape
step_input_shape
state_size
總結(jié)
以上是生活随笔為你收集整理的tf keras SimpleRNN源码解析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tf 从RNN到BERT
- 下一篇: tf keras Dense源码解析