%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% 程序说明 % 实例8-1 % 功能:基于迁移学习对 ResNet50 网络进行微调, % 并用新冠肺炎(COVID-19)图片数据集对网络进行再训练, % 对输入图像进行检测/分类(这里是分类)。 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% 步骤1:加载图像数据,并将其划分为训练集和测试集(/验证集) % 数据集根目录(里面按“类别名文件夹”存放图片) dataset_path = 'dataset'; % 从文件夹创建图像数据存储(imageDatastore) % IncludeSubfolders=true:递归读取子文件夹 % LabelSource='foldernames':用“子文件夹名”作为标签(类别名) img_ds = imageDatastore(dataset_path, ... 'IncludeSubfolders', true, ... 'LabelSource', 'foldernames'); % 统计每个类别的样本数量(检查数据是否均衡、是否读对) 数一数每个类别有多少张 total_split = countEachLabel(img_ds); % 随机显示数据集中的部分图像(快速肉眼检查数据和标签) num_images = length(img_ds.Labels); perm = randperm(num_images, 6); % 将1~num_images随机打乱并取前6个索引 figure; %新建一个画图窗口。 for idx = 1:length(perm) subplot(2, 3, idx); %把窗口分成 2×3 的小格子,第 idx 格。 imshow(imread(img_ds.Files(perm(idx)))); %显示图片 title(sprintf('%s', string(img_ds.Labels(perm(idx))))); end % 随机取出5张图像作为测试样本,其余作为训练集 test_idx = randperm(num_images, 5); % 测试集索引(随机抽5张) img_ds_Test = subset(img_ds, test_idx); % 测试集 datastore train_idx = setdiff(1:length(img_ds.Files), test_idx); % 训练集索引=全集-测试集test_idx img_ds_Train = subset(img_ds, train_idx); % 训练集 datastore %% 步骤2:加载预训练好的网络(ResNet50) % 说明:首次调用可能需要下载预训练模型 net = resnet50; %% 步骤3:对网络结构进行调整(替换最后几层,使其适配你的类别数) % 将预训练网络转换为可编辑的层图结构(layer graph) LayerGraph = layerGraph(net); % 清理变量 net(可选,仅为了后面不混淆) clear net; % 确定你的数据集类别数量(例如:正常/肺炎 -> 2类) numClasses = numel(categories(img_ds_Train.Labels)); % 1) 替换最后的全连接层(fc1000)为“输出 numClasses 的新全连接层” % WeightLearnRateFactor / BiasLearnRateFactor = 10 % 含义:让新层学习更快(因为它是新初始化的) newLearnableLayer = fullyConnectedLayer(numClasses, ... 'Name', 'new_fc', ... 'WeightLearnRateFactor', 10, ... 'BiasLearnRateFactor', 10); %把 ResNet50 原来的 fc1000(输出 1000 类)替换成你的新层(输出 2 类)。 LayerGraph = replaceLayer(LayerGraph, 'fc1000', newLearnableLayer); % 2) 替换 softmax 层(保持名字一致更好管理) % Softmax 把输出变成“概率分布” newSoftmaxLayer = softmaxLayer('Name', 'new_softmax'); LayerGraph = replaceLayer(LayerGraph, 'fc1000_softmax', newSoftmaxLayer); % 3) 替换分类输出层(让分类层知道你现在是 numClasses 类) newClassLayer = classificationLayer('Name', 'new_classoutput'); LayerGraph = replaceLayer(LayerGraph, 'ClassificationLayer_fc1000', newClassLayer); %% 步骤4:按网络配置调整图像数据(预处理 + 数据增强 + 尺寸匹配) % ResNet50 输入是 224×224×3(RGB),你的数据可能是灰度/尺寸不同 % 因此给 datastore 指定 ReadFcn:读取每张图时自动做预处理 % 保证读出来的图片符合网络输入(比如灰度变 3 通道) img_ds_Train.ReadFcn = @(filename) preprocess(filename); img_ds_Test.ReadFcn = @(filename) preprocess(filename); % 数据增强器(让训练集更“多样”,提升泛化)(只用于训练集) % 让训练图片随机旋转/翻转/错切,等于“生成更多变化样本”,提高泛化。 augmenter = imageDataAugmenter( ... 'RandRotation', [-5 5], ... % 随机旋转 'RandXReflection', true, ... % 随机水平翻转 'RandYReflection', true, ... % 随机垂直翻转 'RandXShear', [-0.05 0.05], ... % 随机X方向错切 'RandYShear', [-0.05 0.05]); % 随机Y方向错切 % 将训练集图像调整为网络输入大小,并应用数据增强 aug_img_ds_train = augmentedImageDatastore([224 224], img_ds_Train, ... 'DataAugmentation', augmenter); % 测试集只做尺寸调整(通常不做增强,以保证评估公平) aug_img_ds_test = augmentedImageDatastore([224 224], img_ds_Test); %% 步骤5:对网络进行训练(设置训练超参数 + trainNetwork) % 设置训练选项 options = trainingOptions('adam', ... 'MaxEpochs', 30, ... % 最大训练轮数 'MiniBatchSize', 8, ... % 批大小 'Shuffle', 'every-epoch', ... % 每个epoch打乱 'InitialLearnRate', 1e-4, ... % 初始学习率 'Verbose', false, ... % 关闭命令行详细输出 'Plots', 'training-progress', ... % 显示训练曲线窗口 'ExecutionEnvironment', 'cpu'); % 使用CPU(可改为 'gpu') % 用训练集对网络训练/微调,训练好的网络保存到 netTransfer netTransfer = trainNetwork(aug_img_ds_train, LayerGraph, options); %% 步骤6:测试并查看结果(分类 + 可视化 + 指标计算) % 用训练好的网络(netTransfer)对测试集(aug_img_ds_test)进行分类 [YPred, scores] = classify(netTransfer, aug_img_ds_test); % 随机显示4张测试图片及其预测标签(直观看效果) idx = randperm(numel(img_ds_Test.Files), 4); figure; for i = 1:4 subplot(2, 2, i); I = readimage(img_ds_Test, idx(i)); % 注意:img_ds_Test.ReadFcn 已做预处理 imshow(I); label = YPred(idx(i)); title(string(label)); end % 计算分类准确率(Accuracy) YValidation = img_ds_Test.Labels; %YValidation:真实标签 accuracy = mean(YPred == YValidation) % YPred == YValidation:逐个比较是否预测对,得到 true/false % mean(...):true 当 1,false 当 0,所以均值就是准确率 % 绘制混淆矩阵(看每一类是否容易混淆) figure; confusionchart(YValidation, YPred); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% preprocess.m(你需要单独写成一个函数文件,或放在脚本末尾当本地函数) % 作用:把任意输入图片变成 ResNet50 可用的格式(224×224×3,且类型正确) % % 常见预处理内容: % 1)读取图像 % 2)如果是灰度图,复制成3通道 % 3)转为 single / uint8(一般直接保持 imread 输出也行) % 4)必要时做归一化/增强(本例主要在 augmenter 做) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % 如果你考试允许把函数写在脚本末尾,可以用下面这种“本地函数”写法: % (不允许的话,就新建 preprocess.m 文件并复制过去) % function I = preprocess(filename) % I = imread(filename); % 读图 % if size(I,3) == 1 % 如果是灰度图 % I = cat(3, I, I, I); % 复制成3通道RGB % end % % 这里不强制resize,因为 augmentedImageDatastore 已经做了尺寸调整 % % 如需统一数据类型,可加: % % I = im2uint8(I); % end