tf keras Dense源码解析
生活随笔
收集整理的這篇文章主要介紹了
tf keras Dense源码解析
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
環境
| tensorflow | 2.3.0 |
| keras | 2.4.3 |
源碼
class Dense(Layer):def __init__(self,units,activation=None,use_bias=True,kernel_initializer='glorot_uniform',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,**kwargs):super(Dense, self).__init__(activity_regularizer=activity_regularizer, **kwargs)self.units = int(units) if not isinstance(units, int) else unitsself.activation = activations.get(activation)self.use_bias = use_biasself.kernel_initializer = initializers.get(kernel_initializer)self.bias_initializer = initializers.get(bias_initializer)self.kernel_regularizer = regularizers.get(kernel_regularizer)self.bias_regularizer = regularizers.get(bias_regularizer)self.kernel_constraint = constraints.get(kernel_constraint)self.bias_constraint = constraints.get(bias_constraint)self.input_spec = InputSpec(min_ndim=2)self.supports_masking = Truedef build(self, input_shape):dtype = dtypes.as_dtype(self.dtype or K.floatx())if not (dtype.is_floating or dtype.is_complex):raise TypeError('Unable to build `Dense` layer with non-floating point ''dtype %s' % (dtype,))input_shape = tensor_shape.TensorShape(input_shape)last_dim = tensor_shape.dimension_value(input_shape[-1])if last_dim is None:raise ValueError('The last dimension of the inputs to `Dense` ''should be defined. Found `None`.')self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})self.kernel = self.add_weight('kernel',shape=[last_dim, self.units],initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,dtype=self.dtype,trainable=True)if self.use_bias:self.bias = self.add_weight('bias',shape=[self.units,],initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,dtype=self.dtype,trainable=True)else:self.bias = Noneself.built = Truedef call(self, inputs):return core_ops.dense(inputs,self.kernel,self.bias,self.activation,dtype=self._compute_dtype_object)def compute_output_shape(self, input_shape):input_shape = tensor_shape.TensorShape(input_shape)input_shape = input_shape.with_rank_at_least(2)if tensor_shape.dimension_value(input_shape[-1]) is None:raise ValueError('The innermost dimension of input_shape must be defined, but saw: %s'% input_shape)return input_shape[:-1].concatenate(self.units)def get_config(self):config = super(Dense, self).get_config()config.update({'units':self.units,'activation':activations.serialize(self.activation),'use_bias':self.use_bias,'kernel_initializer':initializers.serialize(self.kernel_initializer),'bias_initializer':initializers.serialize(self.bias_initializer),'kernel_regularizer':regularizers.serialize(self.kernel_regularizer),'bias_regularizer':regularizers.serialize(self.bias_regularizer),'activity_regularizer':regularizers.serialize(self.activity_regularizer),'kernel_constraint':constraints.serialize(self.kernel_constraint),'bias_constraint':constraints.serialize(self.bias_constraint)})return config查看源碼可以看到最簡單的Dense總共有四個方法
init
創建時各個參數的含義
| units | 激活單元 |
| activation | 激活函數 |
| use_bias | 是否用偏移量 |
| initializer | 矩陣初始化的方法 |
| regularizer | 權重正則化的方法 |
| constraint | 限制方法 |
build
初始化后就可以創建權重矩陣和偏移矩陣了(weight bias),主要運用的add_weight方法
call
計算,用的是core_ops.dense方法,以下是dense源碼
def dense(inputs, kernel, bias=None, activation=None, dtype=None):if dtype:if inputs.dtype.base_dtype != dtype.base_dtype:inputs = math_ops.cast(inputs, dtype=dtype)rank = inputs.shape.rankif rank == 2 or rank is None:if isinstance(inputs, sparse_tensor.SparseTensor):outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)else:outputs = gen_math_ops.mat_mul(inputs, kernel)# Broadcast kernel to inputs.else:outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])# Reshape the output back to the original ndim of the input.if not context.executing_eagerly():shape = inputs.shape.as_list()output_shape = shape[:-1] + [kernel.shape[-1]]outputs.set_shape(output_shape)if bias is not None:outputs = nn_ops.bias_add(outputs, bias)if activation is not None:outputs = activation(outputs)return outputs## TODO:乘法區別這里input是個tensor,所以有rank變量,rank即tensor是幾維的
一個是 sparse_ops.sparse_tensor_dense_matmul 和 gen_math_ops.mat_mul
一個是 standard_ops.tensordot
compute_output_shape
根據input和units,計算output_shape
get_config
返回config dict
總結
以上是生活随笔為你收集整理的tf keras Dense源码解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tf keras SimpleRNN源码
- 下一篇: 检验是否服从同一分布