Merge pull request #28 from dou2/develop

提交机器学习模块的接口及实现
This commit is contained in:
feihu.wang
2019-12-06 18:06:24 +08:00
committed by GitHub
16 changed files with 620 additions and 0 deletions

View File

@@ -32,6 +32,7 @@
<mysql.version>5.1.47</mysql.version>
<springboot.version>2.1.7.RELEASE</springboot.version>
<tomcat.version>8.5.37</tomcat.version>
<tensorflow.version>1.12.0</tensorflow.version>
</properties>
@@ -240,6 +241,11 @@
<artifactId>mybatis</artifactId>
<version>3.4.6</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>${tensorflow.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@@ -0,0 +1,8 @@
package com.pgmmers.radar.dal.model;
import com.pgmmers.radar.vo.model.MoldVO;
public interface MoldDal {
MoldVO get(Long id);
MoldVO getByModelId(Long modelId);
}

View File

@@ -0,0 +1,62 @@
package com.pgmmers.radar.dal.model.impl;
import com.pgmmers.radar.dal.model.MoldDal;
import com.pgmmers.radar.mapper.MoldMapper;
import com.pgmmers.radar.mapper.MoldParamMapper;
import com.pgmmers.radar.model.MoldPO;
import com.pgmmers.radar.model.MoldParamPO;
import com.pgmmers.radar.vo.model.MoldParamVO;
import com.pgmmers.radar.vo.model.MoldVO;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import tk.mybatis.mapper.entity.Example;
import javax.annotation.Resource;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class MoldDalImpl implements MoldDal {
@Resource
private MoldMapper moldMapper;
@Resource
private MoldParamMapper moldParamMapper;
@Override
public MoldVO get(Long id) {
MoldPO moldPO = moldMapper.selectByPrimaryKey(id);
return convert(moldPO);
}
@Override
public MoldVO getByModelId(Long modelId) {
Example example = new Example(MoldPO.class);
example.createCriteria().andEqualTo("modelId", modelId);
MoldPO moldPO = moldMapper.selectOneByExample(example);
return convert(moldPO);
}
private MoldVO convert(MoldPO moldPO) {
MoldVO vo = null;
if (moldPO != null) {
vo = new MoldVO();
BeanUtils.copyProperties(moldPO, vo);
fitParams(vo);
}
return vo;
}
private void fitParams(MoldVO mold) {
if (mold != null) {
Example example = new Example(MoldParamPO.class);
example.createCriteria().andEqualTo("moldId", mold.getId());
List<MoldParamPO> moldParamList = moldParamMapper.selectByExample(example);
List<MoldParamVO> list = moldParamList.stream().map(moldParamPO -> {
MoldParamVO moldParamVO = new MoldParamVO();
BeanUtils.copyProperties(moldParamPO, moldParamVO);
return moldParamVO;
}).collect(Collectors.toList());
mold.setParams(list);
}
}
}

View File

@@ -0,0 +1,48 @@
package com.pgmmers.radar.vo.model;
import java.io.Serializable;
/**
* <p>
* 机器学习模型配置,目前只考虑输入层为离散值的情况,不考虑需要词嵌入和融入卷积层,其中
* 离散值通过表达式取数从前置流程传递过来.
* </p>
*
* @author guor
* @date 2019/11/28
*/
public class MoldParamVO implements Serializable {
private Long id;
/**
* 参数的key
*/
private String feed;
/**
* 取数表达式英文逗号分隔fields.deviceIdabstractions.log_uid_ip_1_day_qty
*/
private String expressions;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getFeed() {
return feed;
}
public void setFeed(String feed) {
this.feed = feed;
}
public String getExpressions() {
return expressions;
}
public void setExpressions(String expressions) {
this.expressions = expressions;
}
}

View File

@@ -0,0 +1,110 @@
package com.pgmmers.radar.vo.model;
import java.io.Serializable;
import java.util.Date;
import java.util.List;
/**
* <p>
* 机器学习模型配置,定义模型文件路径和参数
* </p>
*
* @author guor
* @date 2019/11/28
*/
public class MoldVO implements Serializable {
/**
* 自增ID主键
*/
private Long id;
/**
* 模型名称
*/
private String name;
/**
* 模型文件路径
*/
private String path;
/**
* tensorflow框架保存模型时设置的tag非tensorflow模型此字段为空
*/
private String tag;
/**
* 参数列表
*/
private List<MoldParamVO> params;
/**
* 模型输出操作名称predict_Y = tf.nn.softmax(softmax_before, name='predict')
*/
private String operation;
/**
* 模型更新时间
*/
private Date updateDate;
private String type;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public List<MoldParamVO> getParams() {
return params;
}
public void setParams(List<MoldParamVO> params) {
this.params = params;
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public Date getUpdateDate() {
return updateDate;
}
public void setUpdateDate(Date updateDate) {
this.updateDate = updateDate;
}
public String getTag() {
return tag;
}
public void setTag(String tag) {
this.tag = tag;
}
public String getOperation() {
return operation;
}
public void setOperation(String operation) {
this.operation = operation;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
}

View File

@@ -0,0 +1,7 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.MoldPO;
import tk.mybatis.mapper.common.Mapper;
public interface MoldMapper extends Mapper<MoldPO> {
}

View File

@@ -0,0 +1,7 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.MoldParamPO;
import tk.mybatis.mapper.common.Mapper;
public interface MoldParamMapper extends Mapper<MoldParamPO> {
}

View File

@@ -0,0 +1,95 @@
package com.pgmmers.radar.model;
import javax.persistence.Table;
import java.util.Date;
@Table(name = "engine_mold")
public class MoldPO {
/**
* 自增ID主键
*/
private Long id;
private Long modelId;
/**
* 模型名称
*/
private String name;
/**
* 模型文件路径
*/
private String path;
private String tag;
private String operation;
/**
* 模型更新时间
*/
private Date updateDate;
/**
* 模型类型
*/
private String type;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public Date getUpdateDate() {
return updateDate;
}
public void setUpdateDate(Date updateDate) {
this.updateDate = updateDate;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public Long getModelId() {
return modelId;
}
public void setModelId(Long modelId) {
this.modelId = modelId;
}
public String getTag() {
return tag;
}
public void setTag(String tag) {
this.tag = tag;
}
public String getOperation() {
return operation;
}
public void setOperation(String operation) {
this.operation = operation;
}
}

View File

@@ -0,0 +1,49 @@
package com.pgmmers.radar.model;
import javax.persistence.Table;
@Table(name = "engine_mold_param")
public class MoldParamPO {
private Long id;
private Long moldId;
/**
* 参数的key
*/
private String feed;
/**
* 取数表达式英文逗号分隔fields.deviceIdabstractions.log_uid_ip_1_day_qty
*/
private String expressions;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public Long getMoldId() {
return moldId;
}
public void setMoldId(Long moldId) {
this.moldId = moldId;
}
public String getFeed() {
return feed;
}
public void setFeed(String feed) {
this.feed = feed;
}
public String getExpressions() {
return expressions;
}
public void setExpressions(String expressions) {
this.expressions = expressions;
}
}

View File

@@ -47,5 +47,10 @@
<artifactId>spring-web</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,39 @@
package com.pgmmers.radar.service.impl.dnn;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldVO;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.HashMap;
import java.util.Map;
@Component
public class EstimatorContainer {
private Map<String, Estimator> estimatorMap = new HashMap<>();
@Resource
private MoldService moldService;
@Autowired
public void set(Estimator[] estimators) {
for (Estimator estimator : estimators) {
estimatorMap.put(estimator.getType(), estimator);
}
}
public Estimator getByType(String type) {
return estimatorMap.get(type);
}
public Estimator getByModelId(Long modelId) {
MoldVO mold = moldService.getByModelId(modelId);
if (mold == null) {
return null;
}
return getByType(mold.getType());
}
}

View File

@@ -0,0 +1,100 @@
package com.pgmmers.radar.service.impl.dnn;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldParamVO;
import com.pgmmers.radar.vo.model.MoldVO;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import javax.annotation.Resource;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Component
public class TensorDnnEstimator implements Estimator {
private static final Logger LOGGER = LoggerFactory.getLogger(TensorDnnEstimator.class);
@Resource
private MoldService moldService;
private Map<Long, SavedModelBundle> modelBundleMap = new HashMap<>();
@Override
public double predict(Long modelId, Map<String, Map<String, ?>> data) {
MoldVO mold = moldService.getByModelId(modelId);
if (mold == null) {
LOGGER.debug("没有找到模型配置ModelId{}", modelId);
return 0;
}
SavedModelBundle modelBundle = loadAndCacheModel(mold);
if (modelBundle == null) {
LOGGER.warn("模型文件不存在或加载失败ModelId{}", modelId);
return 0;
}
Session tfSession = modelBundle.session();
try {
List<MoldParamVO> params = mold.getParams();
Session.Runner runner = tfSession.runner();
for (MoldParamVO moldParam : params) {
runner.feed(moldParam.getFeed(), convert2Tensor(moldParam, data));
}
Tensor<?> output = runner.fetch(mold.getOperation()).run().get(0);
double[] results = new double[1];
output.copyTo(results);
return results[0];
} catch (Exception e) {
LOGGER.error("模型调用失败ModelId:" + modelId, e);
} finally {
tfSession.close();
}
return 0;
}
private Tensor<?> convert2Tensor(MoldParamVO moldParam, Map<String, Map<String, ?>> data) {
String expressions = moldParam.getExpressions();
if (StringUtils.isEmpty(expressions)) {
return Tensor.create(new double[1][1]);
}
String[] expList = expressions.split(",");
double[][] vec = new double[expList.length][1];
int a = 0;
for (String s : expList) {
double xn = 0;
String[] ss = s.split("\\.");//fields.deviceIdabstractions.log_uid_ip_1_day_qty
Map<String, ?> stringMap = data.get(ss[0]);
if (stringMap != null) {
xn = (Double) stringMap.get(ss[1]);
}
vec[a++][0] = xn;
}
return Tensor.create(vec);
}
private synchronized SavedModelBundle loadAndCacheModel(MoldVO mold) {
SavedModelBundle modelBundle = modelBundleMap.get(mold.getId());
if (modelBundle == null) {
File file = new File(mold.getPath());
if (file.exists() && file.isDirectory()) {
// 模型加载,比较耗时
try {
modelBundle = SavedModelBundle.load(mold.getPath(), mold.getTag());
modelBundleMap.put(mold.getId(), modelBundle);
} catch (Exception e) {
LOGGER.warn("模型加载失败MoldId{}", mold.getId());
}
}
}
return modelBundle;
}
@Override
public String getType() {
return Estimator.TYPE_TENSOR_DNN;
}
}

View File

@@ -0,0 +1,19 @@
package com.pgmmers.radar.service.impl.dnn;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.model.MoldService;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.Map;
@Component
public class TensorFlowEstimator implements Estimator {
@Resource
private MoldService moldService;
@Override
public double predict(Long modelId, Map<String, Map<String, ?>> data) {
return 0;
}
}

View File

@@ -0,0 +1,24 @@
package com.pgmmers.radar.service.impl.model;
import com.pgmmers.radar.dal.model.MoldDal;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldVO;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
@Service
public class MoldServiceImpl implements MoldService {
@Resource
private MoldDal moldDal;
@Override
public MoldVO get(Long id) {
return moldDal.get(id);
}
@Override
public MoldVO getByModelId(Long modelId) {
return moldDal.getByModelId(modelId);
}
}

View File

@@ -0,0 +1,33 @@
package com.pgmmers.radar.service.dnn;
import java.util.Map;
/**
* <p>
* 机器学习模型执行器接口
* </p>
* <p>
* 该接口内置TensorFlow实现目前版本只考虑输入层为离散值的情况不考虑词嵌入和融入卷积层其中
* 离散值通过表达式取数从前置流程传递过来,模型的预测结果为事件评分。
* </p>
* <p>
* 模型抽象y=f(x)
* </p>
*
* @author guor
* @date 2019/11/28
*/
public interface Estimator {
/**
* 线性回归模型
*/
String TYPE_REGRESSION = "REGRESSION";
/**
* 基于TensorFlow实现的神经网络模型
*/
String TYPE_TENSOR_DNN = "TENSOR_DNN";
double predict(Long modelId, Map<String, Map<String, ?>> data);
String getType();
}

View File

@@ -0,0 +1,8 @@
package com.pgmmers.radar.service.model;
import com.pgmmers.radar.vo.model.MoldVO;
public interface MoldService {
MoldVO get(Long id);
MoldVO getByModelId(Long modelId);
}