Merge pull request #30 from dou2/develop

提交模型调用
This commit is contained in:
feihu.wang
2019-12-25 14:04:25 +08:00
committed by GitHub
25 changed files with 268 additions and 157 deletions

View File

@@ -2,9 +2,12 @@ package com.pgmmers;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
import org.springframework.boot.autoconfigure.elasticsearch.rest.RestClientAutoConfiguration;
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
import tk.mybatis.spring.annotation.MapperScan;
@SpringBootApplication
@SpringBootApplication(exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, RestClientAutoConfiguration.class})
@MapperScan("com.pgmmers.radar.mapper")
public class AdminApplication
{

View File

@@ -0,0 +1,9 @@
package com.pgmmers.radar.dal.model;
import com.pgmmers.radar.vo.model.ModelConfVO;
public interface ModelConfDal {
ModelConfVO get(Long id);
ModelConfVO getByModelId(Long modelId);
}

View File

@@ -1,8 +0,0 @@
package com.pgmmers.radar.dal.model;
import com.pgmmers.radar.vo.model.MoldVO;
public interface MoldDal {
MoldVO get(Long id);
MoldVO getByModelId(Long modelId);
}

View File

@@ -0,0 +1,62 @@
package com.pgmmers.radar.dal.model.impl;
import com.pgmmers.radar.dal.model.ModelConfDal;
import com.pgmmers.radar.mapper.ModelConfMapper;
import com.pgmmers.radar.mapper.ModelConfParamMapper;
import com.pgmmers.radar.model.ModelConfPO;
import com.pgmmers.radar.model.ModelConfParamPO;
import com.pgmmers.radar.vo.model.ModelConfParamVO;
import com.pgmmers.radar.vo.model.ModelConfVO;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import tk.mybatis.mapper.entity.Example;
import javax.annotation.Resource;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class ModelConfDalImpl implements ModelConfDal {
@Resource
private ModelConfMapper modelConfMapper;
@Resource
private ModelConfParamMapper modelConfParamMapper;
@Override
public ModelConfVO get(Long id) {
ModelConfPO modelConfPO = modelConfMapper.selectByPrimaryKey(id);
return convert(modelConfPO);
}
@Override
public ModelConfVO getByModelId(Long modelId) {
Example example = new Example(ModelConfPO.class);
example.createCriteria().andEqualTo("modelId", modelId);
ModelConfPO modelConfPO = modelConfMapper.selectOneByExample(example);
return convert(modelConfPO);
}
private ModelConfVO convert(ModelConfPO modelConfPO) {
ModelConfVO vo = null;
if (modelConfPO != null) {
vo = new ModelConfVO();
BeanUtils.copyProperties(modelConfPO, vo);
fitParams(vo);
}
return vo;
}
private void fitParams(ModelConfVO mold) {
if (mold != null) {
Example example = new Example(ModelConfParamPO.class);
example.createCriteria().andEqualTo("moldId", mold.getId());
List<ModelConfParamPO> moldParamList = modelConfParamMapper.selectByExample(example);
List<ModelConfParamVO> list = moldParamList.stream().map(modelConfParamPO -> {
ModelConfParamVO modelConfParamVO = new ModelConfParamVO();
BeanUtils.copyProperties(modelConfParamPO, modelConfParamVO);
return modelConfParamVO;
}).collect(Collectors.toList());
mold.setParams(list);
}
}
}

View File

@@ -1,62 +0,0 @@
package com.pgmmers.radar.dal.model.impl;
import com.pgmmers.radar.dal.model.MoldDal;
import com.pgmmers.radar.mapper.MoldMapper;
import com.pgmmers.radar.mapper.MoldParamMapper;
import com.pgmmers.radar.model.MoldPO;
import com.pgmmers.radar.model.MoldParamPO;
import com.pgmmers.radar.vo.model.MoldParamVO;
import com.pgmmers.radar.vo.model.MoldVO;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import tk.mybatis.mapper.entity.Example;
import javax.annotation.Resource;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class MoldDalImpl implements MoldDal {
@Resource
private MoldMapper moldMapper;
@Resource
private MoldParamMapper moldParamMapper;
@Override
public MoldVO get(Long id) {
MoldPO moldPO = moldMapper.selectByPrimaryKey(id);
return convert(moldPO);
}
@Override
public MoldVO getByModelId(Long modelId) {
Example example = new Example(MoldPO.class);
example.createCriteria().andEqualTo("modelId", modelId);
MoldPO moldPO = moldMapper.selectOneByExample(example);
return convert(moldPO);
}
private MoldVO convert(MoldPO moldPO) {
MoldVO vo = null;
if (moldPO != null) {
vo = new MoldVO();
BeanUtils.copyProperties(moldPO, vo);
fitParams(vo);
}
return vo;
}
private void fitParams(MoldVO mold) {
if (mold != null) {
Example example = new Example(MoldParamPO.class);
example.createCriteria().andEqualTo("moldId", mold.getId());
List<MoldParamPO> moldParamList = moldParamMapper.selectByExample(example);
List<MoldParamVO> list = moldParamList.stream().map(moldParamPO -> {
MoldParamVO moldParamVO = new MoldParamVO();
BeanUtils.copyProperties(moldParamPO, moldParamVO);
return moldParamVO;
}).collect(Collectors.toList());
mold.setParams(list);
}
}
}

View File

@@ -11,7 +11,7 @@ import java.io.Serializable;
* @author guor
* @date 2019/11/28
*/
public class MoldParamVO implements Serializable {
public class ModelConfParamVO implements Serializable {
private Long id;
/**
* 参数的key

View File

@@ -12,7 +12,7 @@ import java.util.List;
* @author guor
* @date 2019/11/28
*/
public class MoldVO implements Serializable {
public class ModelConfVO implements Serializable {
/**
* 自增ID主键
*/
@@ -32,7 +32,7 @@ public class MoldVO implements Serializable {
/**
* 参数列表
*/
private List<MoldParamVO> params;
private List<ModelConfParamVO> params;
/**
* 模型输出操作名称predict_Y = tf.nn.softmax(softmax_before, name='predict')
*/
@@ -60,11 +60,11 @@ public class MoldVO implements Serializable {
this.name = name;
}
public List<MoldParamVO> getParams() {
public List<ModelConfParamVO> getParams() {
return params;
}
public void setParams(List<MoldParamVO> params) {
public void setParams(List<ModelConfParamVO> params) {
this.params = params;
}

View File

@@ -44,9 +44,32 @@
<plugin>
<groupId>org.mybatis.generator</groupId>
<artifactId>mybatis-generator-maven-plugin</artifactId>
<version>1.3.6</version>
<configuration>
<configurationFile>
${basedir}/src/main/resources/generator/generatorConfig.xml
</configurationFile>
<overwrite>true</overwrite>
<verbose>true</verbose>
</configuration>
<dependencies>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.47</version>
</dependency>
<dependency>
<groupId>tk.mybatis</groupId>
<artifactId>mapper</artifactId>
<version>4.0.0</version>
</dependency>
<dependency>
<groupId>com.pgmmers</groupId>
<artifactId>radar-commons</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
</plugin>
</plugins>
</build>

View File

@@ -0,0 +1,7 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.ModelConfPO;
import tk.mybatis.mapper.common.Mapper;
public interface ModelConfMapper extends Mapper<ModelConfPO> {
}

View File

@@ -0,0 +1,7 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.ModelConfParamPO;
import tk.mybatis.mapper.common.Mapper;
public interface ModelConfParamMapper extends Mapper<ModelConfParamPO> {
}

View File

@@ -1,7 +0,0 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.MoldPO;
import tk.mybatis.mapper.common.Mapper;
public interface MoldMapper extends Mapper<MoldPO> {
}

View File

@@ -1,7 +0,0 @@
package com.pgmmers.radar.mapper;
import com.pgmmers.radar.model.MoldParamPO;
import tk.mybatis.mapper.common.Mapper;
public interface MoldParamMapper extends Mapper<MoldParamPO> {
}

View File

@@ -3,8 +3,8 @@ package com.pgmmers.radar.model;
import javax.persistence.Table;
import java.util.Date;
@Table(name = "engine_mold")
public class MoldPO {
@Table(name = "engine_model_conf")
public class ModelConfPO {
/**
* 自增ID主键
*/

View File

@@ -2,8 +2,8 @@ package com.pgmmers.radar.model;
import javax.persistence.Table;
@Table(name = "engine_mold_param")
public class MoldParamPO {
@Table(name = "engine_model_conf_param")
public class ModelConfParamPO {
private Long id;
private Long moldId;
/**

View File

@@ -2,9 +2,12 @@ package com.pgmmers.radar;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
import org.springframework.boot.autoconfigure.elasticsearch.rest.RestClientAutoConfiguration;
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
import tk.mybatis.spring.annotation.MapperScan;
@SpringBootApplication
@SpringBootApplication(exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, RestClientAutoConfiguration.class})
@MapperScan("com.pgmmers.radar.mapper")
public class EngineApplication {

View File

@@ -1,8 +1,8 @@
package com.pgmmers.radar.service.impl.dnn;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldVO;
import com.pgmmers.radar.service.model.ModelConfService;
import com.pgmmers.radar.vo.model.ModelConfVO;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@@ -10,13 +10,13 @@ import javax.annotation.Resource;
import java.util.HashMap;
import java.util.Map;
//@Component
@Component
public class EstimatorContainer {
private Map<String, Estimator> estimatorMap = new HashMap<>();
@Resource
private MoldService moldService;
private ModelConfService modelConfService;
@Autowired
public void set(Estimator[] estimators) {
@@ -30,7 +30,7 @@ public class EstimatorContainer {
}
public Estimator getByModelId(Long modelId) {
MoldVO mold = moldService.getByModelId(modelId);
ModelConfVO mold = modelConfService.getByModelId(modelId);
if (mold == null) {
return null;
}

View File

@@ -1,9 +1,9 @@
package com.pgmmers.radar.service.impl.dnn;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldParamVO;
import com.pgmmers.radar.vo.model.MoldVO;
import com.pgmmers.radar.service.model.ModelConfService;
import com.pgmmers.radar.vo.model.ModelConfParamVO;
import com.pgmmers.radar.vo.model.ModelConfVO;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -18,16 +18,16 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
//@Component
@Component
public class TensorDnnEstimator implements Estimator {
private static final Logger LOGGER = LoggerFactory.getLogger(TensorDnnEstimator.class);
@Resource
private MoldService moldService;
private ModelConfService modelConfService;
private Map<Long, SavedModelBundle> modelBundleMap = new HashMap<>();
@Override
public double predict(Long modelId, Map<String, Map<String, ?>> data) {
MoldVO mold = moldService.getByModelId(modelId);
public float predict(Long modelId, Map<String, Map<String, ?>> data) {
ModelConfVO mold = modelConfService.getByModelId(modelId);
if (mold == null) {
LOGGER.debug("没有找到模型配置ModelId{}", modelId);
return 0;
@@ -39,44 +39,42 @@ public class TensorDnnEstimator implements Estimator {
}
Session tfSession = modelBundle.session();
try {
List<MoldParamVO> params = mold.getParams();
List<ModelConfParamVO> params = mold.getParams();
Session.Runner runner = tfSession.runner();
for (MoldParamVO moldParam : params) {
for (ModelConfParamVO moldParam : params) {
runner.feed(moldParam.getFeed(), convert2Tensor(moldParam, data));
}
Tensor<?> output = runner.fetch(mold.getOperation()).run().get(0);
double[] results = new double[1];
float[][] results = new float[1][1];
output.copyTo(results);
return results[0];
return results[0][0];
} catch (Exception e) {
LOGGER.error("模型调用失败ModelId:" + modelId, e);
} finally {
tfSession.close();
}
return 0;
}
private Tensor<?> convert2Tensor(MoldParamVO moldParam, Map<String, Map<String, ?>> data) {
private Tensor<?> convert2Tensor(ModelConfParamVO moldParam, Map<String, Map<String, ?>> data) {
String expressions = moldParam.getExpressions();
if (StringUtils.isEmpty(expressions)) {
return Tensor.create(new double[1][1]);
return Tensor.create(new float[1][1]);
}
String[] expList = expressions.split(",");
double[][] vec = new double[expList.length][1];
float[][] vec = new float[1][expList.length];
int a = 0;
for (String s : expList) {
double xn = 0;
float xn = 0;
String[] ss = s.split("\\.");//fields.deviceIdabstractions.log_uid_ip_1_day_qty
Map<String, ?> stringMap = data.get(ss[0]);
if (stringMap != null) {
xn = (Double) stringMap.get(ss[1]);
xn = Float.parseFloat(String.valueOf(stringMap.get(ss[1])));
}
vec[a++][0] = xn;
vec[0][a++] = xn;
}
return Tensor.create(vec);
}
private synchronized SavedModelBundle loadAndCacheModel(MoldVO mold) {
private synchronized SavedModelBundle loadAndCacheModel(ModelConfVO mold) {
SavedModelBundle modelBundle = modelBundleMap.get(mold.getId());
if (modelBundle == null) {
File file = new File(mold.getPath());
@@ -97,4 +95,23 @@ public class TensorDnnEstimator implements Estimator {
public String getType() {
return Estimator.TYPE_TENSOR_DNN;
}
public static void main(String[] args) {
SavedModelBundle modelBundle = SavedModelBundle.load("d:/radar01", "serve");
Session tfSession = modelBundle.session();
try {
Session.Runner runner = tfSession.runner();
float[][] aa = new float[1][6];
aa[0] = new float[]{20f, 1f, 1f, 1f, 10f, 2f};
runner.feed("input_x", Tensor.create(aa));
Tensor<?> output = runner.fetch("output_y/BiasAdd").run().get(0);
float[][] results = new float[1][1];
output.copyTo(results);
System.out.println(results[0][0]);
} catch (Exception e) {
e.printStackTrace();
} finally {
tfSession.close();
}
}
}

View File

@@ -5,9 +5,11 @@ import com.pgmmers.radar.enums.AggregateType;
import com.pgmmers.radar.enums.FieldType;
import com.pgmmers.radar.enums.Operator;
import com.pgmmers.radar.enums.StatusType;
import com.pgmmers.radar.service.dnn.Estimator;
import com.pgmmers.radar.service.engine.AggregateCommand;
import com.pgmmers.radar.service.engine.AntiFraudEngine;
import com.pgmmers.radar.service.engine.vo.*;
import com.pgmmers.radar.service.impl.dnn.EstimatorContainer;
import com.pgmmers.radar.service.model.*;
import com.pgmmers.radar.util.DateUtils;
import com.pgmmers.radar.util.GroovyScriptUtil;
@@ -51,6 +53,9 @@ public class AntiFraudEngineImpl implements AntiFraudEngine {
@Autowired
private RuleService ruleService;
@Autowired
private EstimatorContainer estimatorContainer;
@Override
public AbstractionResult executeAbstraction(Long modelId, Map<String, Map<String, ?>> data) {
AbstractionResult result = new AbstractionResult();
@@ -243,9 +248,13 @@ public class AntiFraudEngineImpl implements AntiFraudEngine {
@Override
public AdaptationResult executeAdaptation(Long modelId, Map<String, Map<String, ?>> data) {
AdaptationResult result = new AdaptationResult();
// TODO Auto-generated method stub
Estimator estimator = estimatorContainer.getByModelId(modelId);
if(estimator != null) {
float score = estimator.predict(modelId, data);
result.getAdaptationMap().put("score", score);
}
result.setSuccess(true);
data.put("adapations", new HashMap<String, Object>());
data.put("adapations", result.getAdaptationMap());
return result;
}

View File

@@ -0,0 +1,24 @@
package com.pgmmers.radar.service.impl.model;
import com.pgmmers.radar.dal.model.ModelConfDal;
import com.pgmmers.radar.service.model.ModelConfService;
import com.pgmmers.radar.vo.model.ModelConfVO;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
@Service
public class ModelConfServiceImpl implements ModelConfService {
@Resource
private ModelConfDal modelConfDal;
@Override
public ModelConfVO get(Long id) {
return modelConfDal.get(id);
}
@Override
public ModelConfVO getByModelId(Long modelId) {
return modelConfDal.getByModelId(modelId);
}
}

View File

@@ -1,24 +0,0 @@
package com.pgmmers.radar.service.impl.model;
import com.pgmmers.radar.dal.model.MoldDal;
import com.pgmmers.radar.service.model.MoldService;
import com.pgmmers.radar.vo.model.MoldVO;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
@Service
public class MoldServiceImpl implements MoldService {
@Resource
private MoldDal moldDal;
@Override
public MoldVO get(Long id) {
return moldDal.get(id);
}
@Override
public MoldVO getByModelId(Long modelId) {
return moldDal.getByModelId(modelId);
}
}

View File

@@ -27,7 +27,7 @@ public interface Estimator {
*/
String TYPE_TENSOR_DNN = "TENSOR_DNN";
double predict(Long modelId, Map<String, Map<String, ?>> data);
float predict(Long modelId, Map<String, Map<String, ?>> data);
String getType();
}

View File

@@ -0,0 +1,9 @@
package com.pgmmers.radar.service.model;
import com.pgmmers.radar.vo.model.ModelConfVO;
public interface ModelConfService {
ModelConfVO get(Long id);
ModelConfVO getByModelId(Long modelId);
}

View File

@@ -1,8 +0,0 @@
package com.pgmmers.radar.service.model;
import com.pgmmers.radar.vo.model.MoldVO;
public interface MoldService {
MoldVO get(Long id);
MoldVO getByModelId(Long modelId);
}

BIN
resources/radar-tran-v1.zip Normal file

Binary file not shown.

54
sql/radar-1.0.3.sql Normal file
View File

@@ -0,0 +1,54 @@
/*
Navicat MySQL Data Transfer
Source Server : test@172.30.0.6
Source Server Version : 50726
Source Host : 172.30.0.6:3306
Source Database : radar
Target Server Type : MYSQL
Target Server Version : 50726
File Encoding : 65001
Date: 2019-12-24 18:02:12
*/
SET FOREIGN_KEY_CHECKS=0;
-- ----------------------------
-- Table structure for engine_model_conf
-- ----------------------------
DROP TABLE IF EXISTS `engine_model_conf`;
CREATE TABLE `engine_model_conf` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`model_id` int(11) DEFAULT NULL,
`name` varchar(255) DEFAULT NULL,
`path` varchar(255) DEFAULT NULL,
`tag` varchar(255) DEFAULT NULL,
`operation` varchar(255) DEFAULT NULL,
`update_date` datetime DEFAULT NULL,
`type` varchar(255) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8;
-- ----------------------------
-- Records of engine_model_conf
-- ----------------------------
INSERT INTO `engine_model_conf` VALUES ('1', '103', '交易ai模型', 'd:/radar01', 'serve', 'output_y/BiasAdd', '2019-12-24 17:38:38', 'TENSOR_DNN');
-- ----------------------------
-- Table structure for engine_model_conf_param
-- ----------------------------
DROP TABLE IF EXISTS `engine_model_conf_param`;
CREATE TABLE `engine_model_conf_param` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`mold_id` int(11) DEFAULT NULL,
`feed` varchar(255) DEFAULT NULL,
`expressions` varchar(255) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8;
-- ----------------------------
-- Records of engine_model_conf_param
-- ----------------------------
INSERT INTO `engine_model_conf_param` VALUES ('1', '1', 'input_x', 'abstractions.tran_uid_ip_1_day_qty,abstractions.tran_did_ip_1_day_qty,abstractions.tran_ip_1_day_qty,abstractions.tran_ip_1_hour_qty,abstractions.tran_ip_1_day_amt,abstractions.tran_did_1_day_qty');