From ceb99938276651655cb4b6f41f4c1d7d993eb4c2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 1 Jun 2017 06:03:29 -0500 Subject: [PATCH] Add speculative model serialisation fix --- thinc/neural/_classes/model.py | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/thinc/neural/_classes/model.py b/thinc/neural/_classes/model.py index 28bd146e7..41ec6c63d 100644 --- a/thinc/neural/_classes/model.py +++ b/thinc/neural/_classes/model.py @@ -278,10 +278,10 @@ def to_bytes(self): for layer in queue: if hasattr(layer, '_mem'): weights.append({ - 'dims': normalize_string_keys(getattr(layer, '_dims', {})), - 'params': []}) + b'dims': normalize_string_keys(getattr(layer, '_dims', {})), + b'params': []}) if hasattr(layer, 'seed'): - weights[-1]['seed'] = layer.seed + weights[-1][b'seed'] = layer.seed for (id_, name), (start, row, shape) in layer._mem._offsets.items(): if row == 1: @@ -289,33 +289,36 @@ def to_bytes(self): param = layer._mem.get((id_, name)) if not isinstance(layer._mem.weights, numpy.ndarray): param = param.get() - weights[-1]['params'].append( + weights[-1][b'params'].append( { - 'name': name, - 'offset': start, - 'shape': shape, - 'value': param, + b'name': name, + b'offset': start, + b'shape': shape, + b'value': param, } ) i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers) - return msgpack.dumps({'weights': weights}) + return msgpack.dumps({b'weights': weights}) def from_bytes(self, bytes_data): data = msgpack.loads(bytes_data) - weights = data['weights'] + weights = data[b'weights'] queue = [self] i = 0 for layer in queue: - if hasattr(layer, '_mem'): - if 'seed' in weights[i]: - layer.seed = weights[i]['seed'] - for dim, value in weights[i]['dims'].items(): + if hasattr(layer, b'_mem'): + if b'seed' in weights[i]: + layer.seed = weights[i][b'seed'] + for dim, value in weights[i][b'dims'].items(): setattr(layer, dim, value) - for param in weights[i]['params']: - dest = getattr(layer, param['name']) - copy_array(dest, param['value']) + for param in weights[i][b'params']: + name = param[b'name'] + if isinstance(name, bytes): + name = name.decode('utf8') + dest = getattr(layer, name) + copy_array(dest, param[b'value']) i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers)