mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 21:05:19 +00:00
Update model for classification app
- Use model referred from our models.md list. - Download both float/quant PiperOrigin-RevId: 224394551
This commit is contained in:
parent
e33bca07de
commit
7a9363e165
@ -34,7 +34,7 @@ android_binary(
|
||||
# to reduce APK size.
|
||||
assets = [
|
||||
"//tensorflow/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
|
||||
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
|
||||
"@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite",
|
||||
"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
|
||||
"//tensorflow/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
|
||||
"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
|
||||
|
@ -8,13 +8,12 @@
|
||||
* 3 model files will be downloaded into given folder of ext.ASSET_DIR
|
||||
*/
|
||||
// hard coded model files
|
||||
// LINT.IfChange
|
||||
|
||||
def models = ['conv_actions_tflite.zip',
|
||||
'mobilenet_ssd_tflite_v1.zip',
|
||||
'mobilenet_v1_224_android_quant_2017_11_08.zip',
|
||||
'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip']
|
||||
// LINT.ThenChange(//tensorflow/lite/examples/android/BUILD)
|
||||
def models = ['https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip',
|
||||
'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip',
|
||||
'https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip',
|
||||
'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz',
|
||||
'http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz']
|
||||
|
||||
// Root URL for model archives
|
||||
def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite'
|
||||
@ -30,9 +29,9 @@ buildscript {
|
||||
|
||||
import de.undercouch.gradle.tasks.download.Download
|
||||
task downloadFile(type: Download){
|
||||
for (f in models) {
|
||||
def modelUrl = MODEL_URL + "/" + f
|
||||
println "Downloading ${f} from ${modelUrl}"
|
||||
for (modelUrl in models) {
|
||||
def localFile = modelUrl.split("/")[-1]
|
||||
println "Downloading ${localFile} from ${modelUrl}"
|
||||
src modelUrl
|
||||
}
|
||||
|
||||
@ -43,7 +42,12 @@ task downloadFile(type: Download){
|
||||
task extractModels(type: Copy) {
|
||||
for (f in models) {
|
||||
def localFile = f.split("/")[-1]
|
||||
from zipTree(project.ext.TMP_DIR + '/' + localFile)
|
||||
def localExt = localFile.split("[.]")[-1]
|
||||
if (localExt == "tgz") {
|
||||
from tarTree(project.ext.TMP_DIR + '/' + localFile)
|
||||
} else {
|
||||
from zipTree(project.ext.TMP_DIR + '/' + localFile)
|
||||
}
|
||||
}
|
||||
|
||||
into file(project.ext.ASSET_DIR)
|
||||
@ -63,6 +67,9 @@ task extractModels(type: Copy) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
if (task.name == 'assembleDebug') {
|
||||
task.dependsOn 'extractModels'
|
||||
|
@ -65,7 +65,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
|
||||
// --input_binary=true
|
||||
private static final int INPUT_SIZE = 224;
|
||||
|
||||
private static final String MODEL_FILE = "mobilenet_quant_v1_224.tflite";
|
||||
private static final String MODEL_FILE = "mobilenet_v1_1.0_224_quant.tflite";
|
||||
private static final String LABEL_FILE = "labels_mobilenet_quant_v1_224.txt";
|
||||
|
||||
private static final boolean MAINTAIN_ASPECT = true;
|
||||
|
@ -52,28 +52,60 @@ dependencies {
|
||||
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
|
||||
}
|
||||
|
||||
def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
|
||||
def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip"
|
||||
def targetFolder = "src/main/assets"
|
||||
def modelFloatDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
|
||||
def modelQuantDownloadUrl = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
|
||||
def localCacheFloat = "build/intermediates/mobilenet_v1_1.0_224.tgz"
|
||||
def localCacheQuant = "build/intermediates/mmobilenet_v1_1.0_224_quant.tgz"
|
||||
|
||||
task downloadModel(type: DownloadUrlTask) {
|
||||
|
||||
task downloadModelFloat(type: DownloadUrlTask) {
|
||||
doFirst {
|
||||
println "Downloading ${modelDownloadUrl}"
|
||||
println "Downloading ${modelFloatDownloadUrl}"
|
||||
}
|
||||
sourceUrl = "${modelDownloadUrl}"
|
||||
target = file("${localCache}")
|
||||
sourceUrl = "${modelFloatDownloadUrl}"
|
||||
target = file("${localCacheFloat}")
|
||||
}
|
||||
|
||||
task unzipModel(type: Copy, dependsOn: 'downloadModel') {
|
||||
task downloadModelQuant(type: DownloadUrlTask) {
|
||||
doFirst {
|
||||
println "Unzipping ${localCache}"
|
||||
println "Downloading ${modelQuantDownloadUrl}"
|
||||
}
|
||||
from zipTree("${localCache}")
|
||||
sourceUrl = "${modelQuantDownloadUrl}"
|
||||
target = file("${localCacheQuant}")
|
||||
}
|
||||
|
||||
task unzipModelFloat(type: Copy, dependsOn: 'downloadModelFloat') {
|
||||
doFirst {
|
||||
println "Unzipping ${localCacheFloat}"
|
||||
}
|
||||
from tarTree("${localCacheFloat}")
|
||||
into "${targetFolder}"
|
||||
}
|
||||
|
||||
task unzipModelQuant(type: Copy, dependsOn: 'downloadModelQuant') {
|
||||
doFirst {
|
||||
println "Unzipping ${localCacheQuant}"
|
||||
}
|
||||
from tarTree("${localCacheQuant}")
|
||||
into "${targetFolder}"
|
||||
}
|
||||
|
||||
task cleanUnusedFiles(type: Delete, dependsOn: ['unzipModelFloat', 'unzipModelQuant']) {
|
||||
delete fileTree("${targetFolder}").matching {
|
||||
include "*.pb"
|
||||
include "*.ckpt.*"
|
||||
include "*.pbtxt.*"
|
||||
include "*.quant_info.*"
|
||||
include "*.meta"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Ensure the model file is downloaded and extracted before every build
|
||||
preBuild.dependsOn unzipModel
|
||||
preBuild.dependsOn unzipModelFloat
|
||||
preBuild.dependsOn unzipModelQuant
|
||||
preBuild.dependsOn cleanUnusedFiles
|
||||
|
||||
class DownloadUrlTask extends DefaultTask {
|
||||
@Input
|
||||
@ -87,3 +119,4 @@ class DownloadUrlTask extends DefaultTask {
|
||||
ant.get(src: sourceUrl, dest: target)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,8 @@ android_binary(
|
||||
aapt_version = "aapt",
|
||||
assets = [
|
||||
"//tensorflow/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
|
||||
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
|
||||
"@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite",
|
||||
"@tflite_mobilenet_float//:mobilenet_v1_1.0_224.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
custom_package = "com.example.android.tflitecamerademo",
|
||||
|
@ -42,8 +42,9 @@ public class ImageClassifierQuantizedMobileNet extends ImageClassifier {
|
||||
@Override
|
||||
protected String getModelPath() {
|
||||
// you can download this file from
|
||||
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
|
||||
return "mobilenet_quant_v1_224.tflite";
|
||||
// see build.gradle for where to obtain this file. It should be auto
|
||||
// downloaded into assets.
|
||||
return "mobilenet_v1_1.0_224_quant.tflite";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -173,6 +173,7 @@ tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
@ -198,6 +199,7 @@ tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
|
@ -734,12 +734,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "tflite_mobilenet",
|
||||
build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
|
||||
sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
|
||||
name = "tflite_mobilenet_float",
|
||||
build_file = clean_dep("//third_party:tflite_mobilenet_float.BUILD"),
|
||||
sha256 = "2fadeabb9968ec6833bee903900dda6e61b3947200535874ce2fe42a8493abc0",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
|
||||
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
|
||||
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
|
||||
],
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "tflite_mobilenet_quant",
|
||||
build_file = clean_dep("//third_party:tflite_mobilenet_quant.BUILD"),
|
||||
sha256 = "d32432d28673a936b2d6281ab0600c71cf7226dfe4cdcef3012555f691744166",
|
||||
urls = [
|
||||
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
|
||||
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
|
||||
],
|
||||
)
|
||||
|
||||
|
12
third_party/tflite_mobilenet_float.BUILD
vendored
Normal file
12
third_party/tflite_mobilenet_float.BUILD
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(
|
||||
glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"BUILD",
|
||||
],
|
||||
),
|
||||
)
|
12
third_party/tflite_mobilenet_quant.BUILD
vendored
Normal file
12
third_party/tflite_mobilenet_quant.BUILD
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(
|
||||
glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"BUILD",
|
||||
],
|
||||
),
|
||||
)
|
Loading…
Reference in New Issue
Block a user