wget https://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images-square.zip
解压,方便后面演习模型利用
unzip ut-zap50k-images-square.zip
3.代码工程实验目的
基于djl实现图片分类
pom.xml<?xml version="1.0" encoding="UTF-8"?><project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>3.2.1</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>djl</artifactId> <properties> <maven.compiler.source>17</maven.compiler.source> <maven.compiler.target>17</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> <!-- DJL --> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> </dependency> <!-- pytorch-engine--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <scope>runtime</scope> </dependency> </dependencies> <profiles> <profile> <id>windows</id> <activation> <activeByDefault>true</activeByDefault> </activation> <dependencies> <!-- Windows CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>centos7</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For Pre-CXX11 build (CentOS7)--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-x86_64</classifier> <version>2.0.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>linux</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- Linux CPU --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>linux-x86_64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> <profile> <id>aarch64</id> <activation> <activeByDefault>false</activeByDefault> </activation> <dependencies> <!-- For aarch64 build--> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-aarch64</classifier> <scope>runtime</scope> <version>2.0.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.1-0.23.0</version> <scope>runtime</scope> </dependency> </dependencies> </profile> </profiles> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>0.23.0</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement></project>
conotroller
package com.et.controller;import ai.djl.MalformedModelException;import ai.djl.translate.TranslateException;import com.et.service.ImageClassificationService;import lombok.RequiredArgsConstructor;import org.springframework.core.io.FileSystemResource;import org.springframework.core.io.Resource;import org.springframework.http.HttpHeaders;import org.springframework.http.MediaType;import org.springframework.http.ResponseEntity;import org.springframework.web.bind.annotation.;import org.springframework.web.multipart.MultipartFile;import java.io.IOException;import java.nio.file.Files;import java.nio.file.Path;import java.nio.file.Paths;import java.util.ArrayList;import java.util.List;import java.util.Random;import java.util.stream.Stream;@RestController@RequiredArgsConstructorpublic class ImageClassificationController { private final ImageClassificationService imageClassificationService; @PostMapping(path = "/analyze") public String predict(@RequestPart("image") MultipartFile image, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, MalformedModelException, IOException { return imageClassificationService.predict(image, modePath); } @PostMapping(path = "/training") public String training(@RequestParam(defaultValue = "/home/djl-test/images-test") String datasetRoot, @RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException { return imageClassificationService.training(datasetRoot, modePath); } @GetMapping("/download") public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) { List<String> imgPathList = new ArrayList<>(); try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) { // Filter only regular files (excluding directories) paths.filter(Files::isRegularFile) .forEach(c-> imgPathList.add(c.toString())); } catch (IOException e) { return ResponseEntity.status(500).build(); } Random random = new Random(); String filePath = imgPathList.get(random.nextInt(imgPathList.size())); Path file = Paths.get(filePath); Resource resource = new FileSystemResource(file.toFile()); if (!resource.exists()) { return ResponseEntity.notFound().build(); } HttpHeaders headers = new HttpHeaders(); headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString()); headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE); try { return ResponseEntity.ok() .headers(headers) .contentLength(resource.contentLength()) .body(resource); } catch (IOException e) { return ResponseEntity.status(500).build(); } }}
service
接口
package com.et.service;import ai.djl.MalformedModelException;import ai.djl.translate.TranslateException;import org.springframework.web.multipart.MultipartFile;import java.io.IOException;public interface ImageClassificationService { public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException; public String training(String datasetRoot, String modePath) throws TranslateException, IOException;}
实现类
package com.et.service;import ai.djl.MalformedModelException;import ai.djl.Model;import ai.djl.basicdataset.cv.classification.ImageFolder;import ai.djl.inference.Predictor;import ai.djl.metric.Metrics;import ai.djl.modality.Classifications;import ai.djl.modality.cv.Image;import ai.djl.modality.cv.ImageFactory;import ai.djl.modality.cv.transform.Resize;import ai.djl.modality.cv.transform.ToTensor;import ai.djl.modality.cv.translator.ImageClassificationTranslator;import ai.djl.ndarray.types.Shape;import ai.djl.training.;import ai.djl.training.dataset.RandomAccessDataset;import ai.djl.training.evaluator.Accuracy;import ai.djl.training.listener.TrainingListener;import ai.djl.training.loss.Loss;import ai.djl.translate.TranslateException;import ai.djl.translate.Translator;import com.et.Models;import lombok.Cleanup;import lombok.extern.slf4j.Slf4j;import org.springframework.beans.factory.annotation.Value;import org.springframework.stereotype.Service;import org.springframework.web.multipart.MultipartFile;import javax.imageio.ImageIO;import java.awt.image.BufferedImage;import java.io.IOException;import java.io.InputStream;import java.nio.file.Path;import java.nio.file.Paths;@Slf4j@Servicepublic class ImageClassificationServiceImpl implements ImageClassificationService { // represents number of training samples processed before the model is updated private static final int BATCH_SIZE = 32; // the number of passes over the complete dataset private static final int EPOCHS = 2; //the number of classification labels: boots, sandals, shoes, slippers @Value("${djl.num-of-output:4}") public int numOfOutput; @Override public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException { @Cleanup InputStream is = image.getInputStream(); Path modelDir = Paths.get(modePath); BufferedImage bi = ImageIO.read(is); Image img = ImageFactory.getInstance().fromImage(bi); // empty model instance try (Model model = Models.getModel(numOfOutput)) { // load the model model.load(modelDir, Models.MODEL_NAME); // define a translator for pre and post processing // out of the box this translator converts images to ResNet friendly ResNet 18 shape Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) .optApplySoftmax(true) .build(); // run the inference using a Predictor try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) { // holds the probability score per label Classifications predictResult = predictor.predict(img); log.info("reusult={}",predictResult.toJson()); return predictResult.toJson(); } } } @Override public String training(String datasetRoot, String modePath) throws TranslateException, IOException { log.info("Image dataset training started...Image dataset address path:{}",datasetRoot); // the location to save the model Path modelDir = Paths.get(modePath); // create ImageFolder dataset from directory ImageFolder dataset = initDataset(datasetRoot); // Split the dataset set into training dataset and validate dataset RandomAccessDataset[] datasets = dataset.randomSplit(8, 2); // set loss function, which seeks to minimize errors // loss function evaluates model's predictions against the correct answer (during training) // higher numbers are bad - means model performed poorly; indicates more errors; want to // minimize errors (loss) Loss loss = Loss.softmaxCrossEntropyLoss(); // setting training parameters (ie hyperparameters) TrainingConfig config = setupTrainingConfig(loss); try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns Trainer trainer = model.newTrainer(config)) { // metrics collect and report key performance indicators, like accuracy trainer.setMetrics(new Metrics()); Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT); // initialize trainer with proper input shape trainer.initialize(inputShape); // find the patterns in data EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]); // set model properties TrainingResult result = trainer.getTrainingResult(); model.setProperty("Epoch", String.valueOf(EPOCHS)); model.setProperty( "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy"))); model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); // save the model after done training for inference later // model saved as shoeclassifier-0000.params model.save(modelDir, Models.MODEL_NAME); // save labels into model directory Models.saveSynset(modelDir, dataset.getSynset()); log.info("Image dataset training completed......"); return String.join("\n", dataset.getSynset()); } } private ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException { ImageFolder dataset = ImageFolder.builder() // retrieve the data .setRepositoryPath(Paths.get(datasetRoot)) .optMaxDepth(10) .addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)) .addTransform(new ToTensor()) // random sampling; don't process the data in order .setSampling(BATCH_SIZE, true) .build(); dataset.prepare(); return dataset; } private TrainingConfig setupTrainingConfig(Loss loss) { return new DefaultTrainingConfig(loss) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging()); }}
application.yaml
server: port: 8888spring: application: name: djl-image-classification-demo servlet: multipart: max-file-size: 100MB max-request-size: 100MB mvc: pathmatch: matching-strategy: ant_path_matcher
启动类
package com.et;import org.springframework.boot.SpringApplication;import org.springframework.boot.autoconfigure.SpringBootApplication;@SpringBootApplicationpublic class DemoApplication { public static void main(String[] args) { SpringApplication.run(DemoApplication.class, args); }}
以上只是一些关键代码,所有代码请拜会下面代码仓库
代码仓库https://github.com/Harries/springboot-demo(DJL)4.测试启动Spring Boot运用
演习模型利用之前下载的数据集
掌握台输出日志,如果没有gpu的话,演习有点慢,估计要等一会
2024-10-11T21:00:05.407+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] c.e.s.ImageClassificationServiceImpl : Image dataset training started...Image dataset address path:/Users/liuhaihua/ai/ut-zap50k-images-square2024-10-11T21:00:08.455+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.util.Platform : Ignore mismatching platform from: jar:file:/Users/liuhaihua/.m2/repository/ai/djl/pytorch/pytorch-native-cpu/2.0.1/pytorch-native-cpu-2.0.1-win-x86_64.jar!/native/lib/pytorch.properties2024-10-11T21:00:09.240+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization2024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of inter-op threads is 42024-10-11T21:00:09.241+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] ai.djl.pytorch.engine.PtEngine : Number of intra-op threads is 42024-10-11T21:00:09.287+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Training on: cpu().2024-10-11T21:00:09.290+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Load PyTorch Engine Version 1.13.1 in 0.044 ms.Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.38Validating: 100% |████████████████████████████████████████|2024-10-11T22:42:48.142+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Epoch 1 finished.2024-10-11T22:42:48.187+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.382024-10-11T22:42:48.189+08:00 INFO 74606 --- [djl-image-classification-demo] [nio-8888-exec-1] a.d.t.listener.LoggingTrainingListener : Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.24Training: 5% |███ | Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
预测图片分类
利用上一步演习出来的模型进行预测
根据返回的结果瞥见鞋子的概率最高,由此可见该图片所属的鞋类为 Shoes
5.引用https://docs.djl.ai/master/docs/demos/footwear_classification/index.html#train-the-footwear-classification-modelhttps://github.com/deepjavalibrary/djl