Skip to content

Commit 8f7b2d6

Browse files
committed
fixing errors
1 parent fbad9af commit 8f7b2d6

File tree

4 files changed

+7
-9
lines changed

4 files changed

+7
-9
lines changed

neuralmonkey/decoders/autoregressive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def embedding_size(self) -> int:
176176
"size of the reused embeddings from the "
177177
"`embeddings_source`.")
178178

179-
return self.embeddings_source.dimension
179+
return self.embeddings_source.embedding_matrix.get_shape()[1].value
180180

181181
@tensor
182182
def go_symbols(self) -> tf.Tensor:

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self,
2323
name: str,
2424
encoder: TemporalStateful,
2525
data_id: str,
26-
vocabulary: Vocabulary,
26+
vocabulary: Vocabulary = None,
2727
embeddings_source: EmbeddedSequence = None,
2828
dropout_keep_prob: float = 1.0,
2929
reuse: ModelPart = None,
@@ -40,12 +40,11 @@ def __init__(self,
4040
self.data_id = data_id
4141
self.dropout_keep_prob = dropout_keep_prob
4242

43-
# We provide only embedding_source when we want to input and output
43+
# We provide only embedding_source when we want to tie input and output
4444
# projections
45-
if self.embeddings_source is not None and self.vocabulary is not None:
46-
warn("Both `vocabulary` and `embedding_source` was provided. "
47-
"using `embedding_source.vocabulary` instead of provided "
48-
"`vocabulary`")
45+
if (self.embeddings_source is None) == (self.vocabulary is None):
46+
raise ValueError("You must specify either `vocabulary or` or "
47+
"`embeddings_source`, not both")
4948
self.vocabulary = self.embeddings_source.vocabulary
5049
# pylint: enable=too-many-arguments
5150

neuralmonkey/readers/string_vector_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray:
1313

1414
return np.array(numbers, dtype=dtype)
1515

16-
def reader(files: List[str])-> Iterable[List[np.ndarray]]:
16+
def reader(files: List[str]) -> Iterable[List[np.ndarray]]:
1717
for path in files:
1818
current_line = 0
1919

tests/bert.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ dropout_keep_prob=0.9
6666
class=decoders.sequence_labeler.SequenceLabeler
6767
name="labeler_bert"
6868
encoder=<encoder>
69-
vocabulary=<vocabulary>
7069
data_id="source_masked"
7170
dropout_keep_prob=0.5
7271
embeddings_source=<sequence>

0 commit comments

Comments
 (0)