其他分享
首页 > 其他分享> > 基于ResNet的MSTAR数据集目标分类

基于ResNet的MSTAR数据集目标分类

作者:互联网

基于ResNet的MSTAR数据集目标分类

文章目录

说在前面

1. MSART数据集介绍

在这里插入图片描述

2. SAR目标分类网络

在这里插入图片描述

在这里插入图片描述

3. ResNet代码及训练

在这里插入图片描述

4. 结尾

附录(代码)

import tensorflow as tf 
import numpy as np
from ResNet import resnet50
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, callbacks
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, precision_score, f1_score, fbeta_score
from sklearn.metrics import roc_auc_score, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pathlib
import datetime
import seaborn as sns
def load_one_from_path_label(path, label):
    images = np.zeros((1, 128, 128, 1))
    labels = tf.one_hot(label, depth=10)
    labels = tf.cast(labels, dtype=tf.int32)

    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image)
    image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
    image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
    image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
        
    images[0, :, :, :] = image
    return images, labels

def load_from_path_label2(all_image_paths, all_image_labels):
    '''读取所有图片'''
    image_count = len(all_image_paths)
    images = np.zeros((image_count, 128, 128, 1))
    labels = tf.one_hot(all_image_labels, depth=10)
    labels = tf.cast(all_image_labels, dtype=tf.int32)

    for i in range(0, image_count):
        image = tf.io.read_file(all_image_paths[i])
        image = tf.image.decode_jpeg(image)
        image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
        image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
        image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
        
        images[i, :, :, :] = image

    return images, labels


def load_from_path_label(path, label):
    '''读取图片'''
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image)
    label = tf.one_hot(label, depth=10)
    label = tf.cast(label, dtype=tf.int32)
    return image, label

def preprocess(image, label):
    '''图片预处理'''
    
    image = tf.image.adjust_gamma(image, 0.6)  # Gamma 
    image = tf.image.resize(image, [128, 128])  # 重设为(128, 128)
    image = tf.cast(image, dtype=tf.float32) / 255.0  # 归一化到[0,1]范围
    
    return image, label

def show_image(db, row, col, title, is_preprocess=True):
    '''显示10个类别图片'''
    plt.figure()
    plt.suptitle(title, fontsize=14)
    j = 0
    for i, (image, label) in enumerate(db):
        if j == row * col :
            break
        if int(tf.argmax(label)) == int(j / col) :
            if is_preprocess == True :
                image = image * 255
            plt.subplot(row, col, j+1)
            plt.title("class" + str(int(tf.argmax(label))), fontsize=8)
            plt.imshow(image, cmap='gray')
            plt.axis('off')
            j = j + 1
    plt.tight_layout()

def get_datasets(path, train=True):
    '''获取数据集'''
    # 获得数据集文件路径
    data_path = pathlib.Path(path)
    # 获得所有类别图片的路径
    all_image_paths = list(data_path.glob('*/*'))
    all_image_paths = [str(path1) for path1 in all_image_paths]
    # 数据集图片数量
    image_count = len(all_image_paths)
    # 获得类别名称列表
    label_names = [item.name for item in data_path.glob('*/')]
    # 枚举类别名称并转化为数字标号
    label_index = dict((name, index) for index, name in enumerate(label_names))
    print(label_index)
    print(label_names)
    print(image_count)
    # 获得所有数据集图片的数字标号
    all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
    for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
        print(image, ' --->  ', label)
    images, labels = load_from_path_label2(all_image_paths, all_image_labels)
    # 建立dataset数据集
    db = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels)) 
    db = db.map(load_from_path_label)
    if train == True:
        show_image(db, 5, 5, '(Train) Raw SAR Image', False)
        db = db.map(preprocess)
        show_image(db, 5, 5, '(Train) Preprocessed SAR Image', True)
    else:
        show_image(db, 5, 5, '(Test) Raw SAR Image', False)
        db = db.map(preprocess)
        show_image(db, 5, 5, '(Test) Preprocessed SAR Image', True)
    
    db = db.shuffle(1000).batch(16)
    return db, images, labels, label_names


def get_train_valid_datasets(path, train=True):
    '''获取数据集'''
    # 获得数据集文件路径
    data_path = pathlib.Path(path)
    # 获得所有类别图片的路径
    all_image_paths = list(data_path.glob('*/*'))
    all_image_paths = [str(path1) for path1 in all_image_paths]
    # 数据集图片数量
    image_count = len(all_image_paths)
    # 获得类别名称列表
    label_names = [item.name for item in data_path.glob('*/')]
    # 枚举类别名称并转化为数字标号
    label_index = dict((name, index) for index, name in enumerate(label_names))
    print(label_index)
    print(label_names)
    print(image_count)
    # 获得所有数据集图片的数字标号
    all_image_labels = [label_index[pathlib.Path(path).parent.name] for path in all_image_paths]
    for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
        print(image, ' --->  ', label)
    
    train_images, valid_images, train_labels, valid_labels = train_test_split(all_image_paths, all_image_labels, test_size = 0.2, random_state = 0)

    print('train counts -----> ',len(train_images))
    print('valid counts -----> ',len(valid_images))
    train_db = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) 
    train_db = train_db.map(load_from_path_label)

    valid_db = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels)) 
    valid_db = valid_db.map(load_from_path_label)

    if train == True:
        show_image(train_db, 5, 5, '(Train) Raw SAR Image', False)
        train_db = train_db.map(preprocess)
        show_image(train_db, 5, 5, '(Train) Preprocessed SAR Image', True)

        show_image(valid_db, 5, 5, '(Valid) Raw SAR Image', False)
        valid_db = valid_db.map(preprocess)
        show_image(valid_db, 5, 5, '(Valid) Preprocessed SAR Image', True)
    
    train_db = train_db.shuffle(1000).batch(16)
    valid_db = valid_db.shuffle(1000).batch(16)
    return train_db, valid_db

def plot_confusion_matrix(matrix, class_labels, normalize=False):
    '''混淆矩阵绘图'''
    if normalize:
        matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis] # 混淆矩阵归一化  
        A = np.around(matrix, decimals=5)
        A = A * 100
        print(A)
        matrix = np.around(matrix, decimals=2)
    sns.set()
    f, ax = plt.subplots()
    tick_marks = np.arange(0.5,10.5,1)
    sns.heatmap(matrix, annot=True, cmap="Blues",ax=ax) #画热力图
    ax.set_title('confusion matrix') #标题
    
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels, rotation=0)
    ax.set_xlabel('Predict') #x轴
    ax.set_ylabel('True') #y轴
    plt.tight_layout()

def main():    
    '''main函数'''
    train_db, train_images, train_labels, train_label_names = get_datasets('E:\\SARimage\\TRAIN', True)
    test_db, test_images, test_labels, test_label_names = get_datasets('E:\\SARimage\\TEST', False)
    # resnet 50
    model = resnet50()
    model.build(input_shape=(None, 128, 128, 1))
    model.summary()
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = 'logs/ResNet50_epoch50_Mstar_' + current_time
    tb_callback = callbacks.TensorBoard(log_dir=log_dir)

    model.compile(optimizer=optimizers.Adam(lr=0.0001), loss=tf.losses.CategoricalCrossentropy(from_logits=True),
            metrics=['accuracy'])
    #model.fit(train_db, epochs=1, validation_data=test_db,validation_freq=1)

    model.fit(train_db, epochs=50, validation_data=test_db,validation_freq=1,callbacks=[tb_callback])
    model.evaluate(test_db)

    model.save_weights('./checkpoint/ResNet50_epoch50_weights.ckpt')
    print('save weights')

    pred_labels = model.predict(test_images)
    pred_labels = tf.argmax(pred_labels, axis=1)
    con_matrix = confusion_matrix(test_labels, pred_labels, labels=[0,1,2,3,4,5,6,7,8,9])
    print(con_matrix)
    plot_confusion_matrix(con_matrix, test_label_names, normalize=True)
    np.savetxt('./checkpoint/ResNet50_epoch50_confusion_matrix.txt',con_matrix)
    plt.show()

标签:基于,image,labels,db,ResNet,label,tf,MSTAR,self
来源: https://blog.csdn.net/qq_40181592/article/details/118495082