Update model for classification app

- Use model referred from our models.md list.
- Download both float/quant

PiperOrigin-RevId: 224394551
This commit is contained in:
Andrew Selle 2018-12-06 13:17:22 -08:00 committed by TensorFlower Gardener
parent e33bca07de
commit 7a9363e165
10 changed files with 108 additions and 30 deletions

View File

@ -34,7 +34,7 @@ android_binary(
# to reduce APK size.
assets = [
"//tensorflow/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
"@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite",
"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
"//tensorflow/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",

View File

@ -8,13 +8,12 @@
* 3 model files will be downloaded into given folder of ext.ASSET_DIR
*/
// hard coded model files
// LINT.IfChange
def models = ['conv_actions_tflite.zip',
'mobilenet_ssd_tflite_v1.zip',
'mobilenet_v1_224_android_quant_2017_11_08.zip',
'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip']
// LINT.ThenChange(//tensorflow/lite/examples/android/BUILD)
def models = ['https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip',
'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip',
'https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip',
'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz',
'http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz']
// Root URL for model archives
def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite'
@ -30,9 +29,9 @@ buildscript {
import de.undercouch.gradle.tasks.download.Download
task downloadFile(type: Download){
for (f in models) {
def modelUrl = MODEL_URL + "/" + f
println "Downloading ${f} from ${modelUrl}"
for (modelUrl in models) {
def localFile = modelUrl.split("/")[-1]
println "Downloading ${localFile} from ${modelUrl}"
src modelUrl
}
@ -43,7 +42,12 @@ task downloadFile(type: Download){
task extractModels(type: Copy) {
for (f in models) {
def localFile = f.split("/")[-1]
from zipTree(project.ext.TMP_DIR + '/' + localFile)
def localExt = localFile.split("[.]")[-1]
if (localExt == "tgz") {
from tarTree(project.ext.TMP_DIR + '/' + localFile)
} else {
from zipTree(project.ext.TMP_DIR + '/' + localFile)
}
}
into file(project.ext.ASSET_DIR)
@ -63,6 +67,9 @@ task extractModels(type: Copy) {
}
}
tasks.whenTaskAdded { task ->
if (task.name == 'assembleDebug') {
task.dependsOn 'extractModels'

View File

@ -65,7 +65,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
// --input_binary=true
private static final int INPUT_SIZE = 224;
private static final String MODEL_FILE = "mobilenet_quant_v1_224.tflite";
private static final String MODEL_FILE = "mobilenet_v1_1.0_224_quant.tflite";
private static final String LABEL_FILE = "labels_mobilenet_quant_v1_224.txt";
private static final boolean MAINTAIN_ASPECT = true;

View File

@ -52,28 +52,60 @@ dependencies {
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}
def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip"
def targetFolder = "src/main/assets"
def modelFloatDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
def modelQuantDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
def localCacheFloat = "build/intermediates/mobilenet_v1_1.0_224.tgz"
def localCacheQuant = "build/intermediates/mmobilenet_v1_1.0_224_quant.tgz"
task downloadModel(type: DownloadUrlTask) {
task downloadModelFloat(type: DownloadUrlTask) {
doFirst {
println "Downloading ${modelDownloadUrl}"
println "Downloading ${modelFloatDownloadUrl}"
}
sourceUrl = "${modelDownloadUrl}"
target = file("${localCache}")
sourceUrl = "${modelFloatDownloadUrl}"
target = file("${localCacheFloat}")
}
task unzipModel(type: Copy, dependsOn: 'downloadModel') {
task downloadModelQuant(type: DownloadUrlTask) {
doFirst {
println "Unzipping ${localCache}"
println "Downloading ${modelQuantDownloadUrl}"
}
from zipTree("${localCache}")
sourceUrl = "${modelQuantDownloadUrl}"
target = file("${localCacheQuant}")
}
task unzipModelFloat(type: Copy, dependsOn: 'downloadModelFloat') {
doFirst {
println "Unzipping ${localCacheFloat}"
}
from tarTree("${localCacheFloat}")
into "${targetFolder}"
}
task unzipModelQuant(type: Copy, dependsOn: 'downloadModelQuant') {
doFirst {
println "Unzipping ${localCacheQuant}"
}
from tarTree("${localCacheQuant}")
into "${targetFolder}"
}
task cleanUnusedFiles(type: Delete, dependsOn: ['unzipModelFloat', 'unzipModelQuant']) {
delete fileTree("${targetFolder}").matching {
include "*.pb"
include "*.ckpt.*"
include "*.pbtxt.*"
include "*.quant_info.*"
include "*.meta"
}
}
// Ensure the model file is downloaded and extracted before every build
preBuild.dependsOn unzipModel
preBuild.dependsOn unzipModelFloat
preBuild.dependsOn unzipModelQuant
preBuild.dependsOn cleanUnusedFiles
class DownloadUrlTask extends DefaultTask {
@Input
@ -87,3 +119,4 @@ class DownloadUrlTask extends DefaultTask {
ant.get(src: sourceUrl, dest: target)
}
}

View File

@ -10,7 +10,8 @@ android_binary(
aapt_version = "aapt",
assets = [
"//tensorflow/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
"@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite",
"@tflite_mobilenet_float//:mobilenet_v1_1.0_224.tflite",
],
assets_dir = "",
custom_package = "com.example.android.tflitecamerademo",

View File

@ -42,8 +42,9 @@ public class ImageClassifierQuantizedMobileNet extends ImageClassifier {
@Override
protected String getModelPath() {
// you can download this file from
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
return "mobilenet_quant_v1_224.tflite";
// see build.gradle for where to obtain this file. It should be auto
// downloaded into assets.
return "mobilenet_v1_1.0_224_quant.tflite";
}
@Override

View File

@ -173,6 +173,7 @@ tensorflow/third_party/common.bzl
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
@ -198,6 +199,7 @@ tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/gif.BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/six.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/repo.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD

View File

@ -734,12 +734,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
name = "tflite_mobilenet",
build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
name = "tflite_mobilenet_float",
build_file = clean_dep("//third_party:tflite_mobilenet_float.BUILD"),
sha256 = "2fadeabb9968ec6833bee903900dda6e61b3947200535874ce2fe42a8493abc0",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
],
)
tf_http_archive(
name = "tflite_mobilenet_quant",
build_file = clean_dep("//third_party:tflite_mobilenet_quant.BUILD"),
sha256 = "d32432d28673a936b2d6281ab0600c71cf7226dfe4cdcef3012555f691744166",
urls = [
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
],
)

View File

@ -0,0 +1,12 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(
glob(
["**/*"],
exclude = [
"BUILD",
],
),
)

View File

@ -0,0 +1,12 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(
glob(
["**/*"],
exclude = [
"BUILD",
],
),
)