关于sklearn机器学习模型在Android设备的跨平台部署
基于Pmml-Android实现scikit-learn模型在Android设备的部署
关于sklearn机器学习模型的部署
部署目的
我们已经有一个由python训练的sklearn模型以及一个Android系统APP,我们的目的是将python导出的决策树pmml模型部署到Android设备上使用。
部署流程
当部署pmml到Android设备时,需要两步的工作,第一步是将离线训练得到的模型转化为pmml模型文件,这已经在训练中完成,第二步是将PMML模型文件载入在线预测环境,进行预测。这两步都需要相关的库支持
1.PMML模型导出
对于一个决策树模型,我们使用如下的语言来保存期PMML格式,其中我们包的版本为:
包名 | 版本 |
---|---|
scikit-learn | 0.23.1 |
scikit2pmml | 0.63.1 |
scikit-pandas | 2.0.2 |
模型保存在res/DecisionTreeClassifier.pmml目录下,注意x_train以及y_train是训练数据的输入与输出,保存的pmml模型可以用于任何平台。
# 导出pmml模型
clf2 = DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf2)])
pipeline.fit(x_train, y_train)
sklearn2pmml(pipeline, 'res/DecisionTreeClassifier.pmml', with_repr=True)
2.PMML模型序列化
在使用机器学习APP项目前,首先需要考虑如何将之前训练的scikit-learn决策树模型部署到手机上,由于之前的决策树模型已经实现sklearn2pmml包导出为pmml模式,pmml模型通过相关库的支持可以直接在计算机Java平台上运行,但是Android不提供自己的JAXB(Java Architecture for XML Binding (JAXB))运行时,通常不适合使用标准JAXB运行时间,如GlassFish Metro或EclipseLink MOXy。因此,org.dmg.pmml.pmml实例必须通过其他方式获得。 建议的解决方法是使用Java序列化来传输模型。这里通过pmml-android项目首先将pmml模型文件序列化为ser文件后在Android设备上执行,具体转化步骤如下:
(1) 将生成的pmml文件中的xml头部的PMML-Version以及jpmml-model版本化为一致的4.3版本(现在一般生成的是4.4版本,但是因为pmml-android库支持的是4.3版本,所以需要改成4.3):
即将:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_4" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.4">
修改为:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
(2) sklearn模型转换pmml格式后,需要先进行序列化操作。此处需要下载pmml-android项目,项目地址为https://github.com/loopGod/pmml-android,项目编译支持Android版本3.0以上,jdk版本为版本8以下,版本9及以上需要自己集成JAXB,这里建议使用jdk-1.8,项目目录如下:
C:.
│ LICENSE.txt
│ pom.xml # 库引入文件
│ README.md
│
├─pmml-android # Android操作系统的PMML evaluator库
│ │ pom.xml # 库引入文件
│ │
│ ├─src
│ │ └─main
│ │ ├─android
│ │ │ AndroidManifest.xml #清单文件
│ │ │
│ │ └─java
│ │ └─org
│ │ └─jpmml
│ │ └─android
│ │ EvaluatorUtil.java # 序列化pmml文件
│ │
│ └─target
│ │ original-pmml-android-1.0-SNAPSHOT.jar # 生成的JPMML-Android原始库
│ │ pmml-android-1.0-SNAPSHOT.jar
# 生成的JPMML-Android库
│ │
│ ├─classes
│ │ └─org
│ │ └─jpmml
│ │ └─android
│ │ EvaluatorUtil.class
│ │
│ ├─maven-archiver
│ │ pom.properties
│ │
│ └─maven-status目录
│ └─maven-compiler-plugin # maven编译相关
│ └─compile
│ └─default-compile
│ createdFiles.lst
│ inputFiles.lst
│
└─pmml-android-example # android=pmml示例程序
│ pom.xml
│
├─src
│ └─main
│ ├─android
│ │ │ AndroidManifest.xml #清单文件
│ │ │
│ │ └─resources # 资源文件
│ │ ├─…
│ ├─java
│ │ └─org
│ │ └─jpmml
│ │ └─android #主界面实现文件
│ │ MainActivity.java
│ │
│ └─pmml # 待使用pmml文件
│ model.pmml
│
└─target
│ AndroidManifest.xml
│ classes.dex
│ pmml-android-example-1.0-SNAPSHOT.apk #生成的示例apk
│ pmml-android-example-1.0-SNAPSHOT.ap_
│ pmml-android-example-1.0-SNAPSHOT.jar
│ R.txt
│
├─classes
│ └─org
│ └─jpmml
│ └─android #相关类文件目录
│
├─generated-sources
│ ├─combined-assets
│ │ model.pmml.ser #生成的由pmml序列化之后的文件
│ │
│ └─r
│ └─org
│ └─jpmml
│ └─android
│ BuildConfig.java
│ R.java
│
├─…其他maven相关文件
进入根目录后,首先将自己的pmml文件放入pmml-android-example/src/main/pmml/目录下,在项目根目录下执行命令,使用mvn clean install编译项目,生成文件包括:pmml-android/target/pmml-android-1.0-SNAPSHOT.jar以及pmml-android-example/target/pmml-android-example-1.0-SNAPSHOT.apk,一个JAR库和一个示例APK。在目录pmml-android-example/target/generated-sources/combined-assets/下能够找到生成的序列化之后的.ser文件。
3.Android APP使用序列化后文件进行预测
将上述生成的pmml-android-1.0-SNAPSHOT.jar库放在之后的性能异常检测APP的app/libs目录下,使用implementate进行引入使用,使得作为决策树模型作为资源文件部署到本地使用(放在assets资源目录下)。使用决策树进行性能异常检测的类如下(这是一个根据不同应用类型选择决策树的示例):
加载模型:
public Evaluator loadPmml(String pak,int a){
AssetManager assetManager = mContext.getAssets();
InputStream inputStream = null;
try {
if (pak.equals("com.taobao.taobao")) {
if (a == 1)
inputStream = assetManager.open("taobaocpu.pmml.ser");
if (a == 2)
inputStream = assetManager.open("taobaogpu.pmml.ser");
if (a == 3)
inputStream = assetManager.open("taobaomem.pmml.ser");
}else if(pak.equals("com.ss.android.ugc.aweme")) {
if (a == 1)
inputStream = assetManager.open("douyincpu.pmml.ser");
if (a == 2)
inputStream = assetManager.open("douyingpu.pmml.ser");
if (a == 3)
inputStream = assetManager.open("douyinmem.pmml.ser");
}else if(pak.equals("com.miHoYo.ys.mi")) {
if (a == 1)
inputStream = assetManager.open("yuanshencpu.pmml.ser");
if (a == 2)
inputStream = assetManager.open("yuanshengpu.pmml.ser");
if (a == 3)
inputStream = assetManager.open("yuanshenmem.pmml.ser");
}else{
if (a == 1)
inputStream = assetManager.open("entirecpu.pmml.ser");
if (a == 2)
inputStream = assetManager.open("entiregpu.pmml.ser");
if (a == 3)
inputStream = assetManager.open("entiremem.pmml.ser");
}
Log.d("model","Success: return model");
return EvaluatorUtil.createEvaluator(inputStream);
}catch (Exception e)
{
Log.d("model","Error: return model null");
e.printStackTrace();
return null;
}
}
使用模型进行预测,函数输入中的map为输入,需要与决策树中的输入字段对应:
public int predict(String pak, Evaluator evaluator,Map<String, Double> data0,Map<String, Double> data1)
{
Map<String, Double> datax = new HashMap<String, Double>();
if(pak.equals("com.taobao.taobao")||pak.equals("com.ss.android.ugc.aweme")||pak.equals("com.miHoYo.ys.mi")) {
datax = data0;
}else{
datax = data1;
}
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
Log.d("model","target: " + targetFieldName.getValue().toString() + " value: " + targetFieldValue.toString());
int primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
primitiveValue = (Integer) computable.getResult();
}
Log.d("model","result: " + String.valueOf(primitiveValue));
return primitiveValue;
}
成功实现决策树部署:
更多推荐
所有评论(0)