mirror of
https://gitee.com/freshday/radar.git
synced 2026-03-22 12:47:16 +08:00
6
pom.xml
6
pom.xml
@@ -32,6 +32,7 @@
|
||||
<mysql.version>5.1.47</mysql.version>
|
||||
<springboot.version>2.1.7.RELEASE</springboot.version>
|
||||
<tomcat.version>8.5.37</tomcat.version>
|
||||
<tensorflow.version>1.12.0</tensorflow.version>
|
||||
</properties>
|
||||
|
||||
|
||||
@@ -240,6 +241,11 @@
|
||||
<artifactId>mybatis</artifactId>
|
||||
<version>3.4.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
<version>${tensorflow.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.pgmmers.radar.dal.model;
|
||||
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
|
||||
public interface MoldDal {
|
||||
MoldVO get(Long id);
|
||||
MoldVO getByModelId(Long modelId);
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.pgmmers.radar.dal.model.impl;
|
||||
|
||||
import com.pgmmers.radar.dal.model.MoldDal;
|
||||
import com.pgmmers.radar.mapper.MoldMapper;
|
||||
import com.pgmmers.radar.mapper.MoldParamMapper;
|
||||
import com.pgmmers.radar.model.MoldPO;
|
||||
import com.pgmmers.radar.model.MoldParamPO;
|
||||
import com.pgmmers.radar.vo.model.MoldParamVO;
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import tk.mybatis.mapper.entity.Example;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class MoldDalImpl implements MoldDal {
|
||||
@Resource
|
||||
private MoldMapper moldMapper;
|
||||
@Resource
|
||||
private MoldParamMapper moldParamMapper;
|
||||
|
||||
@Override
|
||||
public MoldVO get(Long id) {
|
||||
MoldPO moldPO = moldMapper.selectByPrimaryKey(id);
|
||||
return convert(moldPO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MoldVO getByModelId(Long modelId) {
|
||||
Example example = new Example(MoldPO.class);
|
||||
example.createCriteria().andEqualTo("modelId", modelId);
|
||||
MoldPO moldPO = moldMapper.selectOneByExample(example);
|
||||
return convert(moldPO);
|
||||
}
|
||||
|
||||
private MoldVO convert(MoldPO moldPO) {
|
||||
MoldVO vo = null;
|
||||
if (moldPO != null) {
|
||||
vo = new MoldVO();
|
||||
BeanUtils.copyProperties(moldPO, vo);
|
||||
fitParams(vo);
|
||||
}
|
||||
return vo;
|
||||
}
|
||||
|
||||
private void fitParams(MoldVO mold) {
|
||||
if (mold != null) {
|
||||
Example example = new Example(MoldParamPO.class);
|
||||
example.createCriteria().andEqualTo("moldId", mold.getId());
|
||||
List<MoldParamPO> moldParamList = moldParamMapper.selectByExample(example);
|
||||
List<MoldParamVO> list = moldParamList.stream().map(moldParamPO -> {
|
||||
MoldParamVO moldParamVO = new MoldParamVO();
|
||||
BeanUtils.copyProperties(moldParamPO, moldParamVO);
|
||||
return moldParamVO;
|
||||
}).collect(Collectors.toList());
|
||||
mold.setParams(list);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.pgmmers.radar.vo.model;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* 机器学习模型配置,目前只考虑输入层为离散值的情况,不考虑需要词嵌入和融入卷积层,其中
|
||||
* 离散值通过表达式取数从前置流程传递过来.
|
||||
* </p>
|
||||
*
|
||||
* @author guor
|
||||
* @date 2019/11/28
|
||||
*/
|
||||
public class MoldParamVO implements Serializable {
|
||||
private Long id;
|
||||
/**
|
||||
* 参数的key
|
||||
*/
|
||||
private String feed;
|
||||
/**
|
||||
* 取数表达式,英文逗号分隔fields.deviceId,abstractions.log_uid_ip_1_day_qty
|
||||
*/
|
||||
private String expressions;
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getFeed() {
|
||||
return feed;
|
||||
}
|
||||
|
||||
public void setFeed(String feed) {
|
||||
this.feed = feed;
|
||||
}
|
||||
|
||||
public String getExpressions() {
|
||||
return expressions;
|
||||
}
|
||||
|
||||
public void setExpressions(String expressions) {
|
||||
this.expressions = expressions;
|
||||
}
|
||||
}
|
||||
110
radar-dal/src/main/java/com.pgmmers.radar/vo/model/MoldVO.java
Normal file
110
radar-dal/src/main/java/com.pgmmers.radar/vo/model/MoldVO.java
Normal file
@@ -0,0 +1,110 @@
|
||||
package com.pgmmers.radar.vo.model;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* 机器学习模型配置,定义模型文件路径和参数
|
||||
* </p>
|
||||
*
|
||||
* @author guor
|
||||
* @date 2019/11/28
|
||||
*/
|
||||
public class MoldVO implements Serializable {
|
||||
/**
|
||||
* 自增ID,主键
|
||||
*/
|
||||
private Long id;
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String name;
|
||||
/**
|
||||
* 模型文件路径
|
||||
*/
|
||||
private String path;
|
||||
/**
|
||||
* tensorflow框架保存模型时设置的tag,非tensorflow模型此字段为空
|
||||
*/
|
||||
private String tag;
|
||||
/**
|
||||
* 参数列表
|
||||
*/
|
||||
private List<MoldParamVO> params;
|
||||
/**
|
||||
* 模型输出操作名称,predict_Y = tf.nn.softmax(softmax_before, name='predict')
|
||||
*/
|
||||
private String operation;
|
||||
/**
|
||||
* 模型更新时间
|
||||
*/
|
||||
private Date updateDate;
|
||||
|
||||
private String type;
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public List<MoldParamVO> getParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
public void setParams(List<MoldParamVO> params) {
|
||||
this.params = params;
|
||||
}
|
||||
|
||||
public String getPath() {
|
||||
return path;
|
||||
}
|
||||
|
||||
public void setPath(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
|
||||
public Date getUpdateDate() {
|
||||
return updateDate;
|
||||
}
|
||||
|
||||
public void setUpdateDate(Date updateDate) {
|
||||
this.updateDate = updateDate;
|
||||
}
|
||||
|
||||
public String getTag() {
|
||||
return tag;
|
||||
}
|
||||
|
||||
public void setTag(String tag) {
|
||||
this.tag = tag;
|
||||
}
|
||||
|
||||
public String getOperation() {
|
||||
return operation;
|
||||
}
|
||||
|
||||
public void setOperation(String operation) {
|
||||
this.operation = operation;
|
||||
}
|
||||
|
||||
public String getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public void setType(String type) {
|
||||
this.type = type;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.pgmmers.radar.mapper;
|
||||
|
||||
import com.pgmmers.radar.model.MoldPO;
|
||||
import tk.mybatis.mapper.common.Mapper;
|
||||
|
||||
public interface MoldMapper extends Mapper<MoldPO> {
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.pgmmers.radar.mapper;
|
||||
|
||||
import com.pgmmers.radar.model.MoldParamPO;
|
||||
import tk.mybatis.mapper.common.Mapper;
|
||||
|
||||
public interface MoldParamMapper extends Mapper<MoldParamPO> {
|
||||
}
|
||||
95
radar-dao/src/main/java/com/pgmmers/radar/model/MoldPO.java
Normal file
95
radar-dao/src/main/java/com/pgmmers/radar/model/MoldPO.java
Normal file
@@ -0,0 +1,95 @@
|
||||
package com.pgmmers.radar.model;
|
||||
|
||||
import javax.persistence.Table;
|
||||
import java.util.Date;
|
||||
|
||||
@Table(name = "engine_mold")
|
||||
public class MoldPO {
|
||||
/**
|
||||
* 自增ID,主键
|
||||
*/
|
||||
private Long id;
|
||||
private Long modelId;
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String name;
|
||||
/**
|
||||
* 模型文件路径
|
||||
*/
|
||||
private String path;
|
||||
private String tag;
|
||||
private String operation;
|
||||
/**
|
||||
* 模型更新时间
|
||||
*/
|
||||
private Date updateDate;
|
||||
/**
|
||||
* 模型类型
|
||||
*/
|
||||
private String type;
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getPath() {
|
||||
return path;
|
||||
}
|
||||
|
||||
public void setPath(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
|
||||
public Date getUpdateDate() {
|
||||
return updateDate;
|
||||
}
|
||||
|
||||
public void setUpdateDate(Date updateDate) {
|
||||
this.updateDate = updateDate;
|
||||
}
|
||||
|
||||
public String getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public void setType(String type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public void setModelId(Long modelId) {
|
||||
this.modelId = modelId;
|
||||
}
|
||||
|
||||
public String getTag() {
|
||||
return tag;
|
||||
}
|
||||
|
||||
public void setTag(String tag) {
|
||||
this.tag = tag;
|
||||
}
|
||||
|
||||
public String getOperation() {
|
||||
return operation;
|
||||
}
|
||||
|
||||
public void setOperation(String operation) {
|
||||
this.operation = operation;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package com.pgmmers.radar.model;
|
||||
|
||||
import javax.persistence.Table;
|
||||
|
||||
@Table(name = "engine_mold_param")
|
||||
public class MoldParamPO {
|
||||
private Long id;
|
||||
private Long moldId;
|
||||
/**
|
||||
* 参数的key
|
||||
*/
|
||||
private String feed;
|
||||
/**
|
||||
* 取数表达式,英文逗号分隔fields.deviceId,abstractions.log_uid_ip_1_day_qty
|
||||
*/
|
||||
private String expressions;
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public Long getMoldId() {
|
||||
return moldId;
|
||||
}
|
||||
|
||||
public void setMoldId(Long moldId) {
|
||||
this.moldId = moldId;
|
||||
}
|
||||
|
||||
public String getFeed() {
|
||||
return feed;
|
||||
}
|
||||
|
||||
public void setFeed(String feed) {
|
||||
this.feed = feed;
|
||||
}
|
||||
|
||||
public String getExpressions() {
|
||||
return expressions;
|
||||
}
|
||||
|
||||
public void setExpressions(String expressions) {
|
||||
this.expressions = expressions;
|
||||
}
|
||||
}
|
||||
@@ -47,5 +47,10 @@
|
||||
<artifactId>spring-web</artifactId>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.pgmmers.radar.service.impl.dnn;
|
||||
|
||||
import com.pgmmers.radar.service.dnn.Estimator;
|
||||
import com.pgmmers.radar.service.model.MoldService;
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class EstimatorContainer {
|
||||
|
||||
private Map<String, Estimator> estimatorMap = new HashMap<>();
|
||||
|
||||
@Resource
|
||||
private MoldService moldService;
|
||||
|
||||
@Autowired
|
||||
public void set(Estimator[] estimators) {
|
||||
for (Estimator estimator : estimators) {
|
||||
estimatorMap.put(estimator.getType(), estimator);
|
||||
}
|
||||
}
|
||||
|
||||
public Estimator getByType(String type) {
|
||||
return estimatorMap.get(type);
|
||||
}
|
||||
|
||||
public Estimator getByModelId(Long modelId) {
|
||||
MoldVO mold = moldService.getByModelId(modelId);
|
||||
if (mold == null) {
|
||||
return null;
|
||||
}
|
||||
return getByType(mold.getType());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.pgmmers.radar.service.impl.dnn;
|
||||
|
||||
import com.pgmmers.radar.service.dnn.Estimator;
|
||||
import com.pgmmers.radar.service.model.MoldService;
|
||||
import com.pgmmers.radar.vo.model.MoldParamVO;
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class TensorDnnEstimator implements Estimator {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TensorDnnEstimator.class);
|
||||
@Resource
|
||||
private MoldService moldService;
|
||||
private Map<Long, SavedModelBundle> modelBundleMap = new HashMap<>();
|
||||
|
||||
@Override
|
||||
public double predict(Long modelId, Map<String, Map<String, ?>> data) {
|
||||
MoldVO mold = moldService.getByModelId(modelId);
|
||||
if (mold == null) {
|
||||
LOGGER.debug("没有找到模型配置,ModelId:{}", modelId);
|
||||
return 0;
|
||||
}
|
||||
SavedModelBundle modelBundle = loadAndCacheModel(mold);
|
||||
if (modelBundle == null) {
|
||||
LOGGER.warn("模型文件不存在或加载失败,ModelId:{}", modelId);
|
||||
return 0;
|
||||
}
|
||||
Session tfSession = modelBundle.session();
|
||||
try {
|
||||
List<MoldParamVO> params = mold.getParams();
|
||||
Session.Runner runner = tfSession.runner();
|
||||
for (MoldParamVO moldParam : params) {
|
||||
runner.feed(moldParam.getFeed(), convert2Tensor(moldParam, data));
|
||||
}
|
||||
Tensor<?> output = runner.fetch(mold.getOperation()).run().get(0);
|
||||
double[] results = new double[1];
|
||||
output.copyTo(results);
|
||||
return results[0];
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("模型调用失败,ModelId:" + modelId, e);
|
||||
} finally {
|
||||
tfSession.close();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private Tensor<?> convert2Tensor(MoldParamVO moldParam, Map<String, Map<String, ?>> data) {
|
||||
String expressions = moldParam.getExpressions();
|
||||
if (StringUtils.isEmpty(expressions)) {
|
||||
return Tensor.create(new double[1][1]);
|
||||
}
|
||||
String[] expList = expressions.split(",");
|
||||
double[][] vec = new double[expList.length][1];
|
||||
int a = 0;
|
||||
for (String s : expList) {
|
||||
double xn = 0;
|
||||
String[] ss = s.split("\\.");//fields.deviceId,abstractions.log_uid_ip_1_day_qty
|
||||
Map<String, ?> stringMap = data.get(ss[0]);
|
||||
if (stringMap != null) {
|
||||
xn = (Double) stringMap.get(ss[1]);
|
||||
}
|
||||
vec[a++][0] = xn;
|
||||
}
|
||||
return Tensor.create(vec);
|
||||
}
|
||||
|
||||
private synchronized SavedModelBundle loadAndCacheModel(MoldVO mold) {
|
||||
SavedModelBundle modelBundle = modelBundleMap.get(mold.getId());
|
||||
if (modelBundle == null) {
|
||||
File file = new File(mold.getPath());
|
||||
if (file.exists() && file.isDirectory()) {
|
||||
// 模型加载,比较耗时
|
||||
try {
|
||||
modelBundle = SavedModelBundle.load(mold.getPath(), mold.getTag());
|
||||
modelBundleMap.put(mold.getId(), modelBundle);
|
||||
} catch (Exception e) {
|
||||
LOGGER.warn("模型加载失败,MoldId:{}", mold.getId());
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelBundle;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return Estimator.TYPE_TENSOR_DNN;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.pgmmers.radar.service.impl.dnn;
|
||||
|
||||
import com.pgmmers.radar.service.dnn.Estimator;
|
||||
import com.pgmmers.radar.service.model.MoldService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class TensorFlowEstimator implements Estimator {
|
||||
@Resource
|
||||
private MoldService moldService;
|
||||
|
||||
@Override
|
||||
public double predict(Long modelId, Map<String, Map<String, ?>> data) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.pgmmers.radar.service.impl.model;
|
||||
|
||||
import com.pgmmers.radar.dal.model.MoldDal;
|
||||
import com.pgmmers.radar.service.model.MoldService;
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
|
||||
@Service
|
||||
public class MoldServiceImpl implements MoldService {
|
||||
@Resource
|
||||
private MoldDal moldDal;
|
||||
|
||||
@Override
|
||||
public MoldVO get(Long id) {
|
||||
return moldDal.get(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MoldVO getByModelId(Long modelId) {
|
||||
return moldDal.getByModelId(modelId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.pgmmers.radar.service.dnn;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* 机器学习模型执行器接口
|
||||
* </p>
|
||||
* <p>
|
||||
* 该接口内置TensorFlow实现,目前版本只考虑输入层为离散值的情况,不考虑词嵌入和融入卷积层,其中
|
||||
* 离散值通过表达式取数从前置流程传递过来,模型的预测结果为事件评分。
|
||||
* </p>
|
||||
* <p>
|
||||
* 模型抽象:y=f(x)
|
||||
* </p>
|
||||
*
|
||||
* @author guor
|
||||
* @date 2019/11/28
|
||||
*/
|
||||
public interface Estimator {
|
||||
/**
|
||||
* 线性回归模型
|
||||
*/
|
||||
String TYPE_REGRESSION = "REGRESSION";
|
||||
/**
|
||||
* 基于TensorFlow实现的神经网络模型
|
||||
*/
|
||||
String TYPE_TENSOR_DNN = "TENSOR_DNN";
|
||||
|
||||
double predict(Long modelId, Map<String, Map<String, ?>> data);
|
||||
|
||||
String getType();
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.pgmmers.radar.service.model;
|
||||
|
||||
import com.pgmmers.radar.vo.model.MoldVO;
|
||||
|
||||
public interface MoldService {
|
||||
MoldVO get(Long id);
|
||||
MoldVO getByModelId(Long modelId);
|
||||
}
|
||||
Reference in New Issue
Block a user