mirror of
https://gitee.com/freshday/radar.git
synced 2026-03-22 04:37:16 +08:00
@@ -7,6 +7,7 @@ import com.pgmmers.radar.vo.model.ModelConfVO;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.tensorflow.SavedModelBundle;
|
||||
import org.tensorflow.Session;
|
||||
@@ -25,6 +26,9 @@ public class TensorDnnEstimator implements Estimator {
|
||||
private ModelConfService modelConfService;
|
||||
private Map<Long, SavedModelBundle> modelBundleMap = new HashMap<>();
|
||||
|
||||
@Value("${sys.conf.workdir}")
|
||||
private String workDir;
|
||||
|
||||
@Override
|
||||
public float predict(Long modelId, Map<String, Map<String, ?>> data) {
|
||||
ModelConfVO mold = modelConfService.getByModelId(modelId);
|
||||
@@ -77,7 +81,9 @@ public class TensorDnnEstimator implements Estimator {
|
||||
private synchronized SavedModelBundle loadAndCacheModel(ModelConfVO mold) {
|
||||
SavedModelBundle modelBundle = modelBundleMap.get(mold.getId());
|
||||
if (modelBundle == null) {
|
||||
File file = new File(mold.getPath());
|
||||
String path = workDir + "\\" + mold.getPath();
|
||||
String decomposePath = path.substring(0, path.lastIndexOf("."));
|
||||
File file = new File(decomposePath);
|
||||
if (file.exists() && file.isDirectory()) {
|
||||
// 模型加载,比较耗时
|
||||
try {
|
||||
|
||||
Reference in New Issue
Block a user