简介

Matlab降低了深度神经网络的开发难度,可以通过拖拽的模式设计网络,甚至训练的过程也是GUI操作。

实例

以高光谱图像分类为例,参考文献1 。构造一个卷积神经网络,输入为 9 × 9 × B 9\times 9\times B 9×9×B 的图像,其中 B B B为波段数,类标为中心像素的标签。

网络设计

在Matlab的APPS中搜索Deep Network工具箱,打开后,选择New来创建网络,在弹出的界面中可以选择创建空白网络,也可以选择预训练的网络。

在这里插入图片描述
进入设计洁面后,从左侧拖拽相应的模块,命名-->设置参数-->连接不同模块,网络搭建完成后,可以选择Analyze来分析下网络,看看有没有错误,没有错误责可以导出代码。

在这里插入图片描述

代码

主训练文件"train_cnn.m",主要完成加载数据、从图像中随机抽取小的图像块,构造训练集,验证集和测试集。注意,真值变量需要用categorical函数转换一下。

load('../data/WHU_Hi_HongHu_preprocessing_tensor_edgemap_7.mat')

rng(2022);

% In the experiments, the patch sizes of the three datasets were set as 
% 9 × 9 × d, where d denotes the band number of the remote sensing image.
Ntrain = 1000;
Nvalid = 500;
Ntest = 200;
ptcsize = [9, 9]; 
M = ones(size(Label));
nclass = length(unique(Label));
[X, Y] = sample_patchs(X, Label, M, ptcsize, nclass, Ntrain+Nvalid+Ntest);
Xtrain = X(:, :, :, 1:Ntrain); Ytrain = Y(1:Ntrain);
Xvalid = X(:, :, :, 1:Nvalid); Yvalid = Y(1:Nvalid);
Xtest = X(:, :, :, 1:Ntest); Ytest = Y(1:Ntest);
Ytrain = categorical(Ytrain);
Yvalid = categorical(Yvalid);
Ytest = categorical(Ytest);

layers = [
    imageInputLayer([9 9 270],"Name","imageinput")
    convolution2dLayer([3 3],128,"Name","conv1")
    batchNormalizationLayer("Name","batchnorm1")
    reluLayer("Name","relu1")
    convolution2dLayer([3 3],256,"Name","conv2")
    batchNormalizationLayer("Name","batchnorm2")
    reluLayer("Name","relu2")
    convolution2dLayer([3 3],256,"Name","conv3","Padding","same")
    batchNormalizationLayer("Name","batchnorm3")
    reluLayer("Name","relu3")
    convolution2dLayer([3 3],128,"Name","conv4")
    batchNormalizationLayer("Name","batchnorm4")
    reluLayer("Name","relu4")
    fullyConnectedLayer(128,"Name","fc1")
    batchNormalizationLayer("Name","batchnorm5")
    reluLayer("Name","relu5")
    fullyConnectedLayer(64,"Name","fc2")
    batchNormalizationLayer("Name","batchnorm6")
    reluLayer("Name","relu6")
    fullyConnectedLayer(nclass,"Name","fc3")
    softmaxLayer("Name","softmax")
    classificationLayer("Name", "classoutput")];

% plot(layerGraph(layers));

options = trainingOptions('adam', ...
    'ValidationData', {Xvalid, Yvalid}, ...
    'Plots', 'training-progress', ...
    'MaxEpochs', 100, ...
    'Shuffle', 'every-epoch', ...
    'InitialLearnRate', 1e-3, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.1, ...
    'LearnRateDropPeriod', 20, ...
    'ExecutionEnvironment', 'gpu', ...
    'MiniBatchSize', 32);
net = trainNetwork(Xtrain, Ytrain, layers, options);

Ptest = classify(net, Xtest);
precision = sum(Ptest==Ytest) / numel(Ptest);

disp(precision)

随机选图像块文件 “sample_patchs.m”

function [Xp, Yp] = sample_patchs(X, Y, M, ptcsize, nclass, nptcs)
% X: Data image
% Y: Label image
% M: mask: 1: candidate
% ptcsize: size (h, w) of patch
% nclass: number of classes
% nptcs: number of patchs


if isempty(ptcsize) 
    ptcsize = [9, 9];
end
if isempty(nptcs)
    nptcs = 100;
end

pH = ptcsize(1);
pW = ptcsize(2);
pH2 = floor(pH / 2.);
pW2 = floor(pW / 2.);
[xH, xW, C] = size(X);

M(1:pH2, :)  = 0; % boundary
M(xH-pH2:xH, :)  = 0; % boundary
M(:, 1:pW2)  = 0; % boundary
M(:, xW-pW2:xW)  = 0; % boundary

[rows, cols] = find(M==1);
npixel = length(rows);

idx = randi([1, npixel], nptcs, 1);
idxH = rows(idx);
idxW = cols(idx);

Xp = zeros(ptcsize(1), ptcsize(2), C, nptcs);
Yp = zeros(nptcs, 1);
% Yp = zeros(nptcs, nclass); % one-hot
for i = 1:nptcs
    Xp(:, :, :, i) = X(idxH(i) - pH2:idxH(i) + pH2, idxW(i) - pW2:idxW(i) + pW2, :);
    Yp(i, 1) = Y(idxH(i), idxW(i));
    %  Yp(i, Y(idxH(i), idxW(i)) + 1) = 1;  % one-hot
end

运行结果

下图为训练过程的日志结果,图中曲线和一些统计信息是Matlab自动绘制的,不需要自己额外添加代码。
在这里插入图片描述
此外,Matlab命令窗口也有相应的信息,如下:

>> train_cnn
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:08 |        0.00% |       30.60% |       3.6983 |       2.8994 |          0.0010 |
|       2 |          50 |       00:00:11 |       75.00% |       64.80% |       1.0919 |       1.2310 |          0.0010 |
|       4 |         100 |       00:00:14 |       65.62% |       65.20% |       0.9713 |       1.0583 |          0.0010 |
|       5 |         150 |       00:00:17 |       62.50% |       74.80% |       1.1589 |       0.8747 |          0.0010 |
|       7 |         200 |       00:00:19 |       62.50% |       73.00% |       0.9210 |       0.8468 |          0.0010 |
|       9 |         250 |       00:00:23 |       78.12% |       76.60% |       0.6505 |       0.7860 |          0.0010 |
|      10 |         300 |       00:00:25 |       78.12% |       77.80% |       0.7985 |       0.7317 |          0.0010 |
|      12 |         350 |       00:00:28 |       81.25% |       80.00% |       0.6691 |       0.6691 |          0.0010 |
|      13 |         400 |       00:00:31 |       71.88% |       80.20% |       0.9969 |       0.6473 |          0.0010 |
|      15 |         450 |       00:00:34 |       87.50% |       80.20% |       0.4374 |       0.6442 |          0.0010 |
|      17 |         500 |       00:00:37 |       84.38% |       81.20% |       0.4327 |       0.6272 |          0.0010 |
|      18 |         550 |       00:00:39 |       84.38% |       83.80% |       0.3872 |       0.5438 |          0.0010 |
|      20 |         600 |       00:00:42 |       81.25% |       83.00% |       0.6669 |       0.5028 |          0.0010 |
|      21 |         650 |       00:00:45 |       81.25% |       86.40% |       0.4656 |       0.4147 |          0.0001 |
|      23 |         700 |       00:00:48 |       78.12% |       88.00% |       0.6784 |       0.3880 |          0.0001 |
|      25 |         750 |       00:00:51 |       96.88% |       88.40% |       0.2379 |       0.3900 |          0.0001 |
|      26 |         800 |       00:00:53 |       93.75% |       88.20% |       0.3173 |       0.4199 |          0.0001 |
|      28 |         850 |       00:00:56 |       87.50% |       89.00% |       0.3716 |       0.3864 |          0.0001 |
|      30 |         900 |       00:00:59 |       87.50% |       89.20% |       0.3112 |       0.3499 |          0.0001 |
|      31 |         950 |       00:01:01 |       81.25% |       90.60% |       0.4589 |       0.3472 |          0.0001 |
|      33 |        1000 |       00:01:04 |       90.62% |       90.20% |       0.2410 |       0.3030 |          0.0001 |
|      34 |        1050 |       00:01:07 |       96.88% |       91.00% |       0.2589 |       0.3052 |          0.0001 |
|      36 |        1100 |       00:01:10 |       84.38% |       92.00% |       0.5322 |       0.2920 |          0.0001 |
|      38 |        1150 |       00:01:12 |       96.88% |       91.20% |       0.2072 |       0.2998 |          0.0001 |
|      39 |        1200 |       00:01:15 |       90.62% |       92.20% |       0.2447 |       0.2759 |          0.0001 |
|      41 |        1250 |       00:01:18 |       93.75% |       92.00% |       0.1627 |       0.2724 |      1.0000e-05 |
|      42 |        1300 |       00:01:20 |       96.88% |       92.40% |       0.1265 |       0.2751 |      1.0000e-05 |
|      44 |        1350 |       00:01:23 |       93.75% |       90.80% |       0.1679 |       0.3054 |      1.0000e-05 |
|      46 |        1400 |       00:01:26 |       96.88% |       93.40% |       0.1650 |       0.2544 |      1.0000e-05 |
|      47 |        1450 |       00:01:29 |       93.75% |       92.20% |       0.2000 |       0.2709 |      1.0000e-05 |
|      49 |        1500 |       00:01:32 |       93.75% |       92.40% |       0.1877 |       0.2520 |      1.0000e-05 |
|      50 |        1550 |       00:01:34 |       93.75% |       92.20% |       0.1618 |       0.2842 |      1.0000e-05 |
|      52 |        1600 |       00:01:37 |       93.75% |       91.80% |       0.3416 |       0.2809 |      1.0000e-05 |
|      54 |        1650 |       00:01:40 |       96.88% |       91.60% |       0.1159 |       0.2628 |      1.0000e-05 |
|      55 |        1700 |       00:01:43 |       90.62% |       94.00% |       0.2882 |       0.2346 |      1.0000e-05 |
|      57 |        1750 |       00:01:46 |       93.75% |       93.00% |       0.1924 |       0.2571 |      1.0000e-05 |
|      59 |        1800 |       00:01:48 |      100.00% |       94.40% |       0.0592 |       0.2273 |      1.0000e-05 |
|      60 |        1850 |       00:01:51 |       93.75% |       91.40% |       0.1993 |       0.2669 |      1.0000e-05 |
|      62 |        1900 |       00:01:54 |       87.50% |       91.00% |       0.3692 |       0.2943 |      1.0000e-06 |
|      63 |        1950 |       00:01:57 |       96.88% |       92.80% |       0.2041 |       0.2607 |      1.0000e-06 |
|      65 |        2000 |       00:02:00 |       93.75% |       91.60% |       0.2100 |       0.2653 |      1.0000e-06 |
|      67 |        2050 |       00:02:03 |       87.50% |       92.60% |       0.3792 |       0.2715 |      1.0000e-06 |
|      68 |        2100 |       00:02:06 |       93.75% |       91.80% |       0.1791 |       0.2868 |      1.0000e-06 |
|      70 |        2150 |       00:02:08 |       96.88% |       92.60% |       0.2040 |       0.2728 |      1.0000e-06 |
|      71 |        2200 |       00:02:11 |       90.62% |       93.20% |       0.2053 |       0.2353 |      1.0000e-06 |
|      73 |        2250 |       00:02:14 |       93.75% |       93.60% |       0.2120 |       0.2299 |      1.0000e-06 |
|      75 |        2300 |       00:02:17 |       90.62% |       93.20% |       0.2796 |       0.2405 |      1.0000e-06 |
|      76 |        2350 |       00:02:19 |       93.75% |       92.60% |       0.2731 |       0.2586 |      1.0000e-06 |
|      78 |        2400 |       00:02:22 |       93.75% |       91.80% |       0.1932 |       0.2732 |      1.0000e-06 |
|      80 |        2450 |       00:02:25 |       96.88% |       92.80% |       0.1315 |       0.2484 |      1.0000e-06 |
|      81 |        2500 |       00:02:28 |       93.75% |       93.60% |       0.2221 |       0.2730 |      1.0000e-07 |
|      83 |        2550 |       00:02:31 |       93.75% |       92.20% |       0.1957 |       0.2558 |      1.0000e-07 |
|      84 |        2600 |       00:02:34 |       96.88% |       91.80% |       0.1457 |       0.2807 |      1.0000e-07 |
|      86 |        2650 |       00:02:36 |       87.50% |       93.20% |       0.4540 |       0.2724 |      1.0000e-07 |
|      88 |        2700 |       00:02:39 |       93.75% |       93.40% |       0.2235 |       0.2315 |      1.0000e-07 |
|      89 |        2750 |       00:02:42 |      100.00% |       93.40% |       0.0892 |       0.2506 |      1.0000e-07 |
|      91 |        2800 |       00:02:45 |       93.75% |       92.00% |       0.2005 |       0.2666 |      1.0000e-07 |
|      92 |        2850 |       00:02:48 |      100.00% |       91.20% |       0.1301 |       0.2748 |      1.0000e-07 |
|      94 |        2900 |       00:02:51 |       96.88% |       92.20% |       0.1594 |       0.2691 |      1.0000e-07 |
|      96 |        2950 |       00:02:53 |       93.75% |       93.00% |       0.1665 |       0.2548 |      1.0000e-07 |
|      97 |        3000 |       00:02:56 |       93.75% |       94.00% |       0.2878 |       0.2366 |      1.0000e-07 |
|      99 |        3050 |       00:02:59 |       90.62% |       92.00% |       0.1891 |       0.2761 |      1.0000e-07 |
|     100 |        3100 |       00:03:02 |       93.75% |       92.00% |       0.1937 |       0.2665 |      1.0000e-07 |
|======================================================================================================================|
    0.9500

参考文献


  1. WHU-Hi: UAV-borne hyperspectral with high spatial resolution (H2) benchmark datasets and classifier for precise crop identification based on deep convolutional neural network with CRF ↩︎

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐