Page MenuHomePhabricator (Chris)

No OneTemporary

Size
5 KB
Referenced Files
None
Subscribers
None
diff --git a/README.md b/README.md
index 20f232b..e04fcbc 100644
--- a/README.md
+++ b/README.md
@@ -1,72 +1,72 @@
# Tensorflow Implementation of Yahoo's Open NSFW Model
This repository contains an implementation of [Yahoo's Open NSFW Classifier](https://github.com/yahoo/open_nsfw) rewritten in tensorflow.
The original caffe weights have been extracted using [Caffe to TensorFlow](https://github.com/ethereon/caffe-tensorflow). You can find them at `data/open_nsfw-weights.npy`.
## Prerequisites
-All code should be compatible with `Python 3.6` and `Tensorflow 1.x`. The model implementation can be found in `model.py`.
+All code should be compatible with `Python 3.6` and `Tensorflow 1.x` (tested with 1.12). The model implementation can be found in `model.py`.
### Usage
```
> python classify_nsfw.py -m data/open_nsfw-weights.npy test.jpg
Results for 'test.jpg'
SFW score: 0.9355766177177429
NSFW score: 0.06442338228225708
```
__Note:__ Currently only jpeg images are supported.
`classify_nsfw.py` accepts some optional parameters you may want to play around with:
```
usage: classify_nsfw.py [-h] -m MODEL_WEIGHTS [-l {yahoo,tensorflow}]
[-t {tensor,base64_jpeg}]
input_jpeg_file
positional arguments:
input_file Path to the input image. Only jpeg images are
supported.
optional arguments:
-h, --help show this help message and exit
-m MODEL_WEIGHTS, --model_weights MODEL_WEIGHTS
Path to trained model weights file
-l {yahoo,tensorflow}, --image_loader {yahoo,tensorflow}
image loading mechanism
-t {tensor,base64_jpeg}, --input_type {tensor,base64_jpeg}
input type
```
__-l/--image-loader__
The classification tool supports two different image loading mechanisms.
* `yahoo` (default) replicates yahoo's original image loading and preprocessing. Use this option if you want the same results as with the original implementation
* `tensorflow` is an image loader which uses tensorflow exclusively (no dependencies on `PIL`, `skimage`, etc.). Tries to replicate the image loading mechanism used by the original caffe implementation, differs a bit though due to different jpeg and resizing implementations. See [this issue](https://github.com/mdietrichstein/tensorflow-open_nsfw/issues/2#issuecomment-346125345) for details.
__Note:__ Classification results may vary depending on the selected image loader!
__-t/--input_type__
Determines if the model internally uses a float tensor (`tensor` - `[None, 224, 224, 3]` - default) or a base64 encoded string tensor (`base64_jpeg` - `[None, ]`) as input. If `base64_jpeg` is used, then the `tensorflow` image loader will be used, regardless of the _-l/--image-loader_ argument.
### Tools
The `tools` folder contains some utility scripts to test the model.
__export_graph.py__
Exports the tensorflow graph and checkpoint. Freezes and optimizes the graph per default for improved inference and deployment usage (e.g. Android, iOS, etc.). Import the graph with `tf.import_graph_def`.
__export_savedmodel.py__
Exports the model using the tensorflow serving export api (`SavedModel`). The export can be used to deploy the model on [Google Cloud ML Engine](https://cloud.google.com/ml-engine/docs/concepts/prediction-overview), [Tensorflow Serving]() or on mobile (haven't tried that one yet).
__create_predict_request.py__
-Takes an input image and spits out an json file suitable for prediction requests to a Open NSFW Model deployed on [Google Cloud ML Engine](https://cloud.google.com/ml-engine/docs/concepts/prediction-overview) (`gcloud ml-engine predict`).
+Takes an input image and spits out an json file suitable for prediction requests to a Open NSFW Model deployed with [Google Cloud ML Engine](https://cloud.google.com/ml-engine/docs/concepts/prediction-overview) (`gcloud ml-engine predict`) or [tensorflow-serving](https://www.tensorflow.org/serving/).
diff --git a/tools/create_predict_request.py b/tools/create_predict_request.py
index b35dc33..5a5a352 100644
--- a/tools/create_predict_request.py
+++ b/tools/create_predict_request.py
@@ -1,18 +1,25 @@
import base64
import json
import argparse
from tensorflow.python.saved_model.signature_constants import PREDICT_INPUTS
"""base64 encodes the given input jpeg and outputs json data suitable for
'gcloud ml-engine predict' requests to a model generated with 'export-model.py'
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_file", help="Path to the input image file")
+ parser.add_argument("-t", "--target", required=True,
+ choices=['ml-engine', 'tf-serving'],
+ help="Create json for ml-engine or tensorflow-serving")
args = parser.parse_args()
+ target = args.target
image_b64 = base64.urlsafe_b64encode(open(args.input_file, "rb").read())
- print(json.dumps({PREDICT_INPUTS: image_b64.decode("ascii")}))
+ if target == "ml-engine":
+ print(json.dumps({PREDICT_INPUTS: image_b64.decode("ascii")}))
+ elif target == "tf-serving":
+ print(json.dumps({"instances": [image_b64.decode("ascii")]}))

File Metadata

Mime Type
text/x-diff
Expires
Fri, Sep 12, 10:50 PM (1 d, 13 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
42963
Default Alt Text
(5 KB)

Event Timeline