双编码器的自然语言图像搜索
正文字數(shù):5798? 閱讀時長:10?分鐘
如何構(gòu)建一個雙編碼器(也稱為雙塔)神經(jīng)網(wǎng)絡(luò)模型,以使用自然語言搜索圖像。
?
作者 /?Khalid Salama
原文鏈接 / https://keras.io/examples/nlp/nl_image_search/
1
介紹
該示例演示了如何構(gòu)建一個雙編碼器(也稱為雙塔)神經(jīng)網(wǎng)絡(luò)模型,以使用自然語言搜索圖像。該模型的靈感來自于Alec Radford等人提出的CLIP方法,其思想是聯(lián)合訓(xùn)練一個視覺編碼器和一個文本編碼器,將圖像及其標(biāo)題的表示投射到同一個嵌入空間,從而使標(biāo)題嵌入位于其描述的圖像的嵌入附近。
這個例子需要TensorFlow 2.4或更高版本。此外,BERT模型需要TensorFlow Hub和TensorFlow Text,AdamW優(yōu)化器需要TensorFlow Addons。這些庫可以使用以下命令進行安裝。
pip?install?-q?-U tensorflow-hub?tensorflow-text?tensorflow-addons2
安裝
import?os import?collections import?json import?numpy?as?np import?tensorflow?as?tf from?tensorflow?import?keras from?tensorflow.keras?import?layers import?tensorflow_hub?as?hub import?tensorflow_text?as?text import?tensorflow_addons?as?tfa import?matplotlib.pyplot?as?plt import?matplotlib.image?as?mpimg from?tqdm?import?tqdm#?Suppressing?tf.hub?warnings tf.get_logger().setLevel("ERROR")3
準備數(shù)據(jù)
我們使用MS-COCO數(shù)據(jù)集來訓(xùn)練我們的雙編碼器模型。MS-COCO包含超過82,000張圖片,每張圖片至少有5個不同的標(biāo)題注釋。該數(shù)據(jù)集通常用image captioning任務(wù),但我們可以重新利用圖像標(biāo)題對來訓(xùn)練雙編碼器模型進行圖像搜索。
下載提取數(shù)據(jù)
首先,下載數(shù)據(jù)集,它由兩個壓縮文件夾組成:一個是圖像,另一個是相關(guān)的圖像標(biāo)題。值得注意的是壓縮后的圖像文件夾大小為13GB。
root_dir = "datasets" annotations_dir?=?os.path.join(root_dir,?"annotations") images_dir = os.path.join(root_dir, "train2014") tfrecords_dir = os.path.join(root_dir, "tfrecords") annotation_file = os.path.join(annotations_dir, "captions_train2014.json")#?Download?caption?annotation?files if not os.path.exists(annotations_dir):annotation_zip = tf.keras.utils.get_file("captions.zip",cache_dir=os.path.abspath("."),origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",extract=True,)os.remove(annotation_zip)# Download image files if not os.path.exists(images_dir):image_zip = tf.keras.utils.get_file("train2014.zip",cache_dir=os.path.abspath("."),origin="http://images.cocodataset.org/zips/train2014.zip",extract=True,)os.remove(image_zip)print("Dataset?is?downloaded?and?extracted?successfully.")with open(annotation_file, "r") as f:annotations = json.load(f)["annotations"]image_path_to_caption = collections.defaultdict(list) for element in annotations:caption = f"{element['caption'].lower().rstrip('.')}"image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])image_path_to_caption[image_path].append(caption)image_paths = list(image_path_to_caption.keys()) print(f"Number of images: {len(image_paths)}") Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip 252878848/252872794 [==============================] - 5s 0us/step Downloading data from http://images.cocodataset.org/zips/train2014.zip 13510574080/13510573713 [==============================] - 394s 0us/step Dataset is downloaded and extracted successfully. Number of images: 82783處理并將數(shù)據(jù)保存到TFRecord文件中
你可以改變sample_size參數(shù)去控制將用于訓(xùn)練雙編碼器模型的多對圖像-標(biāo)題。在這個例子中,我們將training_size設(shè)置為30000張圖像,約占數(shù)據(jù)集的35%。我們?yōu)槊繌垐D像使用2個標(biāo)題,從而產(chǎn)生60000個圖像-標(biāo)題對。訓(xùn)練集的大小會影響生成編碼器的質(zhì)量,樣本越多,訓(xùn)練時間越長。
train_size = 30000 valid_size = 5000 captions_per_image = 2 images_per_file = 2000 train_image_paths = image_paths[:train_size] num_train_files = int(np.ceil(train_size / images_per_file)) train_files_prefix = os.path.join(tfrecords_dir, "train")valid_image_paths = image_paths[-valid_size:] num_valid_files = int(np.ceil(valid_size / images_per_file)) valid_files_prefix = os.path.join(tfrecords_dir, "valid")tf.io.gfile.makedirs(tfrecords_dir)def bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def create_example(image_path, caption):feature = {"caption":?bytes_feature(caption.encode()),"raw_image": bytes_feature(tf.io.read_file(image_path).numpy()),}return tf.train.Example(features=tf.train.Features(feature=feature))def?write_tfrecords(file_name,?image_paths):caption_list?=?[]image_path_list?=?[]for?image_path?in?image_paths:captions = image_path_to_caption[image_path][:captions_per_image]caption_list.extend(captions)image_path_list.extend([image_path] * len(captions))with?tf.io.TFRecordWriter(file_name)?as?writer:for?example_idx?in?range(len(image_path_list)):example = create_example(image_path_list[example_idx], caption_list[example_idx]) writer.write(example.SerializeToString())return?example_idx?+?1def write_data(image_paths, num_files, files_prefix):example_counter = 0 for file_idx in tqdm(range(num_files)):file_name = files_prefix + "-%02d.tfrecord" % (file_idx)start_idx = images_per_file * file_idxend_idx = start_idx + images_per_fileexample_counter += write_tfrecords(file_name, image_paths[start_idx:end_idx])return?example_countertrain_example_count?=?write_data(train_image_paths,?num_train_files,?train_files_prefix) print(f"{train_example_count} training examples were written to tfrecord files.")valid_example_count = write_data(valid_image_paths, num_valid_files, valid_files_prefix) print(f"{valid_example_count}?evaluation?examples?were?written?to?tfrecord?files.") 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [03:19<00:00, 13.27s/it]0%| | 0/3 [00:00<?, ?it/s]60000 training examples were written to tfrecord files.100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:33<00:00, 11.07s/it]10000 evaluation examples were written to tfrecord files.創(chuàng)建用于訓(xùn)練和評估的 tf.data.Dataset
feature_description = {"caption": tf.io.FixedLenFeature([], tf.string),"raw_image": tf.io.FixedLenFeature([], tf.string), }def read_example(example):features = tf.io.parse_single_example(example, feature_description)raw_image = features.pop("raw_image")features["image"] = tf.image.resize(tf.image.decode_jpeg(raw_image, channels=3), size=(299, 299))return featuresdef get_dataset(file_pattern, batch_size):return (tf.data.TFRecordDataset(tf.data.Dataset.list_files(file_pattern)).map(read_example,num_parallel_calls=tf.data.experimental.AUTOTUNE,deterministic=False,).shuffle(batch_size * 10).prefetch(buffer_size=tf.data.experimental.AUTOTUNE).batch(batch_size))4
實時投影頭
投影頭用于將圖像和文字嵌入到具有相同的維度的同一嵌入空間。
def project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate ):projected_embeddings = layers.Dense(units=projection_dims)(embeddings)for _ in range(num_projection_layers):x = tf.nn.gelu(projected_embeddings)x = layers.Dense(projection_dims)(x)x = layers.Dropout(dropout_rate)(x)x = layers.Add()([projected_embeddings, x])projected_embeddings = layers.LayerNormalization()(x)return projected_embeddings5
實現(xiàn)視覺編碼器
在本例中,我們使用Keras Applications的Xception作為視覺編碼器的基礎(chǔ)。
6
實現(xiàn)文本編碼器
我們使用TensorFlow Hub的BERT作為文本編碼器
7
實現(xiàn)雙編碼器
為了計算loss,我們計算每個 caption_i和 images_j之間的對偶點積相似度作為預(yù)測值。caption_i和image_j之間的目標(biāo)相似度計算為(caption_i和caption_j之間的點積相似度)和(image_i和image_j之間的點積相似度)的平均值。然后,我們使用交叉熵來計算目標(biāo)和預(yù)測之間的損失。
class DualEncoder(keras.Model):def __init__(self, text_encoder, image_encoder, temperature=1.0, **kwargs):super(DualEncoder, self).__init__(**kwargs)self.text_encoder = text_encoderself.image_encoder = image_encoderself.temperature = temperatureself.loss_tracker = keras.metrics.Mean(name="loss")@propertydef metrics(self):return [self.loss_tracker]def call(self, features, training=False):# Place each encoder on a separate GPU (if available).# TF will fallback on available devices if there are fewer than 2 GPUs.with tf.device("/gpu:0"):# Get the embeddings for the captions.caption_embeddings = text_encoder(features["caption"], training=training)with tf.device("/gpu:1"):# Get the embeddings for the images.image_embeddings = vision_encoder(features["image"], training=training)return caption_embeddings, image_embeddingsdef compute_loss(self, caption_embeddings, image_embeddings):# logits[i][j] is the dot_similarity(caption_i, image_j).logits = (tf.matmul(caption_embeddings, image_embeddings, transpose_b=True)/ self.temperature)# images_similarity[i][j] is the dot_similarity(image_i, image_j).images_similarity = tf.matmul(image_embeddings, image_embeddings, transpose_b=True)# captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).captions_similarity = tf.matmul(caption_embeddings, caption_embeddings, transpose_b=True)# targets[i][j] = avarage dot_similarity(caption_i, caption_j) and dot_similarity(image_i, image_j).targets = keras.activations.softmax((captions_similarity + images_similarity) / (2 * self.temperature))# Compute the loss for the captions using crossentropycaptions_loss = keras.losses.categorical_crossentropy(y_true=targets, y_pred=logits, from_logits=True)# Compute the loss for the images using crossentropyimages_loss = keras.losses.categorical_crossentropy(y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True)# Return the mean of the loss over the batch.return (captions_loss + images_loss) / 2def train_step(self, features):with tf.GradientTape() as tape:# Forward passcaption_embeddings, image_embeddings = self(features, training=True)loss = self.compute_loss(caption_embeddings, image_embeddings)# Backward passgradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Monitor lossself.loss_tracker.update_state(loss)return {"loss": self.loss_tracker.result()}def test_step(self, features):caption_embeddings, image_embeddings = self(features, training=False)loss = self.compute_loss(caption_embeddings, image_embeddings)self.loss_tracker.update_state(loss)return {"loss": self.loss_tracker.result()}8
訓(xùn)練雙編碼模型
在這個實驗中,我們凍結(jié)了文字和圖像的基礎(chǔ)編碼器,只讓投影頭進行訓(xùn)練。
num_epochs = 5 # In practice, train for at least 30 epochs batch_size?=?256vision_encoder = create_vision_encoder(num_projection_layers=1, projection_dims=256, dropout_rate=0.1 ) text_encoder = create_text_encoder(num_projection_layers=1, projection_dims=256, dropout_rate=0.1 ) dual_encoder = DualEncoder(text_encoder, vision_encoder, temperature=0.05) dual_encoder.compile(optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001) )值得注意的是使用 V100 GPU 加速器訓(xùn)練 60000 個圖像標(biāo)題對的模型,批量大小為 256 個,每個 epoch 需要 12 分鐘左右。如果有2個GPU,則每個epoch需要8分鐘左右。
print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}") print(f"Number of examples (caption-image pairs): {train_example_count}") print(f"Batch size: {batch_size}") print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}") train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size) valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size) # Create a learning rate scheduler callback. reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.2, patience=3 ) # Create an early stopping callback. early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True ) history = dual_encoder.fit(train_dataset,epochs=num_epochs,validation_data=valid_dataset,callbacks=[reduce_lr, early_stopping], ) print("Training completed. Saving vision and text encoders...") vision_encoder.save("vision_encoder") text_encoder.save("text_encoder") print("Models are saved.") Number of GPUs: 2 Number of examples (caption-image pairs): 60000 Batch size: 256 Steps per epoch: 235 Epoch 1/5 235/235 [==============================] - 573s 2s/step - loss: 60.8318 - val_loss: 9.0531 Epoch 2/5 235/235 [==============================] - 553s 2s/step - loss: 7.8959 - val_loss: 5.2654 Epoch 3/5 235/235 [==============================] - 541s 2s/step - loss: 4.6644 - val_loss: 4.9260 Epoch 4/5 235/235 [==============================] - 538s 2s/step - loss: 4.0188 - val_loss: 4.6312 Epoch 5/5 235/235 [==============================] - 539s 2s/step - loss: 3.5555 - val_loss: 4.3503 Training completed. Saving vision and text encoders... Models are saved.訓(xùn)練損失的繪制:
plt.plot(history.history["loss"]) plt.plot(history.history["val_loss"]) plt.ylabel("Loss") plt.xlabel("Epoch") plt.legend(["train", "valid"], loc="upper right") plt.show()9
使用自然語言查詢搜索圖像
我們可以通過以下步驟來檢索對應(yīng)自然語言查詢的圖像:
1. 將圖像輸入vision_encoder,生成圖像的嵌入。
2. 將自然語言查詢反饋給text_encoder,生成查詢嵌入。
3. 計算查詢嵌入與索引中的圖像嵌入之間的相似度,以檢索出最匹配的索引。
4. 查閱頂部匹配圖片的路徑,將其顯示出來。
值得注意的是在訓(xùn)練完雙編碼器后,將只使用微調(diào)后的visual_encoder和text_encoder模型,而dual_encoder模型將被丟棄。
生成圖像的嵌入
我們加載圖像,并將其輸入到vision_encoder中,以生成它們的嵌入。在大規(guī)模系統(tǒng)中,這一步是使用并行數(shù)據(jù)處理框架來執(zhí)行的,比如Apache Spark或Apache Beam。生成圖像嵌入可能需要幾分鐘時間。
print("Loading vision and text encoders...") vision_encoder = keras.models.load_model("vision_encoder") text_encoder = keras.models.load_model("text_encoder") print("Models are loaded.")def read_image(image_path):image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)return tf.image.resize(image_array, (299, 299))print(f"Generating embeddings for {len(image_paths)} images...") image_embeddings = vision_encoder.predict(tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),verbose=1, ) print(f"Image embeddings shape: {image_embeddings.shape}.") Loading vision and text encoders... Models are loaded. Generating embeddings for 82783 images... 324/324 [==============================] - 437s 1s/step Image embeddings shape: (82783, 256).檢索相關(guān)圖像
該例子中,我們通過計算輸入的查詢嵌入和圖像嵌入之間的點積相似度來使用精確匹配,并檢索前k個匹配。然而,在實時用例中,使用ScaNN、Annoy或Faiss等框架進行近似匹配是首選,以擴展大量圖像。
def find_matches(image_embeddings, queries, k=9, normalize=True):# Get the embedding for the query.query_embedding = text_encoder(tf.convert_to_tensor(queries))# Normalize the query and the image embeddings.if normalize:image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)query_embedding = tf.math.l2_normalize(query_embedding, axis=1)# Compute the dot product between the query and the image embeddings.dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)# Retrieve top k indices.results = tf.math.top_k(dot_similarity, k).indices.numpy()# Return matching image paths.return [[image_paths[idx] for idx in indices] for indices in results]將查詢變量設(shè)置為你要搜索的圖片類型。試試像 "一盤健康的食物", "一個戴著帽子的女人走在人行道上", "一只鳥坐在水邊", 或 "野生動物站在田野里"。
query = "a family standing next to the ocean on a sandy beach with a surf board" matches = find_matches(image_embeddings, [query], normalize=True)[0]plt.figure(figsize=(20, 20)) for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(mpimg.imread(matches[i]))plt.axis("off")評估檢索質(zhì)量
為了評估雙編碼器模型,我們使用標(biāo)題作為查詢。使用訓(xùn)練外樣本圖像和標(biāo)題來評估檢索質(zhì)量,使用top k精度。如果對于一個給定的標(biāo)題,其相關(guān)的圖像在前k個匹配范圍內(nèi)被檢索到,則算作一個真正的預(yù)測。
def compute_top_k_accuracy(image_paths, k=100):hits = 0num_batches = int(np.ceil(len(image_paths) / batch_size))for idx in tqdm(range(num_batches)):start_idx = idx * batch_sizeend_idx = start_idx + batch_sizecurrent_image_paths = image_paths[start_idx:end_idx]queries = [image_path_to_caption[image_path][0] for image_path in current_image_paths]result = find_matches(image_embeddings, queries, k)hits += sum([image_path in matchesfor (image_path, matches) in list(zip(current_image_paths, result))])return hits / len(image_paths)print("Scoring training data...") train_accuracy = compute_top_k_accuracy(train_image_paths) print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")print("Scoring evaluation data...") eval_accuracy = compute_top_k_accuracy(image_paths[train_size:]) print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%") 0%| | 0/118 [00:00<?, ?it/s]Scoring training data...100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [04:12<00:00, 2.14s/it]0%| | 0/207 [00:00<?, ?it/s]Train accuracy: 13.373% Scoring evaluation data...100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 207/207 [07:23<00:00, 2.14s/it]Eval accuracy: 6.235%結(jié)束語
你可以通過增加訓(xùn)練樣本的大小,訓(xùn)練更多的時期,探索其他圖像和文本的基礎(chǔ)編碼器,設(shè)置基礎(chǔ)編碼器的可訓(xùn)練性,以及調(diào)整超參數(shù),特別是softmax的temperature loss計算,獲得更好的結(jié)果。
LiveVideoStackCon 2021 ShangHai
我們準備好全新的內(nèi)容
在上海歡迎您的到來
LiveVideoStackCon 2021?上海站
北京時間:2021年4月16日-4月17日
點擊【閱讀原文】了解大會詳情
總結(jié)
以上是生活随笔為你收集整理的双编码器的自然语言图像搜索的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: B端运营级视频服务技术平台搭建
- 下一篇: 音视频技术开发周刊 | 184