Coverage for mlair/keras_legacy/interfaces.py: 20%
293 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-12-02 15:24 +0000
1"""Interface converters for Keras 1 support in Keras 2.
2"""
3from __future__ import absolute_import
4from __future__ import division
5from __future__ import print_function
7import six
8import warnings
9import functools
10import numpy as np
13def generate_legacy_interface(allowed_positional_args=None,
14 conversions=None,
15 preprocessor=None,
16 value_conversions=None,
17 object_type='class'):
18 if allowed_positional_args is None:
19 check_positional_args = False
20 else:
21 check_positional_args = True
22 allowed_positional_args = allowed_positional_args or []
23 conversions = conversions or []
24 value_conversions = value_conversions or []
26 def legacy_support(func):
27 @six.wraps(func)
28 def wrapper(*args, **kwargs):
29 if object_type == 'class': 29 ↛ 32line 29 didn't jump to line 32, because the condition on line 29 was never false
30 object_name = args[0].__class__.__name__
31 else:
32 object_name = func.__name__
33 if preprocessor: 33 ↛ 36line 33 didn't jump to line 36, because the condition on line 33 was never false
34 args, kwargs, converted = preprocessor(args, kwargs)
35 else:
36 converted = []
37 if check_positional_args: 37 ↛ 47line 37 didn't jump to line 47, because the condition on line 37 was never false
38 if len(args) > len(allowed_positional_args) + 1:
39 raise TypeError('`' + object_name +
40 '` can accept only ' +
41 str(len(allowed_positional_args)) +
42 ' positional arguments ' +
43 str(tuple(allowed_positional_args)) +
44 ', but you passed the following '
45 'positional arguments: ' +
46 str(list(args[1:])))
47 for key in value_conversions:
48 if key in kwargs: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 old_value = kwargs[key]
50 if old_value in value_conversions[key]:
51 kwargs[key] = value_conversions[key][old_value]
52 for old_name, new_name in conversions:
53 if old_name in kwargs: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 value = kwargs.pop(old_name)
55 if new_name in kwargs:
56 raise_duplicate_arg_error(old_name, new_name)
57 kwargs[new_name] = value
58 converted.append((new_name, old_name))
59 if converted: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true
60 signature = '`' + object_name + '('
61 for i, value in enumerate(args[1:]):
62 if isinstance(value, six.string_types):
63 signature += '"' + value + '"'
64 else:
65 if isinstance(value, np.ndarray):
66 str_val = 'array'
67 else:
68 str_val = str(value)
69 if len(str_val) > 10:
70 str_val = str_val[:10] + '...'
71 signature += str_val
72 if i < len(args[1:]) - 1 or kwargs:
73 signature += ', '
74 for i, (name, value) in enumerate(kwargs.items()):
75 signature += name + '='
76 if isinstance(value, six.string_types):
77 signature += '"' + value + '"'
78 else:
79 if isinstance(value, np.ndarray):
80 str_val = 'array'
81 else:
82 str_val = str(value)
83 if len(str_val) > 10:
84 str_val = str_val[:10] + '...'
85 signature += str_val
86 if i < len(kwargs) - 1:
87 signature += ', '
88 signature += ')`'
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
94 return legacy_support
97generate_legacy_method_interface = functools.partial(generate_legacy_interface,
98 object_type='method')
101def raise_duplicate_arg_error(old_arg, new_arg):
102 raise TypeError('For the `' + new_arg + '` argument, '
103 'the layer received both '
104 'the legacy keyword argument '
105 '`' + old_arg + '` and the Keras 2 keyword argument '
106 '`' + new_arg + '`. Stick to the latter!')
109legacy_dense_support = generate_legacy_interface(
110 allowed_positional_args=['units'],
111 conversions=[('output_dim', 'units'),
112 ('init', 'kernel_initializer'),
113 ('W_regularizer', 'kernel_regularizer'),
114 ('b_regularizer', 'bias_regularizer'),
115 ('W_constraint', 'kernel_constraint'),
116 ('b_constraint', 'bias_constraint'),
117 ('bias', 'use_bias')])
119legacy_dropout_support = generate_legacy_interface(
120 allowed_positional_args=['rate', 'noise_shape', 'seed'],
121 conversions=[('p', 'rate')])
124def embedding_kwargs_preprocessor(args, kwargs):
125 converted = []
126 if 'dropout' in kwargs:
127 kwargs.pop('dropout')
128 warnings.warn('The `dropout` argument is no longer support in `Embedding`. '
129 'You can apply a `keras.layers.SpatialDropout1D` layer '
130 'right after the `Embedding` layer to get the same behavior.',
131 stacklevel=3)
132 return args, kwargs, converted
134legacy_embedding_support = generate_legacy_interface(
135 allowed_positional_args=['input_dim', 'output_dim'],
136 conversions=[('init', 'embeddings_initializer'),
137 ('W_regularizer', 'embeddings_regularizer'),
138 ('W_constraint', 'embeddings_constraint')],
139 preprocessor=embedding_kwargs_preprocessor)
141legacy_pooling1d_support = generate_legacy_interface(
142 allowed_positional_args=['pool_size', 'strides', 'padding'],
143 conversions=[('pool_length', 'pool_size'),
144 ('stride', 'strides'),
145 ('border_mode', 'padding')])
147legacy_prelu_support = generate_legacy_interface(
148 allowed_positional_args=['alpha_initializer'],
149 conversions=[('init', 'alpha_initializer')])
152legacy_gaussiannoise_support = generate_legacy_interface(
153 allowed_positional_args=['stddev'],
154 conversions=[('sigma', 'stddev')])
157def recurrent_args_preprocessor(args, kwargs):
158 converted = []
159 if 'forget_bias_init' in kwargs:
160 if kwargs['forget_bias_init'] == 'one':
161 kwargs.pop('forget_bias_init')
162 kwargs['unit_forget_bias'] = True
163 converted.append(('forget_bias_init', 'unit_forget_bias'))
164 else:
165 kwargs.pop('forget_bias_init')
166 warnings.warn('The `forget_bias_init` argument '
167 'has been ignored. Use `unit_forget_bias=True` '
168 'instead to initialize with ones.', stacklevel=3)
169 if 'input_dim' in kwargs:
170 input_length = kwargs.pop('input_length', None)
171 input_dim = kwargs.pop('input_dim')
172 input_shape = (input_length, input_dim)
173 kwargs['input_shape'] = input_shape
174 converted.append(('input_dim', 'input_shape'))
175 warnings.warn('The `input_dim` and `input_length` arguments '
176 'in recurrent layers are deprecated. '
177 'Use `input_shape` instead.', stacklevel=3)
178 return args, kwargs, converted
180legacy_recurrent_support = generate_legacy_interface(
181 allowed_positional_args=['units'],
182 conversions=[('output_dim', 'units'),
183 ('init', 'kernel_initializer'),
184 ('inner_init', 'recurrent_initializer'),
185 ('inner_activation', 'recurrent_activation'),
186 ('W_regularizer', 'kernel_regularizer'),
187 ('b_regularizer', 'bias_regularizer'),
188 ('U_regularizer', 'recurrent_regularizer'),
189 ('dropout_W', 'dropout'),
190 ('dropout_U', 'recurrent_dropout'),
191 ('consume_less', 'implementation')],
192 value_conversions={'consume_less': {'cpu': 0,
193 'mem': 1,
194 'gpu': 2}},
195 preprocessor=recurrent_args_preprocessor)
197legacy_gaussiandropout_support = generate_legacy_interface(
198 allowed_positional_args=['rate'],
199 conversions=[('p', 'rate')])
201legacy_pooling2d_support = generate_legacy_interface(
202 allowed_positional_args=['pool_size', 'strides', 'padding'],
203 conversions=[('border_mode', 'padding'),
204 ('dim_ordering', 'data_format')],
205 value_conversions={'dim_ordering': {'tf': 'channels_last',
206 'th': 'channels_first',
207 'default': None}})
209legacy_pooling3d_support = generate_legacy_interface(
210 allowed_positional_args=['pool_size', 'strides', 'padding'],
211 conversions=[('border_mode', 'padding'),
212 ('dim_ordering', 'data_format')],
213 value_conversions={'dim_ordering': {'tf': 'channels_last',
214 'th': 'channels_first',
215 'default': None}})
217legacy_global_pooling_support = generate_legacy_interface(
218 conversions=[('dim_ordering', 'data_format')],
219 value_conversions={'dim_ordering': {'tf': 'channels_last',
220 'th': 'channels_first',
221 'default': None}})
223legacy_upsampling1d_support = generate_legacy_interface(
224 allowed_positional_args=['size'],
225 conversions=[('length', 'size')])
227legacy_upsampling2d_support = generate_legacy_interface(
228 allowed_positional_args=['size'],
229 conversions=[('dim_ordering', 'data_format')],
230 value_conversions={'dim_ordering': {'tf': 'channels_last',
231 'th': 'channels_first',
232 'default': None}})
234legacy_upsampling3d_support = generate_legacy_interface(
235 allowed_positional_args=['size'],
236 conversions=[('dim_ordering', 'data_format')],
237 value_conversions={'dim_ordering': {'tf': 'channels_last',
238 'th': 'channels_first',
239 'default': None}})
242def conv1d_args_preprocessor(args, kwargs):
243 converted = []
244 if 'input_dim' in kwargs:
245 if 'input_length' in kwargs:
246 length = kwargs.pop('input_length')
247 else:
248 length = None
249 input_shape = (length, kwargs.pop('input_dim'))
250 kwargs['input_shape'] = input_shape
251 converted.append(('input_shape', 'input_dim'))
252 return args, kwargs, converted
254legacy_conv1d_support = generate_legacy_interface(
255 allowed_positional_args=['filters', 'kernel_size'],
256 conversions=[('nb_filter', 'filters'),
257 ('filter_length', 'kernel_size'),
258 ('subsample_length', 'strides'),
259 ('border_mode', 'padding'),
260 ('init', 'kernel_initializer'),
261 ('W_regularizer', 'kernel_regularizer'),
262 ('b_regularizer', 'bias_regularizer'),
263 ('W_constraint', 'kernel_constraint'),
264 ('b_constraint', 'bias_constraint'),
265 ('bias', 'use_bias')],
266 preprocessor=conv1d_args_preprocessor)
269def conv2d_args_preprocessor(args, kwargs):
270 converted = []
271 if len(args) > 4:
272 raise TypeError('Layer can receive at most 3 positional arguments.')
273 elif len(args) == 4:
274 if isinstance(args[2], int) and isinstance(args[3], int):
275 new_keywords = ['padding', 'strides', 'data_format']
276 for kwd in new_keywords:
277 if kwd in kwargs:
278 raise ValueError(
279 'It seems that you are using the Keras 2 '
280 'and you are passing both `kernel_size` and `strides` '
281 'as integer positional arguments. For safety reasons, '
282 'this is disallowed. Pass `strides` '
283 'as a keyword argument instead.')
284 kernel_size = (args[2], args[3])
285 args = [args[0], args[1], kernel_size]
286 converted.append(('kernel_size', 'nb_row/nb_col'))
287 elif len(args) == 3 and isinstance(args[2], int):
288 if 'nb_col' in kwargs:
289 kernel_size = (args[2], kwargs.pop('nb_col'))
290 args = [args[0], args[1], kernel_size]
291 converted.append(('kernel_size', 'nb_row/nb_col'))
292 elif len(args) == 2:
293 if 'nb_row' in kwargs and 'nb_col' in kwargs:
294 kernel_size = (kwargs.pop('nb_row'), kwargs.pop('nb_col'))
295 args = [args[0], args[1], kernel_size]
296 converted.append(('kernel_size', 'nb_row/nb_col'))
297 elif len(args) == 1:
298 if 'nb_row' in kwargs and 'nb_col' in kwargs:
299 kernel_size = (kwargs.pop('nb_row'), kwargs.pop('nb_col'))
300 kwargs['kernel_size'] = kernel_size
301 converted.append(('kernel_size', 'nb_row/nb_col'))
302 return args, kwargs, converted
304legacy_conv2d_support = generate_legacy_interface(
305 allowed_positional_args=['filters', 'kernel_size'],
306 conversions=[('nb_filter', 'filters'),
307 ('subsample', 'strides'),
308 ('border_mode', 'padding'),
309 ('dim_ordering', 'data_format'),
310 ('init', 'kernel_initializer'),
311 ('W_regularizer', 'kernel_regularizer'),
312 ('b_regularizer', 'bias_regularizer'),
313 ('W_constraint', 'kernel_constraint'),
314 ('b_constraint', 'bias_constraint'),
315 ('bias', 'use_bias')],
316 value_conversions={'dim_ordering': {'tf': 'channels_last',
317 'th': 'channels_first',
318 'default': None}},
319 preprocessor=conv2d_args_preprocessor)
322def separable_conv2d_args_preprocessor(args, kwargs):
323 converted = []
324 if 'init' in kwargs:
325 init = kwargs.pop('init')
326 kwargs['depthwise_initializer'] = init
327 kwargs['pointwise_initializer'] = init
328 converted.append(('init', 'depthwise_initializer/pointwise_initializer'))
329 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs)
330 return args, kwargs, converted + _converted
332legacy_separable_conv2d_support = generate_legacy_interface(
333 allowed_positional_args=['filters', 'kernel_size'],
334 conversions=[('nb_filter', 'filters'),
335 ('subsample', 'strides'),
336 ('border_mode', 'padding'),
337 ('dim_ordering', 'data_format'),
338 ('b_regularizer', 'bias_regularizer'),
339 ('b_constraint', 'bias_constraint'),
340 ('bias', 'use_bias')],
341 value_conversions={'dim_ordering': {'tf': 'channels_last',
342 'th': 'channels_first',
343 'default': None}},
344 preprocessor=separable_conv2d_args_preprocessor)
347def deconv2d_args_preprocessor(args, kwargs):
348 converted = []
349 if len(args) == 5:
350 if isinstance(args[4], tuple):
351 args = args[:-1]
352 converted.append(('output_shape', None))
353 if 'output_shape' in kwargs:
354 kwargs.pop('output_shape')
355 converted.append(('output_shape', None))
356 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs)
357 return args, kwargs, converted + _converted
359legacy_deconv2d_support = generate_legacy_interface(
360 allowed_positional_args=['filters', 'kernel_size'],
361 conversions=[('nb_filter', 'filters'),
362 ('subsample', 'strides'),
363 ('border_mode', 'padding'),
364 ('dim_ordering', 'data_format'),
365 ('init', 'kernel_initializer'),
366 ('W_regularizer', 'kernel_regularizer'),
367 ('b_regularizer', 'bias_regularizer'),
368 ('W_constraint', 'kernel_constraint'),
369 ('b_constraint', 'bias_constraint'),
370 ('bias', 'use_bias')],
371 value_conversions={'dim_ordering': {'tf': 'channels_last',
372 'th': 'channels_first',
373 'default': None}},
374 preprocessor=deconv2d_args_preprocessor)
377def conv3d_args_preprocessor(args, kwargs):
378 converted = []
379 if len(args) > 5:
380 raise TypeError('Layer can receive at most 4 positional arguments.')
381 if len(args) == 5:
382 if all([isinstance(x, int) for x in args[2:5]]):
383 kernel_size = (args[2], args[3], args[4])
384 args = [args[0], args[1], kernel_size]
385 converted.append(('kernel_size', 'kernel_dim*'))
386 elif len(args) == 4 and isinstance(args[3], int):
387 if isinstance(args[2], int) and isinstance(args[3], int):
388 new_keywords = ['padding', 'strides', 'data_format']
389 for kwd in new_keywords:
390 if kwd in kwargs:
391 raise ValueError(
392 'It seems that you are using the Keras 2 '
393 'and you are passing both `kernel_size` and `strides` '
394 'as integer positional arguments. For safety reasons, '
395 'this is disallowed. Pass `strides` '
396 'as a keyword argument instead.')
397 if 'kernel_dim3' in kwargs:
398 kernel_size = (args[2], args[3], kwargs.pop('kernel_dim3'))
399 args = [args[0], args[1], kernel_size]
400 converted.append(('kernel_size', 'kernel_dim*'))
401 elif len(args) == 3:
402 if all([x in kwargs for x in ['kernel_dim2', 'kernel_dim3']]):
403 kernel_size = (args[2],
404 kwargs.pop('kernel_dim2'),
405 kwargs.pop('kernel_dim3'))
406 args = [args[0], args[1], kernel_size]
407 converted.append(('kernel_size', 'kernel_dim*'))
408 elif len(args) == 2:
409 if all([x in kwargs for x in ['kernel_dim1', 'kernel_dim2', 'kernel_dim3']]):
410 kernel_size = (kwargs.pop('kernel_dim1'),
411 kwargs.pop('kernel_dim2'),
412 kwargs.pop('kernel_dim3'))
413 args = [args[0], args[1], kernel_size]
414 converted.append(('kernel_size', 'kernel_dim*'))
415 elif len(args) == 1:
416 if all([x in kwargs for x in ['kernel_dim1', 'kernel_dim2', 'kernel_dim3']]):
417 kernel_size = (kwargs.pop('kernel_dim1'),
418 kwargs.pop('kernel_dim2'),
419 kwargs.pop('kernel_dim3'))
420 kwargs['kernel_size'] = kernel_size
421 converted.append(('kernel_size', 'nb_row/nb_col'))
422 return args, kwargs, converted
424legacy_conv3d_support = generate_legacy_interface(
425 allowed_positional_args=['filters', 'kernel_size'],
426 conversions=[('nb_filter', 'filters'),
427 ('subsample', 'strides'),
428 ('border_mode', 'padding'),
429 ('dim_ordering', 'data_format'),
430 ('init', 'kernel_initializer'),
431 ('W_regularizer', 'kernel_regularizer'),
432 ('b_regularizer', 'bias_regularizer'),
433 ('W_constraint', 'kernel_constraint'),
434 ('b_constraint', 'bias_constraint'),
435 ('bias', 'use_bias')],
436 value_conversions={'dim_ordering': {'tf': 'channels_last',
437 'th': 'channels_first',
438 'default': None}},
439 preprocessor=conv3d_args_preprocessor)
442def batchnorm_args_preprocessor(args, kwargs):
443 converted = []
444 if len(args) > 1:
445 raise TypeError('The `BatchNormalization` layer '
446 'does not accept positional arguments. '
447 'Use keyword arguments instead.')
448 if 'mode' in kwargs:
449 value = kwargs.pop('mode')
450 if value != 0:
451 raise TypeError('The `mode` argument of `BatchNormalization` '
452 'no longer exists. `mode=1` and `mode=2` '
453 'are no longer supported.')
454 converted.append(('mode', None))
455 return args, kwargs, converted
458def convlstm2d_args_preprocessor(args, kwargs):
459 converted = []
460 if 'forget_bias_init' in kwargs:
461 value = kwargs.pop('forget_bias_init')
462 if value == 'one':
463 kwargs['unit_forget_bias'] = True
464 converted.append(('forget_bias_init', 'unit_forget_bias'))
465 else:
466 warnings.warn('The `forget_bias_init` argument '
467 'has been ignored. Use `unit_forget_bias=True` '
468 'instead to initialize with ones.', stacklevel=3)
469 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs)
470 return args, kwargs, converted + _converted
472legacy_convlstm2d_support = generate_legacy_interface(
473 allowed_positional_args=['filters', 'kernel_size'],
474 conversions=[('nb_filter', 'filters'),
475 ('subsample', 'strides'),
476 ('border_mode', 'padding'),
477 ('dim_ordering', 'data_format'),
478 ('init', 'kernel_initializer'),
479 ('inner_init', 'recurrent_initializer'),
480 ('W_regularizer', 'kernel_regularizer'),
481 ('U_regularizer', 'recurrent_regularizer'),
482 ('b_regularizer', 'bias_regularizer'),
483 ('inner_activation', 'recurrent_activation'),
484 ('dropout_W', 'dropout'),
485 ('dropout_U', 'recurrent_dropout'),
486 ('bias', 'use_bias')],
487 value_conversions={'dim_ordering': {'tf': 'channels_last',
488 'th': 'channels_first',
489 'default': None}},
490 preprocessor=convlstm2d_args_preprocessor)
492legacy_batchnorm_support = generate_legacy_interface(
493 allowed_positional_args=[],
494 conversions=[('beta_init', 'beta_initializer'),
495 ('gamma_init', 'gamma_initializer')],
496 preprocessor=batchnorm_args_preprocessor)
499def zeropadding2d_args_preprocessor(args, kwargs):
500 converted = []
501 if 'padding' in kwargs and isinstance(kwargs['padding'], dict): 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true
502 if set(kwargs['padding'].keys()) <= {'top_pad', 'bottom_pad',
503 'left_pad', 'right_pad'}:
504 top_pad = kwargs['padding'].get('top_pad', 0)
505 bottom_pad = kwargs['padding'].get('bottom_pad', 0)
506 left_pad = kwargs['padding'].get('left_pad', 0)
507 right_pad = kwargs['padding'].get('right_pad', 0)
508 kwargs['padding'] = ((top_pad, bottom_pad), (left_pad, right_pad))
509 warnings.warn('The `padding` argument in the Keras 2 API no longer'
510 'accepts dict types. You can now input argument as: '
511 '`padding=(top_pad, bottom_pad, left_pad, right_pad)`.',
512 stacklevel=3)
513 elif len(args) == 2 and isinstance(args[1], dict): 513 ↛ 514line 513 didn't jump to line 514, because the condition on line 513 was never true
514 if set(args[1].keys()) <= {'top_pad', 'bottom_pad',
515 'left_pad', 'right_pad'}:
516 top_pad = args[1].get('top_pad', 0)
517 bottom_pad = args[1].get('bottom_pad', 0)
518 left_pad = args[1].get('left_pad', 0)
519 right_pad = args[1].get('right_pad', 0)
520 args = (args[0], ((top_pad, bottom_pad), (left_pad, right_pad)))
521 warnings.warn('The `padding` argument in the Keras 2 API no longer'
522 'accepts dict types. You can now input argument as: '
523 '`padding=((top_pad, bottom_pad), (left_pad, right_pad))`',
524 stacklevel=3)
525 return args, kwargs, converted
527legacy_zeropadding2d_support = generate_legacy_interface(
528 allowed_positional_args=['padding'],
529 conversions=[('dim_ordering', 'data_format')],
530 value_conversions={'dim_ordering': {'tf': 'channels_last',
531 'th': 'channels_first',
532 'default': None}},
533 preprocessor=zeropadding2d_args_preprocessor)
535legacy_zeropadding3d_support = generate_legacy_interface(
536 allowed_positional_args=['padding'],
537 conversions=[('dim_ordering', 'data_format')],
538 value_conversions={'dim_ordering': {'tf': 'channels_last',
539 'th': 'channels_first',
540 'default': None}})
542legacy_cropping2d_support = generate_legacy_interface(
543 allowed_positional_args=['cropping'],
544 conversions=[('dim_ordering', 'data_format')],
545 value_conversions={'dim_ordering': {'tf': 'channels_last',
546 'th': 'channels_first',
547 'default': None}})
549legacy_cropping3d_support = generate_legacy_interface(
550 allowed_positional_args=['cropping'],
551 conversions=[('dim_ordering', 'data_format')],
552 value_conversions={'dim_ordering': {'tf': 'channels_last',
553 'th': 'channels_first',
554 'default': None}})
556legacy_spatialdropout1d_support = generate_legacy_interface(
557 allowed_positional_args=['rate'],
558 conversions=[('p', 'rate')])
560legacy_spatialdropoutNd_support = generate_legacy_interface(
561 allowed_positional_args=['rate'],
562 conversions=[('p', 'rate'),
563 ('dim_ordering', 'data_format')],
564 value_conversions={'dim_ordering': {'tf': 'channels_last',
565 'th': 'channels_first',
566 'default': None}})
568legacy_lambda_support = generate_legacy_interface(
569 allowed_positional_args=['function', 'output_shape'])
572# Model methods
574def generator_methods_args_preprocessor(args, kwargs):
575 converted = []
576 if len(args) < 3:
577 if 'samples_per_epoch' in kwargs:
578 samples_per_epoch = kwargs.pop('samples_per_epoch')
579 if len(args) > 1:
580 generator = args[1]
581 else:
582 generator = kwargs['generator']
583 if hasattr(generator, 'batch_size'):
584 kwargs['steps_per_epoch'] = samples_per_epoch // generator.batch_size
585 else:
586 kwargs['steps_per_epoch'] = samples_per_epoch
587 converted.append(('samples_per_epoch', 'steps_per_epoch'))
589 keras1_args = {'samples_per_epoch', 'val_samples',
590 'nb_epoch', 'nb_val_samples', 'nb_worker'}
591 if keras1_args.intersection(kwargs.keys()):
592 warnings.warn('The semantics of the Keras 2 argument '
593 '`steps_per_epoch` is not the same as the '
594 'Keras 1 argument `samples_per_epoch`. '
595 '`steps_per_epoch` is the number of batches '
596 'to draw from the generator at each epoch. '
597 'Basically steps_per_epoch = samples_per_epoch/batch_size. '
598 'Similarly `nb_val_samples`->`validation_steps` and '
599 '`val_samples`->`steps` arguments have changed. '
600 'Update your method calls accordingly.', stacklevel=3)
602 return args, kwargs, converted
605legacy_generator_methods_support = generate_legacy_method_interface(
606 allowed_positional_args=['generator', 'steps_per_epoch', 'epochs'],
607 conversions=[('samples_per_epoch', 'steps_per_epoch'),
608 ('val_samples', 'steps'),
609 ('nb_epoch', 'epochs'),
610 ('nb_val_samples', 'validation_steps'),
611 ('nb_worker', 'workers'),
612 ('pickle_safe', 'use_multiprocessing'),
613 ('max_q_size', 'max_queue_size')],
614 preprocessor=generator_methods_args_preprocessor)
617legacy_model_constructor_support = generate_legacy_interface(
618 allowed_positional_args=None,
619 conversions=[('input', 'inputs'),
620 ('output', 'outputs')])
622legacy_input_support = generate_legacy_interface(
623 allowed_positional_args=None,
624 conversions=[('input_dtype', 'dtype')])
627def add_weight_args_preprocessing(args, kwargs):
628 if len(args) > 1:
629 if isinstance(args[1], (tuple, list)):
630 kwargs['shape'] = args[1]
631 args = (args[0],) + args[2:]
632 if len(args) > 1:
633 if isinstance(args[1], six.string_types):
634 kwargs['name'] = args[1]
635 args = (args[0],) + args[2:]
636 return args, kwargs, []
639legacy_add_weight_support = generate_legacy_interface(
640 allowed_positional_args=['name', 'shape'],
641 preprocessor=add_weight_args_preprocessing)
644def get_updates_arg_preprocessing(args, kwargs):
645 # Old interface: (params, constraints, loss)
646 # New interface: (loss, params)
647 if len(args) > 4:
648 raise TypeError('`get_update` call received more arguments '
649 'than expected.')
650 elif len(args) == 4:
651 # Assuming old interface.
652 opt, params, _, loss = args
653 kwargs['loss'] = loss
654 kwargs['params'] = params
655 return [opt], kwargs, []
656 elif len(args) == 3:
657 if isinstance(args[1], (list, tuple)):
658 assert isinstance(args[2], dict)
659 assert 'loss' in kwargs
660 opt, params, _ = args
661 kwargs['params'] = params
662 return [opt], kwargs, []
663 return args, kwargs, []
665legacy_get_updates_support = generate_legacy_interface(
666 allowed_positional_args=None,
667 conversions=[],
668 preprocessor=get_updates_arg_preprocessing)