node使用tensorflow.js实现垃圾分类练习
t1.jsconst tf = require('@tensorflow/tfjs-node-gpu');const getData = require('./data');const TRAIN_PATH = './垃圾分类/垃圾分类/train';const OUT_PUT = 'output';const MOBILENET_URL = 'http://127.0.0.1:8080/data
·
t1.js
const tf = require('@tensorflow/tfjs-node-gpu');
const getData = require('./data');
const TRAIN_PATH = './垃圾分类/垃圾分类/train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';
(async () => {
const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT);
//引入别人训练好的模型
const mobilenet = await tf.loadLayersModel(MOBILENET_URL);
//查看模型结构
mobilenet.summary();
const model = tf.sequential();
//截断模型,复用了86个层
for (let i = 0; i < 86; ++i) {
const layer = mobilenet.layers[i];
layer.trainable = false;
model.add(layer);
}
//降维,摊平数据
model.add(tf.layers.flatten());
//设置全连接层
model.add(tf.layers.dense({
units: 10,
activation: 'relu'//设置激活函数,用于处理非线性问题
}));
model.add(tf.layers.dense({
units: classes.length,
activation: 'softmax'//用于多分类问题
}));
//设置损失函数,优化器
model.compile({
loss: 'sparseCategoricalCrossentropy',
optimizer: tf.train.adam(),
metrics:['acc']
});
//训练模型
// await model.fit(xs, ys, { epochs: 20 });
await model.fitDataset(ds, { epochs: 20 });
//保存模型
await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();
data.js
const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");
const img2x = (imgPath) => {
const buffer = fs.readFileSync(imgPath);
//清除数据
return tf.tidy(() => {
//把图片转成tensor
const imgt = tf.node.decodeImage(new Uint8Array(buffer));
//调整图片大小
const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);
//归一化
return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
});
}
const getData = async (traindir, output) => {
let classes = fs.readdirSync(traindir, 'utf-8');
classes = classes.slice(1);
fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));
// const inputs=[];
// const labels=[];
const data = [];
classes.forEach((dir, dirIndex) => {
fs.readdirSync(`${traindir}/${dir}`)
.filter(n => n.match(/jpg$/))
.slice(0, 1000)
.forEach(filename => {
const imgPath = `${traindir}/${dir}/${filename}`;
data.push({ imgPath, dirIndex });
// const buffer = fs.readFileSync(imgPath);
// const x=img2x(buffer);
// inputs.push(x);//图片tensor
// labels.push(dirIndex);//对应的标签
});
});
//打乱训练顺序,提高准确度
tf.util.shuffle(data);
const ds = tf.data.generator(function* () {
const count = data.length;
const batchSize = 32;
for (let start = 0; start < count; start += batchSize) {
const end = Math.min(start + batchSize, count);
console.log('当前批次', start);
yield tf.tidy(() => {
const inputs = [];
const labels = [];
for (let j = start; j < end; ++j) {
const { imgPath, dirIndex } = data[j];
const x = img2x(imgPath);
inputs.push(x);
labels.push(dirIndex);
}
const xs = tf.concat(inputs);
const ys = tf.tensor(labels);
return { xs, ys };
});
}
});
//一维tensor数组转成高维tensor数组
// const xs=tf.concat(inputs);
// const ys=tf.tensor(labels);
// return {xs,ys,classes};
return { ds, classes };
}
module.exports = getData;
代码和训练图片下载链接 https://www.ljkanka.com/index/t6
更多推荐
所有评论(0)