prefs: 机器学习加载路径。

Signed-off-by: feihu.wang <wfh45678@163.com>
This commit is contained in:
feihu.wang
2020-01-11 14:42:22 +08:00
parent 40f55fb130
commit e8d181e0d1

View File

@@ -7,6 +7,7 @@ import com.pgmmers.radar.vo.model.ModelConfVO;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
@@ -25,6 +26,9 @@ public class TensorDnnEstimator implements Estimator {
private ModelConfService modelConfService;
private Map<Long, SavedModelBundle> modelBundleMap = new HashMap<>();
@Value("${sys.conf.workdir}")
private String workDir;
@Override
public float predict(Long modelId, Map<String, Map<String, ?>> data) {
ModelConfVO mold = modelConfService.getByModelId(modelId);
@@ -77,7 +81,9 @@ public class TensorDnnEstimator implements Estimator {
private synchronized SavedModelBundle loadAndCacheModel(ModelConfVO mold) {
SavedModelBundle modelBundle = modelBundleMap.get(mold.getId());
if (modelBundle == null) {
File file = new File(mold.getPath());
String path = workDir + "\\" + mold.getPath();
String decomposePath = path.substring(0, path.lastIndexOf("."));
File file = new File(decomposePath);
if (file.exists() && file.isDirectory()) {
// 模型加载,比较耗时
try {