Skip to content

Commit f65fcf7

Browse files
committed
Merge branch 'r0.3'
# Conflicts: # README.md # ndarray/README.md # ndarray/pom.xml # pom.xml # tensorflow-core/pom.xml # tensorflow-core/tensorflow-core-api/pom.xml # tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java # tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java # tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java # tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java # tensorflow-core/tensorflow-core-generator/pom.xml # tensorflow-core/tensorflow-core-platform-gpu/pom.xml # tensorflow-core/tensorflow-core-platform-mkl-gpu/pom.xml # tensorflow-core/tensorflow-core-platform-mkl/pom.xml # tensorflow-core/tensorflow-core-platform/pom.xml # tensorflow-framework/pom.xml
2 parents 242931c + 307b672 commit f65fcf7

File tree

8 files changed

+87
-69
lines changed

8 files changed

+87
-69
lines changed

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ systems, you should add the following dependencies:
5656
<dependency>
5757
<groupId>org.tensorflow</groupId>
5858
<artifactId>tensorflow-core-api</artifactId>
59-
<version>0.3.1</version>
59+
<version>0.3.3</version>
6060
</dependency>
6161
<dependency>
6262
<groupId>org.tensorflow</groupId>
6363
<artifactId>tensorflow-core-api</artifactId>
64-
<version>0.3.1</version>
64+
<version>0.3.3</version>
6565
<classifier>linux-x86_64${javacpp.platform.extension}</classifier>
6666
</dependency>
6767
```
@@ -72,24 +72,24 @@ native dependencies as follows:
7272
<dependency>
7373
<groupId>org.tensorflow</groupId>
7474
<artifactId>tensorflow-core-api</artifactId>
75-
<version>0.3.1</version>
75+
<version>0.3.3</version>
7676
</dependency>
7777
<dependency>
7878
<groupId>org.tensorflow</groupId>
7979
<artifactId>tensorflow-core-api</artifactId>
80-
<version>0.3.1</version>
80+
<version>0.3.3</version>
8181
<classifier>linux-x86_64${javacpp.platform.extension}</classifier>
8282
</dependency>
8383
<dependency>
8484
<groupId>org.tensorflow</groupId>
8585
<artifactId>tensorflow-core-api</artifactId>
86-
<version>0.3.1</version>
86+
<version>0.3.3</version>
8787
<classifier>macosx-x86_64${javacpp.platform.extension}</classifier>
8888
</dependency>
8989
<dependency>
9090
<groupId>org.tensorflow</groupId>
9191
<artifactId>tensorflow-core-api</artifactId>
92-
<version>0.3.1</version>
92+
<version>0.3.3</version>
9393
<classifier>windows-x86_64${javacpp.platform.extension}</classifier>
9494
</dependency>
9595
```
@@ -102,7 +102,7 @@ artifact includes transitively all the artifacts above as a single dependency:
102102
<dependency>
103103
<groupId>org.tensorflow</groupId>
104104
<artifactId>tensorflow-core-platform${javacpp.platform.extension}</artifactId>
105-
<version>0.3.1</version>
105+
<version>0.3.3</version>
106106
</dependency>
107107
```
108108

@@ -146,6 +146,8 @@ This table shows the mapping between different version of TensorFlow for Java an
146146
| 0.2.0 | 2.3.1 |
147147
| 0.3.0 | 2.4.1 |
148148
| 0.3.1 | 2.4.1 |
149+
| 0.3.2 | 2.4.1 |
150+
| 0.3.3 | 2.4.1 |
149151
| 0.4.0-SNAPSHOT | 2.5.0
150152

151153
## How to Contribute?

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
<javacpp.parser.skip>${native.build.skip}</javacpp.parser.skip>
2121
<javacpp.compiler.skip>${native.build.skip}</javacpp.compiler.skip>
2222
<java.module.name>org.tensorflow.core.api</java.module.name>
23-
<ndarray.version>0.3.1</ndarray.version>
23+
<ndarray.version>0.3.3</ndarray.version>
2424
<truth.version>1.0.1</truth.version>
2525
</properties>
2626

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,10 @@ public final class Ops {
381381

382382
public final SignalOps signal;
383383

384-
public final TrainOps train;
385-
386384
public final QuantizationOps quantization;
387385

386+
public final TrainOps train;
387+
388388
private final Scope scope;
389389

390390
private Ops(Scope scope) {
@@ -407,8 +407,8 @@ private Ops(Scope scope) {
407407
math = new MathOps(this);
408408
audio = new AudioOps(this);
409409
signal = new SignalOps(this);
410-
train = new TrainOps(this);
411410
quantization = new QuantizationOps(this);
411+
train = new TrainOps(this);
412412
}
413413

414414
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -858,19 +858,17 @@ public Output<?>[] whileLoop(
858858
synchronized SaverDef saverDef() {
859859
if (saverDef == null) {
860860
// Check to see if this graph has a restore operation
861-
if (operation("save/restore_all") == null) {
861+
if (operation(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP) == null) {
862862
// No saver, create one by mutating the graph
863863
saverDef = addVariableSaver(this);
864864
} else {
865865
// This graph already has saving/restoring operations,
866-
// regenerate SaverDef without mutating. The names mirror
867-
// the python implementation for compatibility.
868-
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
866+
// regenerate SaverDef without mutating.
869867
saverDef =
870868
SaverDef.newBuilder()
871-
.setFilenameTensorName("save/filename")
872-
.setSaveTensorName("save/control_dependency")
873-
.setRestoreOpName("save/restore_all")
869+
.setFilenameTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + ":0")
870+
.setSaveTensorName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP)
871+
.setRestoreOpName(SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP)
874872
.build();
875873
}
876874
}
@@ -981,6 +979,13 @@ public void remove() {
981979
private int position;
982980
}
983981

982+
// These names mirror the python implementation, to reduce the risk of incompatibility.
983+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
984+
private static final String SAVER_DEF_SCOPE = "save";
985+
private static final String SAVER_DEF_FILENAME_OP = "filename";
986+
private static final String SAVER_DEF_SAVE_OP = "control_dependency";
987+
private static final String SAVER_DEF_RESTORE_OP = "restore_all";
988+
984989
private static TF_Graph allocate() {
985990
return TF_NewGraph();
986991
}
@@ -1232,7 +1237,7 @@ private static Object[] whileLoop(
12321237
}
12331238

12341239
private static SaverDef addVariableSaver(Graph graph) {
1235-
Ops tf = Ops.create(graph).withSubScope("save");
1240+
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);
12361241

12371242
List<String> varNames = new ArrayList<>();
12381243
List<Operand<?>> varOutputs = new ArrayList<>();
@@ -1247,36 +1252,35 @@ private static SaverDef addVariableSaver(Graph graph) {
12471252
}
12481253
}
12491254

1250-
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
1255+
Placeholder<TString> filename = tf.withName(SAVER_DEF_FILENAME_OP).placeholder(TString.class);
1256+
Identity<TString> save = null;
1257+
NoOp restore = null;
12511258

12521259
if (varNames.isEmpty()) {
1253-
return SaverDef.newBuilder()
1254-
.setFilenameTensorName(saveFilename.op().name())
1255-
.setSaveTensorName(tf.withName("empty_save").identity(saveFilename).op().name())
1256-
.setRestoreOpName(tf.withName("restore_all").noOp().op().name())
1257-
.build();
1258-
}
1259-
1260-
// FIXME Need an easier way to initialize an NdArray from a list
1261-
String[] tmp = new String[varNames.size()];
1262-
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
1263-
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
1264-
Save saveVariables = tf.train.save(saveFilename, varNamesTensor, varSlices, varOutputs);
1265-
Identity<TString> id =
1266-
tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables))
1267-
.withName("control_dependency")
1268-
.identity(saveFilename);
1269-
Restore restoreVariables = tf.train.restore(saveFilename, varNamesTensor, varSlices, varTypes);
1270-
List<Op> restoreOps = new ArrayList<>(varOutputs.size());
1271-
for (int i = 0; i < varOutputs.size(); ++i) {
1272-
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
1260+
save = tf.withName(SAVER_DEF_SAVE_OP).identity(filename);
1261+
restore = tf.withName(SAVER_DEF_RESTORE_OP).noOp();
1262+
} else {
1263+
String[] tmp = new String[varNames.size()];
1264+
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
1265+
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
1266+
Save saveVars = tf.train.save(filename, varNamesTensor, varSlices, varOutputs);
1267+
List<Op> saveDeps = Arrays.asList(filename, saveVars);
1268+
Restore restoreVars = tf.train.restore(filename, varNamesTensor, varSlices, varTypes);
1269+
List<Op> restoreDeps = new ArrayList<>(varOutputs.size());
1270+
for (int i = 0; i < varOutputs.size(); ++i) {
1271+
restoreDeps.add(tf.assign(varOutputs.get(i), (Operand) restoreVars.tensors().get(i)));
1272+
}
1273+
save = tf.withControlDependencies(saveDeps).withName(SAVER_DEF_SAVE_OP).identity(filename);
1274+
restore = tf.withControlDependencies(restoreDeps).withName(SAVER_DEF_RESTORE_OP).noOp();
12731275
}
1274-
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();
12751276

1277+
// 'Filename' must be the name of a tensor (i.e. with output index)
1278+
// 'Save' must be an operation name, even if the field name is confusing (see SaverDef doc)
1279+
// 'Restore' must be an operation name
12761280
return SaverDef.newBuilder()
1277-
.setFilenameTensorName(saveFilename.op().name())
1278-
.setSaveTensorName(id.op().name())
1279-
.setRestoreOpName(restoreAll.op().name())
1281+
.setFilenameTensorName(filename.output().name())
1282+
.setSaveTensorName(save.op().name())
1283+
.setRestoreOpName(restore.op().name())
12801284
.build();
12811285
}
12821286

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ public int index() {
3838
return index;
3939
}
4040

41+
/** Returns the full name of this Output (a.k.a. tensor name) */
42+
public String name() {
43+
return op().name() + ":" + index;
44+
}
45+
4146
/** Returns the DataType of the tensor referred to by this Output. */
4247
@SuppressWarnings("unchecked")
4348
public DataType dataType() {
@@ -48,7 +53,7 @@ public DataType dataType() {
4853
@SuppressWarnings("unchecked")
4954
@Override
5055
public Class<T> type() {
51-
return (Class<T>)TensorTypeRegistry.find(dataType()).type();
56+
return (Class<T>) TensorTypeRegistry.find(dataType()).type();
5257
}
5358

5459
/**
@@ -63,7 +68,10 @@ public Class<T> type() {
6368
public <U extends TType> Output<U> expect(Class<U> type) {
6469
if (type != type()) {
6570
throw new IllegalArgumentException(
66-
"Cannot cast from output of " + this.type().getSimpleName() + " to output of " + type.getSimpleName());
71+
"Cannot cast from output of "
72+
+ this.type().getSimpleName()
73+
+ " to output of "
74+
+ type.getSimpleName());
6775
}
6876
return ((Output<U>) this);
6977
}
@@ -80,17 +88,16 @@ public <U extends TType> Output<U> expect(Class<U> type) {
8088
*
8189
* @return tensor
8290
* @throws IllegalStateException if this output results from a graph
83-
* @throws ClassCastException if the type of the tensor and this output are unexpectedly incompatible
91+
* @throws ClassCastException if the type of the tensor and this output are unexpectedly
92+
* incompatible
8493
* @see EagerSession
8594
*/
8695
@SuppressWarnings("unchecked")
8796
public T asTensor() {
88-
return (T)operation.tensor(index);
97+
return (T) operation.tensor(index);
8998
}
9099

91-
/**
92-
* Returns the (possibly partially known) shape of the tensor referred to by this output.
93-
*/
100+
/** Returns the (possibly partially known) shape of the tensor referred to by this output. */
94101
@Override
95102
public Shape shape() {
96103
return operation.shape(index);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.List;
3232
import java.util.Map;
3333
import java.util.Map.Entry;
34+
import java.util.Objects;
3435
import java.util.stream.Collectors;
3536
import org.bytedeco.javacpp.BytePointer;
3637
import org.bytedeco.javacpp.PointerScope;
@@ -529,7 +530,7 @@ private static SavedModelBundle load(
529530
}
530531

531532
private static void validateTags(String[] tags) {
532-
if (tags == null || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
533+
if (tags == null || Arrays.stream(tags).anyMatch(Objects::isNull)) {
533534
throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
534535
}
535536
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public Set<String> outputNames() {
208208

209209
@Override
210210
public String toString() {
211-
StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n");
211+
StringBuilder strBuilder = new StringBuilder("Signature for \"" + key + "\":\n");
212212
String methodName = methodName();
213213
if (methodName != null && !methodName.isEmpty()) {
214214
strBuilder.append("\tMethod: \"").append(methodName).append("\"\n");

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.tensorflow.proto.framework.RunOptions;
4444
import org.tensorflow.proto.framework.SignatureDef;
4545
import org.tensorflow.proto.framework.TensorInfo;
46+
import org.tensorflow.proto.util.SaverDef;
4647
import org.tensorflow.types.TFloat32;
4748

4849
/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
@@ -171,7 +172,13 @@ public void exportFunctionWithVariables() throws IOException {
171172
try (SavedModelBundle savedModel =
172173
SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) {
173174
assertNotNull(savedModel.metaGraphDef());
174-
assertNotNull(savedModel.metaGraphDef().getSaverDef());
175+
176+
SaverDef saverDef = savedModel.metaGraphDef().getSaverDef();
177+
assertNotNull(saverDef);
178+
assertEquals("save/filename:0", saverDef.getFilenameTensorName());
179+
assertEquals("save/control_dependency", saverDef.getSaveTensorName());
180+
assertEquals("save/restore_all", saverDef.getRestoreOpName());
181+
175182
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
176183
assertEquals(
177184
Signature.DEFAULT_KEY,
@@ -262,21 +269,18 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti
262269

263270
@Test
264271
public void cannotExportOrImportInvalidTags() {
265-
assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null));
266-
assertThrows(
267-
IllegalArgumentException.class,
268-
() -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null}));
269-
assertThrows(
270-
IllegalArgumentException.class,
271-
() -> SavedModelBundle.loader("/").withTags(new String[] {"tag", ""}));
272-
assertThrows(
273-
IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null));
274-
assertThrows(
275-
IllegalArgumentException.class,
276-
() -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null}));
277-
assertThrows(
278-
IllegalArgumentException.class,
279-
() -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", ""}));
272+
assertThrows(IllegalArgumentException.class, () ->
273+
SavedModelBundle.loader("/").withTags(null)
274+
);
275+
assertThrows(IllegalArgumentException.class, () ->
276+
SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
277+
);
278+
assertThrows(IllegalArgumentException.class, () ->
279+
SavedModelBundle.exporter("/").withTags(null)
280+
);
281+
assertThrows(IllegalArgumentException.class, () ->
282+
SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})
283+
);
280284
}
281285

282286
@Test

0 commit comments

Comments
 (0)