From 7a9363e16590c3a713ce06dbf5857149b460b8aa Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Thu, 6 Dec 2018 13:17:22 -0800 Subject: [PATCH] Update model for classification app - Use model referred from our models.md list. - Download both float/quant PiperOrigin-RevId: 224394551 --- tensorflow/lite/examples/android/BUILD | 2 +- .../android/app/download-models.gradle | 27 ++++++---- .../tensorflow/demo/ClassifierActivity.java | 2 +- tensorflow/lite/java/demo/app/build.gradle | 53 +++++++++++++++---- tensorflow/lite/java/demo/app/src/main/BUILD | 3 +- .../ImageClassifierQuantizedMobileNet.java | 5 +- tensorflow/opensource_only.files | 2 + tensorflow/workspace.bzl | 20 +++++-- third_party/tflite_mobilenet_float.BUILD | 12 +++++ third_party/tflite_mobilenet_quant.BUILD | 12 +++++ 10 files changed, 108 insertions(+), 30 deletions(-) create mode 100644 third_party/tflite_mobilenet_float.BUILD create mode 100644 third_party/tflite_mobilenet_quant.BUILD diff --git a/tensorflow/lite/examples/android/BUILD b/tensorflow/lite/examples/android/BUILD index 761a60314e8..80cefd415a5 100644 --- a/tensorflow/lite/examples/android/BUILD +++ b/tensorflow/lite/examples/android/BUILD @@ -34,7 +34,7 @@ android_binary( # to reduce APK size. assets = [ "//tensorflow/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", - "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + "@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", "//tensorflow/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", diff --git a/tensorflow/lite/examples/android/app/download-models.gradle b/tensorflow/lite/examples/android/app/download-models.gradle index d2f03db5f63..36bd177a1fd 100644 --- a/tensorflow/lite/examples/android/app/download-models.gradle +++ b/tensorflow/lite/examples/android/app/download-models.gradle @@ -8,13 +8,12 @@ * 3 model files will be downloaded into given folder of ext.ASSET_DIR */ // hard coded model files -// LINT.IfChange -def models = ['conv_actions_tflite.zip', - 'mobilenet_ssd_tflite_v1.zip', - 'mobilenet_v1_224_android_quant_2017_11_08.zip', - 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'] -// LINT.ThenChange(//tensorflow/lite/examples/android/BUILD) +def models = ['https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip', + 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip', + 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip', + 'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz', + 'http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz'] // Root URL for model archives def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' @@ -30,9 +29,9 @@ buildscript { import de.undercouch.gradle.tasks.download.Download task downloadFile(type: Download){ - for (f in models) { - def modelUrl = MODEL_URL + "/" + f - println "Downloading ${f} from ${modelUrl}" + for (modelUrl in models) { + def localFile = modelUrl.split("/")[-1] + println "Downloading ${localFile} from ${modelUrl}" src modelUrl } @@ -43,7 +42,12 @@ task downloadFile(type: Download){ task extractModels(type: Copy) { for (f in models) { def localFile = f.split("/")[-1] - from zipTree(project.ext.TMP_DIR + '/' + localFile) + def localExt = localFile.split("[.]")[-1] + if (localExt == "tgz") { + from tarTree(project.ext.TMP_DIR + '/' + localFile) + } else { + from zipTree(project.ext.TMP_DIR + '/' + localFile) + } } into file(project.ext.ASSET_DIR) @@ -63,6 +67,9 @@ task extractModels(type: Copy) { } } + + + tasks.whenTaskAdded { task -> if (task.name == 'assembleDebug') { task.dependsOn 'extractModels' diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java index dcbbefbeab6..698251d8b4a 100644 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java +++ b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java @@ -65,7 +65,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab // --input_binary=true private static final int INPUT_SIZE = 224; - private static final String MODEL_FILE = "mobilenet_quant_v1_224.tflite"; + private static final String MODEL_FILE = "mobilenet_v1_1.0_224_quant.tflite"; private static final String LABEL_FILE = "labels_mobilenet_quant_v1_224.txt"; private static final boolean MAINTAIN_ASPECT = true; diff --git a/tensorflow/lite/java/demo/app/build.gradle b/tensorflow/lite/java/demo/app/build.gradle index 05301ebf88c..5e50ed4b941 100644 --- a/tensorflow/lite/java/demo/app/build.gradle +++ b/tensorflow/lite/java/demo/app/build.gradle @@ -52,28 +52,60 @@ dependencies { compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' } -def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" -def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip" def targetFolder = "src/main/assets" +def modelFloatDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz" +def modelQuantDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" +def localCacheFloat = "build/intermediates/mobilenet_v1_1.0_224.tgz" +def localCacheQuant = "build/intermediates/mmobilenet_v1_1.0_224_quant.tgz" -task downloadModel(type: DownloadUrlTask) { + +task downloadModelFloat(type: DownloadUrlTask) { doFirst { - println "Downloading ${modelDownloadUrl}" + println "Downloading ${modelFloatDownloadUrl}" } - sourceUrl = "${modelDownloadUrl}" - target = file("${localCache}") + sourceUrl = "${modelFloatDownloadUrl}" + target = file("${localCacheFloat}") } -task unzipModel(type: Copy, dependsOn: 'downloadModel') { +task downloadModelQuant(type: DownloadUrlTask) { doFirst { - println "Unzipping ${localCache}" + println "Downloading ${modelQuantDownloadUrl}" } - from zipTree("${localCache}") + sourceUrl = "${modelQuantDownloadUrl}" + target = file("${localCacheQuant}") +} + +task unzipModelFloat(type: Copy, dependsOn: 'downloadModelFloat') { + doFirst { + println "Unzipping ${localCacheFloat}" + } + from tarTree("${localCacheFloat}") into "${targetFolder}" } +task unzipModelQuant(type: Copy, dependsOn: 'downloadModelQuant') { + doFirst { + println "Unzipping ${localCacheQuant}" + } + from tarTree("${localCacheQuant}") + into "${targetFolder}" +} + +task cleanUnusedFiles(type: Delete, dependsOn: ['unzipModelFloat', 'unzipModelQuant']) { + delete fileTree("${targetFolder}").matching { + include "*.pb" + include "*.ckpt.*" + include "*.pbtxt.*" + include "*.quant_info.*" + include "*.meta" + } +} + + // Ensure the model file is downloaded and extracted before every build -preBuild.dependsOn unzipModel +preBuild.dependsOn unzipModelFloat +preBuild.dependsOn unzipModelQuant +preBuild.dependsOn cleanUnusedFiles class DownloadUrlTask extends DefaultTask { @Input @@ -87,3 +119,4 @@ class DownloadUrlTask extends DefaultTask { ant.get(src: sourceUrl, dest: target) } } + diff --git a/tensorflow/lite/java/demo/app/src/main/BUILD b/tensorflow/lite/java/demo/app/src/main/BUILD index df8a024a570..9a7c1d0b611 100644 --- a/tensorflow/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/lite/java/demo/app/src/main/BUILD @@ -10,7 +10,8 @@ android_binary( aapt_version = "aapt", assets = [ "//tensorflow/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", - "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + "@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite", + "@tflite_mobilenet_float//:mobilenet_v1_1.0_224.tflite", ], assets_dir = "", custom_package = "com.example.android.tflitecamerademo", diff --git a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java index e164ac75543..6310a561683 100644 --- a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java +++ b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java @@ -42,8 +42,9 @@ public class ImageClassifierQuantizedMobileNet extends ImageClassifier { @Override protected String getModelPath() { // you can download this file from - // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip - return "mobilenet_quant_v1_224.tflite"; + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224_quant.tflite"; } @Override diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 0c29ac6a307..688a837dac3 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -173,6 +173,7 @@ tensorflow/third_party/common.bzl tensorflow/third_party/com_google_absl.BUILD tensorflow/third_party/pprof.BUILD tensorflow/third_party/BUILD +tensorflow/third_party/tflite_mobilenet_quant.BUILD tensorflow/third_party/lmdb.BUILD tensorflow/third_party/git/BUILD.tpl tensorflow/third_party/git/BUILD @@ -198,6 +199,7 @@ tensorflow/third_party/nanopb.BUILD tensorflow/third_party/gif.BUILD tensorflow/third_party/double_conversion.BUILD tensorflow/third_party/six.BUILD +tensorflow/third_party/tflite_mobilenet_float.BUILD tensorflow/third_party/repo.bzl tensorflow/third_party/codegen.BUILD tensorflow/third_party/cub.BUILD diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index e6b4a89e3b1..f475493446e 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -734,12 +734,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): ) tf_http_archive( - name = "tflite_mobilenet", - build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"), - sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", + name = "tflite_mobilenet_float", + build_file = clean_dep("//third_party:tflite_mobilenet_float.BUILD"), + sha256 = "2fadeabb9968ec6833bee903900dda6e61b3947200535874ce2fe42a8493abc0", urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", + "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", + ], + ) + + tf_http_archive( + name = "tflite_mobilenet_quant", + build_file = clean_dep("//third_party:tflite_mobilenet_quant.BUILD"), + sha256 = "d32432d28673a936b2d6281ab0600c71cf7226dfe4cdcef3012555f691744166", + urls = [ + "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", ], ) diff --git a/third_party/tflite_mobilenet_float.BUILD b/third_party/tflite_mobilenet_float.BUILD new file mode 100644 index 00000000000..de47ed61f9d --- /dev/null +++ b/third_party/tflite_mobilenet_float.BUILD @@ -0,0 +1,12 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/third_party/tflite_mobilenet_quant.BUILD b/third_party/tflite_mobilenet_quant.BUILD new file mode 100644 index 00000000000..de47ed61f9d --- /dev/null +++ b/third_party/tflite_mobilenet_quant.BUILD @@ -0,0 +1,12 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +)