【项目分享】使用 PointNet 进行点云分割
介紹
“點云(point cloud)”是一種用于存儲幾何形狀數(shù)據(jù)的重要數(shù)據(jù)結(jié)構(gòu)類型。由于其不規(guī)則的格式,在用于深度學習應(yīng)用程序之前,它通常會轉(zhuǎn)換為規(guī)則的 3D 體素網(wǎng)格或圖像集合,這會使數(shù)據(jù)變得不必要地大。PointNet 系列模型通過直接使用點云解決了這個問題,同時尊重點數(shù)據(jù)的排列不變性。PointNet 系列模型為從對象分類、部分分割到?場景語義解析等應(yīng)用提供了一個簡單、統(tǒng)一的架構(gòu)。
在此示例中,我們演示了用于形狀分割的 PointNet 架構(gòu)的實現(xiàn)。
參考
- PointNet:用于 3D 分類和分割的點集的深度學習
- 使用 PointNet 進行點云分類
- 空間變壓器網(wǎng)絡(luò)
導入
import os import json import random import numpy as np import pandas as pd from tqdm import tqdm from glob import globimport tensorflow as tf from tensorflow import keras from tensorflow.keras import layersimport matplotlib.pyplot as plt下載數(shù)據(jù)集
ShapeNet數(shù)據(jù)集是建立一個注釋豐富的大規(guī)模 3D 形狀數(shù)據(jù)集的持續(xù)努力。ShapeNetCore是完整 ShapeNet 數(shù)據(jù)集的子集,具有干凈的單個 3D 模型和手動驗證的類別和對齊注釋。它涵蓋了 55 個常見的對象類別,擁有大約 51,300 個獨特的 3D 模型。
對于此示例,我們使用?PASCAL 3D+的 12 個對象類別之一,作為 ShapenetCore 數(shù)據(jù)集的一部分。
dataset_url = "https://git.io/JiY4i"dataset_path = keras.utils.get_file(fname="shapenet.zip",origin=dataset_url,cache_subdir="datasets",hash_algorithm="auto",extract=True,archive_format="auto",cache_dir="datasets", )加載數(shù)據(jù)集
我們解析數(shù)據(jù)集元數(shù)據(jù),以便輕松地將模型類別映射到它們各自的目錄,并將分割類映射到顏色以實現(xiàn)可視化。
with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:metadata = json.load(json_file)print(metadata) {'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}}在這個例子中,我們訓練 PointNet 來分割A(yù)irplane模型的各個部分。
points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(metadata["Airplane"]["directory"] ) labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(metadata["Airplane"]["directory"] ) LABELS = metadata["Airplane"]["lables"] COLORS = metadata["Airplane"]["colors"]VAL_SPLIT = 0.2 NUM_SAMPLE_POINTS = 1024 BATCH_SIZE = 32 EPOCHS = 60 INITIAL_LR = 1e-3構(gòu)建數(shù)據(jù)集
我們從飛機點云及其標簽生成以下內(nèi)存數(shù)據(jù)結(jié)構(gòu):
- point_cloudsnp.array是以x、y 和 z 坐標的形式表示點云數(shù)據(jù)的對象列表。軸 0 表示點云中的點數(shù),軸 1 表示坐標。all_labels是將每個坐標的標簽表示為字符串的列表(主要用于可視化目的)。
- test_point_clouds與 格式相同point_clouds,但沒有對應(yīng)的點云標簽。
- all_labels是np.array表示每個坐標的點云標簽的對象列表,對應(yīng)于point_clouds列表。
- point_cloud_labels是一個np.array對象列表,它以 one-hot 編碼形式表示每個坐標的點云標簽,對應(yīng)于point_clouds?列表。
接下來,我們看一下剛剛生成的內(nèi)存數(shù)組中的一些樣本:
for _ in range(5):i = random.randint(0, len(point_clouds) - 1)print(f"point_clouds[{i}].shape:", point_clouds[0].shape)print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)for j in range(5):print(f"all_labels[{i}][{j}]:",all_labels[i][j],f"\tpoint_cloud_labels[{i}][{j}]:",point_cloud_labels[i][j],"\n",) point_clouds[475].shape: (2602, 3) point_cloud_labels[475].shape: (2602, 5) all_labels[475][0]: body point_cloud_labels[475][0]: [0. 1. 0. 0. 0.] all_labels[475][1]: engine point_cloud_labels[475][1]: [0. 0. 0. 1. 0.] all_labels[475][2]: body point_cloud_labels[475][2]: [0. 1. 0. 0. 0.] all_labels[475][3]: body point_cloud_labels[475][3]: [0. 1. 0. 0. 0.] all_labels[475][4]: wing point_cloud_labels[475][4]: [1. 0. 0. 0. 0.] point_clouds[2712].shape: (2602, 3) point_cloud_labels[2712].shape: (2602, 5) all_labels[2712][0]: tail point_cloud_labels[2712][0]: [0. 0. 1. 0. 0.] all_labels[2712][1]: wing point_cloud_labels[2712][1]: [1. 0. 0. 0. 0.] all_labels[2712][2]: engine point_cloud_labels[2712][2]: [0. 0. 0. 1. 0.] all_labels[2712][3]: wing point_cloud_labels[2712][3]: [1. 0. 0. 0. 0.] all_labels[2712][4]: wing point_cloud_labels[2712][4]: [1. 0. 0. 0. 0.] point_clouds[1413].shape: (2602, 3) point_cloud_labels[1413].shape: (2602, 5) all_labels[1413][0]: body point_cloud_labels[1413][0]: [0. 1. 0. 0. 0.] all_labels[1413][1]: tail point_cloud_labels[1413][1]: [0. 0. 1. 0. 0.] all_labels[1413][2]: tail point_cloud_labels[1413][2]: [0. 0. 1. 0. 0.] all_labels[1413][3]: tail point_cloud_labels[1413][3]: [0. 0. 1. 0. 0.] all_labels[1413][4]: tail point_cloud_labels[1413][4]: [0. 0. 1. 0. 0.] point_clouds[1207].shape: (2602, 3) point_cloud_labels[1207].shape: (2602, 5) all_labels[1207][0]: tail point_cloud_labels[1207][0]: [0. 0. 1. 0. 0.] all_labels[1207][1]: wing point_cloud_labels[1207][1]: [1. 0. 0. 0. 0.] all_labels[1207][2]: wing point_cloud_labels[1207][2]: [1. 0. 0. 0. 0.] all_labels[1207][3]: body point_cloud_labels[1207][3]: [0. 1. 0. 0. 0.] all_labels[1207][4]: body point_cloud_labels[1207][4]: [0. 1. 0. 0. 0.] point_clouds[2492].shape: (2602, 3) point_cloud_labels[2492].shape: (2602, 5) all_labels[2492][0]: engine point_cloud_labels[2492][0]: [0. 0. 0. 1. 0.] all_labels[2492][1]: body point_cloud_labels[2492][1]: [0. 1. 0. 0. 0.] all_labels[2492][2]: body point_cloud_labels[2492][2]: [0. 1. 0. 0. 0.] all_labels[2492][3]: body point_cloud_labels[2492][3]: [0. 1. 0. 0. 0.] all_labels[2492][4]: engine point_cloud_labels[2492][4]: [0. 0. 0. 1. 0.]現(xiàn)在,讓我們可視化一些點云及其標簽。
def visualize_data(point_cloud, labels):df = pd.DataFrame(data={"x": point_cloud[:, 0],"y": point_cloud[:, 1],"z": point_cloud[:, 2],"label": labels,})fig = plt.figure(figsize=(15, 10))ax = plt.axes(projection="3d")for index, label in enumerate(LABELS):c_df = df[df["label"] == label]try:ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])except IndexError:passax.legend()plt.show()visualize_data(point_clouds[0], all_labels[0]) visualize_data(point_clouds[300], all_labels[300])預(yù)處理
請注意,我們加載的所有點云都包含可變數(shù)量的點,這使得我們很難將它們批處理在一起。為了克服這個問題,我們從每個點云中隨機抽取固定數(shù)量的點。我們還對點云進行歸一化,以使數(shù)據(jù)具有尺度不變性。
for index in tqdm(range(len(point_clouds))):current_point_cloud = point_clouds[index]current_label_cloud = point_cloud_labels[index]current_labels = all_labels[index]num_points = len(current_point_cloud)# Randomly sampling respective indices.sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)# Sampling points corresponding to sampled indices.sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])# Sampling corresponding one-hot encoded labels.sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])# Sampling corresponding labels for visualization.sampled_labels = np.array([current_labels[i] for i in sampled_indices])# Normalizing sampled point cloud.norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))point_clouds[index] = norm_point_cloudpoint_cloud_labels[index] = sampled_label_cloudall_labels[index] = sampled_labels 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3694/3694 [00:07<00:00, 478.67it/s]讓我們可視化采樣和歸一化的點云及其相應(yīng)的標簽。
visualize_data(point_clouds[0], all_labels[0]) visualize_data(point_clouds[300], all_labels[300])創(chuàng)建 TensorFlow 數(shù)據(jù)集
tf.data.Dataset我們?yōu)橛柧毢万炞C數(shù)據(jù)創(chuàng)建對象。我們還通過對其應(yīng)用隨機抖動來增強訓練點云。
def load_data(point_cloud_batch, label_cloud_batch):point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])return point_cloud_batch, label_cloud_batchdef augment(point_cloud_batch, label_cloud_batch):noise = tf.random.uniform(tf.shape(label_cloud_batch), -0.005, 0.005, dtype=tf.float64)point_cloud_batch += noise[:, :, :3]return point_cloud_batch, label_cloud_batchdef generate_dataset(point_clouds, label_clouds, is_training=True):dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else datasetdataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.batch(batch_size=BATCH_SIZE)dataset = (dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)if is_trainingelse dataset)return datasetsplit_index = int(len(point_clouds) * (1 - VAL_SPLIT)) train_point_clouds = point_clouds[:split_index] train_label_cloud = point_cloud_labels[:split_index] total_training_examples = len(train_point_clouds)val_point_clouds = point_clouds[split_index:] val_label_cloud = point_cloud_labels[split_index:]print("Num train point clouds:", len(train_point_clouds)) print("Num train point cloud labels:", len(train_label_cloud)) print("Num val point clouds:", len(val_point_clouds)) print("Num val point cloud labels:", len(val_label_cloud))train_dataset = generate_dataset(train_point_clouds, train_label_cloud) val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)print("Train Dataset:", train_dataset) print("Validation Dataset:", val_dataset) Num train point clouds: 2955 Num train point cloud labels: 2955 Num val point clouds: 739 Num val point cloud labels: 739 Train Dataset: <ParallelMapDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)> Validation Dataset: <BatchDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)>點網(wǎng)模型
下圖描述了 PointNet 模型族的內(nèi)部結(jié)構(gòu):
鑒于 PointNet 旨在使用一組無序坐標作為其輸入數(shù)據(jù),其架構(gòu)需要匹配點云數(shù)據(jù)的以下特征:
排列不變性
鑒于點云數(shù)據(jù)的非結(jié)構(gòu)化性質(zhì),由點組成的掃描n具有n!?排列。后續(xù)的數(shù)據(jù)處理必須對不同的表示保持不變。為了使 PointNet 對輸入排列保持不變,一旦n輸入點映射到更高維空間,我們就使用對稱函數(shù)(例如 max-pooling)。結(jié)果是一個全局特征向量,旨在捕獲n輸入點的聚合簽名。全局特征向量與局部點特征一起用于分割。
變換不變性
如果對象經(jīng)歷了某些轉(zhuǎn)換,例如平移或縮放,則分割輸出應(yīng)該保持不變。對于給定的輸入點云,我們應(yīng)用適當?shù)膭傂曰蚍律渥儞Q來實現(xiàn)姿態(tài)歸一化。因為每個n輸入點都表示為一個向量并獨立地映射到嵌入空間,所以應(yīng)用幾何變換簡單地等于矩陣將每個點與一個變換矩陣相乘。這是由空間變壓器網(wǎng)絡(luò)的概念推動的?。
構(gòu)成 T-Net 的操作是由 PointNet 的更高級別架構(gòu)推動的。MLP(或全連接層)用于將輸入點獨立且相同地映射到更高維空間;最大池用于編碼全局特征向量,然后使用全連接層降低其維度。然后將最終全連接層的輸入相關(guān)特征與全局可訓練的權(quán)重和偏差相結(jié)合,形成一個 3×3 變換矩陣。
點交互
相鄰點之間的交互通常攜帶有用的信息(即,不應(yīng)孤立地處理單個點)。分類只需要利用全局特征,而分割必須能夠利用局部點特征和全局點特征。
注:本節(jié)中的數(shù)字取自?原始論文。
現(xiàn)在我們知道了構(gòu)成 PointNet 模型的部分,我們可以實現(xiàn)該模型。我們首先實現(xiàn)基本塊,即卷積塊和多層感知器塊。
def conv_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)return layers.Activation("relu", name=f"{name}_relu")(x)def mlp_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:x = layers.Dense(filters, name=f"{name}_dense")(x)x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)return layers.Activation("relu", name=f"{name}_relu")(x)我們實現(xiàn)了一個正則化器(取自?這個例子)來加強特征空間的正交性。這是為了確保轉(zhuǎn)換后的特征的幅度不會變化太大。
class OrthogonalRegularizer(keras.regularizers.Regularizer):"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""def __init__(self, num_features, l2reg=0.001):self.num_features = num_featuresself.l2reg = l2regself.identity = tf.eye(num_features)def __call__(self, x):x = tf.reshape(x, (-1, self.num_features, self.num_features))xxt = tf.tensordot(x, x, axes=(2, 2))xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))return tf.reduce_sum(self.l2reg * tf.square(xxt - self.identity))def get_config(self):config = super(TransformerEncoder, self).get_config()config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})return config下一部分是我們之前解釋過的轉(zhuǎn)換網(wǎng)絡(luò)。
def transformation_net(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:""" Reference: https://keras.io/examples/vision/pointnet/#build-a-model. The `filters` values come from the original paper: https://arxiv.org/abs/1612.00593. """x = conv_block(inputs, filters=64, name=f"{name}_1")x = conv_block(x, filters=128, name=f"{name}_2")x = conv_block(x, filters=1024, name=f"{name}_3")x = layers.GlobalMaxPooling1D()(x)x = mlp_block(x, filters=512, name=f"{name}_1_1")x = mlp_block(x, filters=256, name=f"{name}_2_1")return layers.Dense(num_features * num_features,kernel_initializer="zeros",bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),activity_regularizer=OrthogonalRegularizer(num_features),name=f"{name}_final",)(x)def transformation_block(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:transformed_features = transformation_net(inputs, num_features, name=name)transformed_features = layers.Reshape((num_features, num_features))(transformed_features)return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])最后,我們將上述塊拼湊在一起并實現(xiàn)分割模型。
def get_shape_segmentation_model(num_points: int, num_classes: int) -> keras.Model:input_points = keras.Input(shape=(None, 3))# PointNet Classification Network.transformed_inputs = transformation_block(input_points, num_features=3, name="input_transformation_block")features_64 = conv_block(transformed_inputs, filters=64, name="features_64")features_128_1 = conv_block(features_64, filters=128, name="features_128_1")features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")transformed_features = transformation_block(features_128_2, num_features=128, name="transformed_features")features_512 = conv_block(transformed_features, filters=512, name="features_512")features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(features_2048)global_features = tf.tile(global_features, [1, num_points, 1])# Segmentation head.segmentation_input = layers.Concatenate(name="segmentation_input")([features_64,features_128_1,features_128_2,transformed_features,features_512,global_features,])segmentation_features = conv_block(segmentation_input, filters=128, name="segmentation_features")outputs = layers.Conv1D(num_classes, kernel_size=1, activation="softmax", name="segmentation_head")(segmentation_features)return keras.Model(input_points, outputs)實例化模型
x, y = next(iter(train_dataset))num_points = x.shape[1] num_classes = y.shape[-1]segmentation_model = get_shape_segmentation_model(num_points, num_classes) segmentation_model.summary() 2021-10-25 01:26:33.563133: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, None, 3)] 0 __________________________________________________________________________________________________ input_transformation_block_1_co (None, None, 64) 256 input_1[0][0] __________________________________________________________________________________________________ input_transformation_block_1_ba (None, None, 64) 256 input_transformation_block_1_conv __________________________________________________________________________________________________ input_transformation_block_1_re (None, None, 64) 0 input_transformation_block_1_batc __________________________________________________________________________________________________ input_transformation_block_2_co (None, None, 128) 8320 input_transformation_block_1_relu __________________________________________________________________________________________________ input_transformation_block_2_ba (None, None, 128) 512 input_transformation_block_2_conv __________________________________________________________________________________________________ input_transformation_block_2_re (None, None, 128) 0 input_transformation_block_2_batc __________________________________________________________________________________________________ input_transformation_block_3_co (None, None, 1024) 132096 input_transformation_block_2_relu __________________________________________________________________________________________________ input_transformation_block_3_ba (None, None, 1024) 4096 input_transformation_block_3_conv __________________________________________________________________________________________________ input_transformation_block_3_re (None, None, 1024) 0 input_transformation_block_3_batc __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 1024) 0 input_transformation_block_3_relu __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 524800 global_max_pooling1d[0][0] __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 2048 input_transformation_block_1_1_de __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 0 input_transformation_block_1_1_ba __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 131328 input_transformation_block_1_1_re __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 1024 input_transformation_block_2_1_de __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 0 input_transformation_block_2_1_ba __________________________________________________________________________________________________ input_transformation_block_fina (None, 9) 2313 input_transformation_block_2_1_re __________________________________________________________________________________________________ reshape (Reshape) (None, 3, 3) 0 input_transformation_block_final[ __________________________________________________________________________________________________ input_transformation_block_mm ( (None, None, 3) 0 input_1[0][0] reshape[0][0] __________________________________________________________________________________________________ features_64_conv (Conv1D) (None, None, 64) 256 input_transformation_block_mm[0][ __________________________________________________________________________________________________ features_64_batch_norm (BatchNo (None, None, 64) 256 features_64_conv[0][0] __________________________________________________________________________________________________ features_64_relu (Activation) (None, None, 64) 0 features_64_batch_norm[0][0] __________________________________________________________________________________________________ features_128_1_conv (Conv1D) (None, None, 128) 8320 features_64_relu[0][0] __________________________________________________________________________________________________ features_128_1_batch_norm (Batc (None, None, 128) 512 features_128_1_conv[0][0] __________________________________________________________________________________________________ features_128_1_relu (Activation (None, None, 128) 0 features_128_1_batch_norm[0][0] __________________________________________________________________________________________________ features_128_2_conv (Conv1D) (None, None, 128) 16512 features_128_1_relu[0][0] __________________________________________________________________________________________________ features_128_2_batch_norm (Batc (None, None, 128) 512 features_128_2_conv[0][0] __________________________________________________________________________________________________ features_128_2_relu (Activation (None, None, 128) 0 features_128_2_batch_norm[0][0] __________________________________________________________________________________________________ transformed_features_1_conv (Co (None, None, 64) 8256 features_128_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_batch_no (None, None, 64) 256 transformed_features_1_conv[0][0] __________________________________________________________________________________________________ transformed_features_1_relu (Ac (None, None, 64) 0 transformed_features_1_batch_norm __________________________________________________________________________________________________ transformed_features_2_conv (Co (None, None, 128) 8320 transformed_features_1_relu[0][0] __________________________________________________________________________________________________ transformed_features_2_batch_no (None, None, 128) 512 transformed_features_2_conv[0][0] __________________________________________________________________________________________________ transformed_features_2_relu (Ac (None, None, 128) 0 transformed_features_2_batch_norm __________________________________________________________________________________________________ transformed_features_3_conv (Co (None, None, 1024) 132096 transformed_features_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_3_batch_no (None, None, 1024) 4096 transformed_features_3_conv[0][0] __________________________________________________________________________________________________ transformed_features_3_relu (Ac (None, None, 1024) 0 transformed_features_3_batch_norm __________________________________________________________________________________________________ global_max_pooling1d_1 (GlobalM (None, 1024) 0 transformed_features_3_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_1_dense (None, 512) 524800 global_max_pooling1d_1[0][0] __________________________________________________________________________________________________ transformed_features_1_1_batch_ (None, 512) 2048 transformed_features_1_1_dense[0] __________________________________________________________________________________________________ transformed_features_1_1_relu ( (None, 512) 0 transformed_features_1_1_batch_no __________________________________________________________________________________________________ transformed_features_2_1_dense (None, 256) 131328 transformed_features_1_1_relu[0][ __________________________________________________________________________________________________ transformed_features_2_1_batch_ (None, 256) 1024 transformed_features_2_1_dense[0] __________________________________________________________________________________________________ transformed_features_2_1_relu ( (None, 256) 0 transformed_features_2_1_batch_no __________________________________________________________________________________________________ transformed_features_final (Den (None, 16384) 4210688 transformed_features_2_1_relu[0][ __________________________________________________________________________________________________ reshape_1 (Reshape) (None, 128, 128) 0 transformed_features_final[0][0] __________________________________________________________________________________________________ transformed_features_mm (Dot) (None, None, 128) 0 features_128_2_relu[0][0] reshape_1[0][0] __________________________________________________________________________________________________ features_512_conv (Conv1D) (None, None, 512) 66048 transformed_features_mm[0][0] __________________________________________________________________________________________________ features_512_batch_norm (BatchN (None, None, 512) 2048 features_512_conv[0][0] __________________________________________________________________________________________________ features_512_relu (Activation) (None, None, 512) 0 features_512_batch_norm[0][0] __________________________________________________________________________________________________ pre_maxpool_block_conv (Conv1D) (None, None, 2048) 1050624 features_512_relu[0][0] __________________________________________________________________________________________________ pre_maxpool_block_batch_norm (B (None, None, 2048) 8192 pre_maxpool_block_conv[0][0] __________________________________________________________________________________________________ pre_maxpool_block_relu (Activat (None, None, 2048) 0 pre_maxpool_block_batch_norm[0][0 __________________________________________________________________________________________________ global_features (MaxPooling1D) (None, None, 2048) 0 pre_maxpool_block_relu[0][0] __________________________________________________________________________________________________ tf.tile (TFOpLambda) (None, None, 2048) 0 global_features[0][0] __________________________________________________________________________________________________ segmentation_input (Concatenate (None, None, 3008) 0 features_64_relu[0][0] features_128_1_relu[0][0] features_128_2_relu[0][0] transformed_features_mm[0][0] features_512_relu[0][0] tf.tile[0][0] __________________________________________________________________________________________________ segmentation_features_conv (Con (None, None, 128) 385152 segmentation_input[0][0] __________________________________________________________________________________________________ segmentation_features_batch_nor (None, None, 128) 512 segmentation_features_conv[0][0] __________________________________________________________________________________________________ segmentation_features_relu (Act (None, None, 128) 0 segmentation_features_batch_norm[ __________________________________________________________________________________________________ segmentation_head (Conv1D) (None, None, 5) 645 segmentation_features_relu[0][0] ================================================================================================== Total params: 7,370,062 Trainable params: 7,356,110 Non-trainable params: 13,952 __________________________________________________________________________________________________訓練
對于訓練,作者建議使用每 20 個 epoch 將初始學習率降低一半的學習率計劃。在這個例子中,我們使用 15 個 epoch。
training_step_size = total_training_examples // BATCH_SIZE total_training_steps = training_step_size * EPOCHS print(f"Total training steps: {total_training_steps}.")lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[training_step_size * 15, training_step_size * 15],values=[INITIAL_LR, INITIAL_LR * 0.5, INITIAL_LR * 0.25], )steps = tf.range(total_training_steps, dtype=tf.int32) lrs = [lr_schedule(step) for step in steps]plt.plot(lrs) plt.xlabel("Steps") plt.ylabel("Learning Rate") plt.show() Total training steps: 5520.最后,我們實現(xiàn)了一個實用程序來運行我們的實驗并啟動模型訓練。
def run_experiment(epochs):segmentation_model = get_shape_segmentation_model(num_points, num_classes)segmentation_model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),loss=keras.losses.CategoricalCrossentropy(),metrics=["accuracy"],)checkpoint_filepath = "/tmp/checkpoint"checkpoint_callback = keras.callbacks.ModelCheckpoint(checkpoint_filepath,monitor="val_loss",save_best_only=True,save_weights_only=True,)history = segmentation_model.fit(train_dataset,validation_data=val_dataset,epochs=epochs,callbacks=[checkpoint_callback],)segmentation_model.load_weights(checkpoint_filepath)return segmentation_model, historysegmentation_model, history = run_experiment(epochs=EPOCHS) Epoch 1/60 93/93 [==============================] - 28s 127ms/step - loss: 5.3556 - accuracy: 0.7448 - val_loss: 5.8386 - val_accuracy: 0.7471 Epoch 2/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7077 - accuracy: 0.8181 - val_loss: 5.2614 - val_accuracy: 0.7793 Epoch 3/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6566 - accuracy: 0.8301 - val_loss: 4.7907 - val_accuracy: 0.8269 Epoch 4/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6059 - accuracy: 0.8406 - val_loss: 4.6031 - val_accuracy: 0.8482 Epoch 5/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5828 - accuracy: 0.8444 - val_loss: 4.7692 - val_accuracy: 0.8220 Epoch 6/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6150 - accuracy: 0.8408 - val_loss: 5.4460 - val_accuracy: 0.8192 Epoch 7/60 93/93 [==============================] - 11s 117ms/step - loss: 67.5943 - accuracy: 0.7378 - val_loss: 1617.1846 - val_accuracy: 0.5191 Epoch 8/60 93/93 [==============================] - 11s 117ms/step - loss: 15.2910 - accuracy: 0.6651 - val_loss: 8.1014 - val_accuracy: 0.7046 Epoch 9/60 93/93 [==============================] - 11s 117ms/step - loss: 6.8878 - accuracy: 0.7368 - val_loss: 14.2311 - val_accuracy: 0.6949 Epoch 10/60 93/93 [==============================] - 11s 117ms/step - loss: 5.8362 - accuracy: 0.7549 - val_loss: 14.6942 - val_accuracy: 0.6350 Epoch 11/60 93/93 [==============================] - 11s 117ms/step - loss: 5.4777 - accuracy: 0.7648 - val_loss: 44.1037 - val_accuracy: 0.6422 Epoch 12/60 93/93 [==============================] - 11s 117ms/step - loss: 5.2688 - accuracy: 0.7712 - val_loss: 4.9977 - val_accuracy: 0.7692 Epoch 13/60 93/93 [==============================] - 11s 117ms/step - loss: 5.1041 - accuracy: 0.7837 - val_loss: 6.0642 - val_accuracy: 0.7577 Epoch 14/60 93/93 [==============================] - 11s 117ms/step - loss: 5.0011 - accuracy: 0.7862 - val_loss: 4.9313 - val_accuracy: 0.7840 Epoch 15/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8910 - accuracy: 0.7953 - val_loss: 5.8368 - val_accuracy: 0.7725 Epoch 16/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8698 - accuracy: 0.8074 - val_loss: 73.0260 - val_accuracy: 0.7251 Epoch 17/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8299 - accuracy: 0.8109 - val_loss: 17.1503 - val_accuracy: 0.7415 Epoch 18/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8147 - accuracy: 0.8111 - val_loss: 62.2765 - val_accuracy: 0.7344 Epoch 19/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8316 - accuracy: 0.8141 - val_loss: 5.2200 - val_accuracy: 0.7890 Epoch 20/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7853 - accuracy: 0.8142 - val_loss: 5.7062 - val_accuracy: 0.7719 Epoch 21/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7753 - accuracy: 0.8157 - val_loss: 6.2089 - val_accuracy: 0.7839 Epoch 22/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7681 - accuracy: 0.8161 - val_loss: 5.1077 - val_accuracy: 0.8021 Epoch 23/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7554 - accuracy: 0.8187 - val_loss: 4.7912 - val_accuracy: 0.7912 Epoch 24/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7355 - accuracy: 0.8197 - val_loss: 4.9164 - val_accuracy: 0.7978 Epoch 25/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7483 - accuracy: 0.8197 - val_loss: 13.4724 - val_accuracy: 0.7631 Epoch 26/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7200 - accuracy: 0.8218 - val_loss: 8.3074 - val_accuracy: 0.7596 Epoch 27/60 93/93 [==============================] - 11s 118ms/step - loss: 4.7192 - accuracy: 0.8231 - val_loss: 12.4468 - val_accuracy: 0.7591 Epoch 28/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7151 - accuracy: 0.8241 - val_loss: 23.8681 - val_accuracy: 0.7689 Epoch 29/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7096 - accuracy: 0.8237 - val_loss: 4.9069 - val_accuracy: 0.8104 Epoch 30/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6991 - accuracy: 0.8257 - val_loss: 4.9858 - val_accuracy: 0.7950 Epoch 31/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6852 - accuracy: 0.8260 - val_loss: 5.0130 - val_accuracy: 0.7678 Epoch 32/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6630 - accuracy: 0.8286 - val_loss: 4.8523 - val_accuracy: 0.7676 Epoch 33/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6837 - accuracy: 0.8281 - val_loss: 5.4347 - val_accuracy: 0.8095 Epoch 34/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6571 - accuracy: 0.8296 - val_loss: 10.4595 - val_accuracy: 0.7410 Epoch 35/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6460 - accuracy: 0.8321 - val_loss: 4.9189 - val_accuracy: 0.8083 Epoch 36/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6430 - accuracy: 0.8327 - val_loss: 5.8674 - val_accuracy: 0.7911 Epoch 37/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6530 - accuracy: 0.8309 - val_loss: 4.7946 - val_accuracy: 0.8032 Epoch 38/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6391 - accuracy: 0.8318 - val_loss: 5.0111 - val_accuracy: 0.8024 Epoch 39/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6521 - accuracy: 0.8336 - val_loss: 8.1558 - val_accuracy: 0.7727 Epoch 40/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6443 - accuracy: 0.8329 - val_loss: 42.8513 - val_accuracy: 0.7688 Epoch 41/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6316 - accuracy: 0.8342 - val_loss: 5.0960 - val_accuracy: 0.8066 Epoch 42/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6322 - accuracy: 0.8335 - val_loss: 5.0634 - val_accuracy: 0.8158 Epoch 43/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8370 - val_loss: 6.0642 - val_accuracy: 0.8062 Epoch 44/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8371 - val_loss: 11.1805 - val_accuracy: 0.7790 Epoch 45/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6056 - accuracy: 0.8377 - val_loss: 4.7359 - val_accuracy: 0.8145 Epoch 46/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6108 - accuracy: 0.8383 - val_loss: 5.7125 - val_accuracy: 0.7713 Epoch 47/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6103 - accuracy: 0.8377 - val_loss: 6.3271 - val_accuracy: 0.8105 Epoch 48/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6020 - accuracy: 0.8383 - val_loss: 14.2876 - val_accuracy: 0.7529 Epoch 49/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6035 - accuracy: 0.8382 - val_loss: 4.8244 - val_accuracy: 0.8143 Epoch 50/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6076 - accuracy: 0.8381 - val_loss: 8.2636 - val_accuracy: 0.7528 Epoch 51/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8399 - val_loss: 4.6473 - val_accuracy: 0.8266 Epoch 52/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8408 - val_loss: 4.6443 - val_accuracy: 0.8276 Epoch 53/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5852 - accuracy: 0.8413 - val_loss: 5.1300 - val_accuracy: 0.7768 Epoch 54/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5787 - accuracy: 0.8426 - val_loss: 8.9590 - val_accuracy: 0.7582 Epoch 55/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5837 - accuracy: 0.8410 - val_loss: 5.1501 - val_accuracy: 0.8117 Epoch 56/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5875 - accuracy: 0.8422 - val_loss: 31.3518 - val_accuracy: 0.7590 Epoch 57/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5821 - accuracy: 0.8427 - val_loss: 4.8853 - val_accuracy: 0.8144 Epoch 58/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5751 - accuracy: 0.8446 - val_loss: 4.6653 - val_accuracy: 0.8222 Epoch 59/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5752 - accuracy: 0.8447 - val_loss: 6.0078 - val_accuracy: 0.8014 Epoch 60/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5695 - accuracy: 0.8452 - val_loss: 4.8178 - val_accuracy: 0.8192可視化訓練環(huán)境
def plot_result(item):plt.plot(history.history[item], label=item)plt.plot(history.history["val_" + item], label="val_" + item)plt.xlabel("Epochs")plt.ylabel(item)plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)plt.legend()plt.grid()plt.show()plot_result("loss") plot_result("accuracy")推理
validation_batch = next(iter(val_dataset)) val_predictions = segmentation_model.predict(validation_batch[0]) print(f"Validation prediction shape: {val_predictions.shape}")def visualize_single_point_cloud(point_clouds, label_clouds, idx):label_map = LABELS + ["none"]point_cloud = point_clouds[idx]label_cloud = label_clouds[idx]visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])idx = np.random.choice(len(validation_batch[0])) print(f"Index selected: {idx}")# Plotting with ground-truth. visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)# Plotting with predicted labels. visualize_single_point_cloud(validation_batch[0], val_predictions, idx) Validation prediction shape: (32, 1024, 5) Index selected: 24最后的筆記
如果您有興趣了解有關(guān)此主題的更多信息,您可能會發(fā)現(xiàn)?此存儲庫?很有用。
總結(jié)
以上是生活随笔為你收集整理的【项目分享】使用 PointNet 进行点云分割的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【转】Android 自己收集的开源项目
- 下一篇: 2021年中国乳胶床垫市场趋势报告、技术