Coverage for mlair/keras_legacy/interfaces.py: 20%

293 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2023-06-30 10:40 +0000

1"""Interface converters for Keras 1 support in Keras 2. 

2""" 

3from __future__ import absolute_import 

4from __future__ import division 

5from __future__ import print_function 

6 

7import six 

8import warnings 

9import functools 

10import numpy as np 

11 

12 

13def generate_legacy_interface(allowed_positional_args=None, 

14 conversions=None, 

15 preprocessor=None, 

16 value_conversions=None, 

17 object_type='class'): 

18 if allowed_positional_args is None: 

19 check_positional_args = False 

20 else: 

21 check_positional_args = True 

22 allowed_positional_args = allowed_positional_args or [] 

23 conversions = conversions or [] 

24 value_conversions = value_conversions or [] 

25 

26 def legacy_support(func): 

27 @six.wraps(func) 

28 def wrapper(*args, **kwargs): 

29 if object_type == 'class': 29 ↛ 32line 29 didn't jump to line 32, because the condition on line 29 was never false

30 object_name = args[0].__class__.__name__ 

31 else: 

32 object_name = func.__name__ 

33 if preprocessor: 33 ↛ 36line 33 didn't jump to line 36, because the condition on line 33 was never false

34 args, kwargs, converted = preprocessor(args, kwargs) 

35 else: 

36 converted = [] 

37 if check_positional_args: 37 ↛ 47line 37 didn't jump to line 47, because the condition on line 37 was never false

38 if len(args) > len(allowed_positional_args) + 1: 

39 raise TypeError('`' + object_name + 

40 '` can accept only ' + 

41 str(len(allowed_positional_args)) + 

42 ' positional arguments ' + 

43 str(tuple(allowed_positional_args)) + 

44 ', but you passed the following ' 

45 'positional arguments: ' + 

46 str(list(args[1:]))) 

47 for key in value_conversions: 

48 if key in kwargs: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 old_value = kwargs[key] 

50 if old_value in value_conversions[key]: 

51 kwargs[key] = value_conversions[key][old_value] 

52 for old_name, new_name in conversions: 

53 if old_name in kwargs: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 value = kwargs.pop(old_name) 

55 if new_name in kwargs: 

56 raise_duplicate_arg_error(old_name, new_name) 

57 kwargs[new_name] = value 

58 converted.append((new_name, old_name)) 

59 if converted: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true

60 signature = '`' + object_name + '(' 

61 for i, value in enumerate(args[1:]): 

62 if isinstance(value, six.string_types): 

63 signature += '"' + value + '"' 

64 else: 

65 if isinstance(value, np.ndarray): 

66 str_val = 'array' 

67 else: 

68 str_val = str(value) 

69 if len(str_val) > 10: 

70 str_val = str_val[:10] + '...' 

71 signature += str_val 

72 if i < len(args[1:]) - 1 or kwargs: 

73 signature += ', ' 

74 for i, (name, value) in enumerate(kwargs.items()): 

75 signature += name + '=' 

76 if isinstance(value, six.string_types): 

77 signature += '"' + value + '"' 

78 else: 

79 if isinstance(value, np.ndarray): 

80 str_val = 'array' 

81 else: 

82 str_val = str(value) 

83 if len(str_val) > 10: 

84 str_val = str_val[:10] + '...' 

85 signature += str_val 

86 if i < len(kwargs) - 1: 

87 signature += ', ' 

88 signature += ')`' 

89 warnings.warn('Update your `' + object_name + '` call to the ' + 

90 'Keras 2 API: ' + signature, stacklevel=2) 

91 return func(*args, **kwargs) 

92 wrapper._original_function = func 

93 return wrapper 

94 return legacy_support 

95 

96 

97generate_legacy_method_interface = functools.partial(generate_legacy_interface, 

98 object_type='method') 

99 

100 

101def raise_duplicate_arg_error(old_arg, new_arg): 

102 raise TypeError('For the `' + new_arg + '` argument, ' 

103 'the layer received both ' 

104 'the legacy keyword argument ' 

105 '`' + old_arg + '` and the Keras 2 keyword argument ' 

106 '`' + new_arg + '`. Stick to the latter!') 

107 

108 

109legacy_dense_support = generate_legacy_interface( 

110 allowed_positional_args=['units'], 

111 conversions=[('output_dim', 'units'), 

112 ('init', 'kernel_initializer'), 

113 ('W_regularizer', 'kernel_regularizer'), 

114 ('b_regularizer', 'bias_regularizer'), 

115 ('W_constraint', 'kernel_constraint'), 

116 ('b_constraint', 'bias_constraint'), 

117 ('bias', 'use_bias')]) 

118 

119legacy_dropout_support = generate_legacy_interface( 

120 allowed_positional_args=['rate', 'noise_shape', 'seed'], 

121 conversions=[('p', 'rate')]) 

122 

123 

124def embedding_kwargs_preprocessor(args, kwargs): 

125 converted = [] 

126 if 'dropout' in kwargs: 

127 kwargs.pop('dropout') 

128 warnings.warn('The `dropout` argument is no longer support in `Embedding`. ' 

129 'You can apply a `keras.layers.SpatialDropout1D` layer ' 

130 'right after the `Embedding` layer to get the same behavior.', 

131 stacklevel=3) 

132 return args, kwargs, converted 

133 

134legacy_embedding_support = generate_legacy_interface( 

135 allowed_positional_args=['input_dim', 'output_dim'], 

136 conversions=[('init', 'embeddings_initializer'), 

137 ('W_regularizer', 'embeddings_regularizer'), 

138 ('W_constraint', 'embeddings_constraint')], 

139 preprocessor=embedding_kwargs_preprocessor) 

140 

141legacy_pooling1d_support = generate_legacy_interface( 

142 allowed_positional_args=['pool_size', 'strides', 'padding'], 

143 conversions=[('pool_length', 'pool_size'), 

144 ('stride', 'strides'), 

145 ('border_mode', 'padding')]) 

146 

147legacy_prelu_support = generate_legacy_interface( 

148 allowed_positional_args=['alpha_initializer'], 

149 conversions=[('init', 'alpha_initializer')]) 

150 

151 

152legacy_gaussiannoise_support = generate_legacy_interface( 

153 allowed_positional_args=['stddev'], 

154 conversions=[('sigma', 'stddev')]) 

155 

156 

157def recurrent_args_preprocessor(args, kwargs): 

158 converted = [] 

159 if 'forget_bias_init' in kwargs: 

160 if kwargs['forget_bias_init'] == 'one': 

161 kwargs.pop('forget_bias_init') 

162 kwargs['unit_forget_bias'] = True 

163 converted.append(('forget_bias_init', 'unit_forget_bias')) 

164 else: 

165 kwargs.pop('forget_bias_init') 

166 warnings.warn('The `forget_bias_init` argument ' 

167 'has been ignored. Use `unit_forget_bias=True` ' 

168 'instead to initialize with ones.', stacklevel=3) 

169 if 'input_dim' in kwargs: 

170 input_length = kwargs.pop('input_length', None) 

171 input_dim = kwargs.pop('input_dim') 

172 input_shape = (input_length, input_dim) 

173 kwargs['input_shape'] = input_shape 

174 converted.append(('input_dim', 'input_shape')) 

175 warnings.warn('The `input_dim` and `input_length` arguments ' 

176 'in recurrent layers are deprecated. ' 

177 'Use `input_shape` instead.', stacklevel=3) 

178 return args, kwargs, converted 

179 

180legacy_recurrent_support = generate_legacy_interface( 

181 allowed_positional_args=['units'], 

182 conversions=[('output_dim', 'units'), 

183 ('init', 'kernel_initializer'), 

184 ('inner_init', 'recurrent_initializer'), 

185 ('inner_activation', 'recurrent_activation'), 

186 ('W_regularizer', 'kernel_regularizer'), 

187 ('b_regularizer', 'bias_regularizer'), 

188 ('U_regularizer', 'recurrent_regularizer'), 

189 ('dropout_W', 'dropout'), 

190 ('dropout_U', 'recurrent_dropout'), 

191 ('consume_less', 'implementation')], 

192 value_conversions={'consume_less': {'cpu': 0, 

193 'mem': 1, 

194 'gpu': 2}}, 

195 preprocessor=recurrent_args_preprocessor) 

196 

197legacy_gaussiandropout_support = generate_legacy_interface( 

198 allowed_positional_args=['rate'], 

199 conversions=[('p', 'rate')]) 

200 

201legacy_pooling2d_support = generate_legacy_interface( 

202 allowed_positional_args=['pool_size', 'strides', 'padding'], 

203 conversions=[('border_mode', 'padding'), 

204 ('dim_ordering', 'data_format')], 

205 value_conversions={'dim_ordering': {'tf': 'channels_last', 

206 'th': 'channels_first', 

207 'default': None}}) 

208 

209legacy_pooling3d_support = generate_legacy_interface( 

210 allowed_positional_args=['pool_size', 'strides', 'padding'], 

211 conversions=[('border_mode', 'padding'), 

212 ('dim_ordering', 'data_format')], 

213 value_conversions={'dim_ordering': {'tf': 'channels_last', 

214 'th': 'channels_first', 

215 'default': None}}) 

216 

217legacy_global_pooling_support = generate_legacy_interface( 

218 conversions=[('dim_ordering', 'data_format')], 

219 value_conversions={'dim_ordering': {'tf': 'channels_last', 

220 'th': 'channels_first', 

221 'default': None}}) 

222 

223legacy_upsampling1d_support = generate_legacy_interface( 

224 allowed_positional_args=['size'], 

225 conversions=[('length', 'size')]) 

226 

227legacy_upsampling2d_support = generate_legacy_interface( 

228 allowed_positional_args=['size'], 

229 conversions=[('dim_ordering', 'data_format')], 

230 value_conversions={'dim_ordering': {'tf': 'channels_last', 

231 'th': 'channels_first', 

232 'default': None}}) 

233 

234legacy_upsampling3d_support = generate_legacy_interface( 

235 allowed_positional_args=['size'], 

236 conversions=[('dim_ordering', 'data_format')], 

237 value_conversions={'dim_ordering': {'tf': 'channels_last', 

238 'th': 'channels_first', 

239 'default': None}}) 

240 

241 

242def conv1d_args_preprocessor(args, kwargs): 

243 converted = [] 

244 if 'input_dim' in kwargs: 

245 if 'input_length' in kwargs: 

246 length = kwargs.pop('input_length') 

247 else: 

248 length = None 

249 input_shape = (length, kwargs.pop('input_dim')) 

250 kwargs['input_shape'] = input_shape 

251 converted.append(('input_shape', 'input_dim')) 

252 return args, kwargs, converted 

253 

254legacy_conv1d_support = generate_legacy_interface( 

255 allowed_positional_args=['filters', 'kernel_size'], 

256 conversions=[('nb_filter', 'filters'), 

257 ('filter_length', 'kernel_size'), 

258 ('subsample_length', 'strides'), 

259 ('border_mode', 'padding'), 

260 ('init', 'kernel_initializer'), 

261 ('W_regularizer', 'kernel_regularizer'), 

262 ('b_regularizer', 'bias_regularizer'), 

263 ('W_constraint', 'kernel_constraint'), 

264 ('b_constraint', 'bias_constraint'), 

265 ('bias', 'use_bias')], 

266 preprocessor=conv1d_args_preprocessor) 

267 

268 

269def conv2d_args_preprocessor(args, kwargs): 

270 converted = [] 

271 if len(args) > 4: 

272 raise TypeError('Layer can receive at most 3 positional arguments.') 

273 elif len(args) == 4: 

274 if isinstance(args[2], int) and isinstance(args[3], int): 

275 new_keywords = ['padding', 'strides', 'data_format'] 

276 for kwd in new_keywords: 

277 if kwd in kwargs: 

278 raise ValueError( 

279 'It seems that you are using the Keras 2 ' 

280 'and you are passing both `kernel_size` and `strides` ' 

281 'as integer positional arguments. For safety reasons, ' 

282 'this is disallowed. Pass `strides` ' 

283 'as a keyword argument instead.') 

284 kernel_size = (args[2], args[3]) 

285 args = [args[0], args[1], kernel_size] 

286 converted.append(('kernel_size', 'nb_row/nb_col')) 

287 elif len(args) == 3 and isinstance(args[2], int): 

288 if 'nb_col' in kwargs: 

289 kernel_size = (args[2], kwargs.pop('nb_col')) 

290 args = [args[0], args[1], kernel_size] 

291 converted.append(('kernel_size', 'nb_row/nb_col')) 

292 elif len(args) == 2: 

293 if 'nb_row' in kwargs and 'nb_col' in kwargs: 

294 kernel_size = (kwargs.pop('nb_row'), kwargs.pop('nb_col')) 

295 args = [args[0], args[1], kernel_size] 

296 converted.append(('kernel_size', 'nb_row/nb_col')) 

297 elif len(args) == 1: 

298 if 'nb_row' in kwargs and 'nb_col' in kwargs: 

299 kernel_size = (kwargs.pop('nb_row'), kwargs.pop('nb_col')) 

300 kwargs['kernel_size'] = kernel_size 

301 converted.append(('kernel_size', 'nb_row/nb_col')) 

302 return args, kwargs, converted 

303 

304legacy_conv2d_support = generate_legacy_interface( 

305 allowed_positional_args=['filters', 'kernel_size'], 

306 conversions=[('nb_filter', 'filters'), 

307 ('subsample', 'strides'), 

308 ('border_mode', 'padding'), 

309 ('dim_ordering', 'data_format'), 

310 ('init', 'kernel_initializer'), 

311 ('W_regularizer', 'kernel_regularizer'), 

312 ('b_regularizer', 'bias_regularizer'), 

313 ('W_constraint', 'kernel_constraint'), 

314 ('b_constraint', 'bias_constraint'), 

315 ('bias', 'use_bias')], 

316 value_conversions={'dim_ordering': {'tf': 'channels_last', 

317 'th': 'channels_first', 

318 'default': None}}, 

319 preprocessor=conv2d_args_preprocessor) 

320 

321 

322def separable_conv2d_args_preprocessor(args, kwargs): 

323 converted = [] 

324 if 'init' in kwargs: 

325 init = kwargs.pop('init') 

326 kwargs['depthwise_initializer'] = init 

327 kwargs['pointwise_initializer'] = init 

328 converted.append(('init', 'depthwise_initializer/pointwise_initializer')) 

329 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs) 

330 return args, kwargs, converted + _converted 

331 

332legacy_separable_conv2d_support = generate_legacy_interface( 

333 allowed_positional_args=['filters', 'kernel_size'], 

334 conversions=[('nb_filter', 'filters'), 

335 ('subsample', 'strides'), 

336 ('border_mode', 'padding'), 

337 ('dim_ordering', 'data_format'), 

338 ('b_regularizer', 'bias_regularizer'), 

339 ('b_constraint', 'bias_constraint'), 

340 ('bias', 'use_bias')], 

341 value_conversions={'dim_ordering': {'tf': 'channels_last', 

342 'th': 'channels_first', 

343 'default': None}}, 

344 preprocessor=separable_conv2d_args_preprocessor) 

345 

346 

347def deconv2d_args_preprocessor(args, kwargs): 

348 converted = [] 

349 if len(args) == 5: 

350 if isinstance(args[4], tuple): 

351 args = args[:-1] 

352 converted.append(('output_shape', None)) 

353 if 'output_shape' in kwargs: 

354 kwargs.pop('output_shape') 

355 converted.append(('output_shape', None)) 

356 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs) 

357 return args, kwargs, converted + _converted 

358 

359legacy_deconv2d_support = generate_legacy_interface( 

360 allowed_positional_args=['filters', 'kernel_size'], 

361 conversions=[('nb_filter', 'filters'), 

362 ('subsample', 'strides'), 

363 ('border_mode', 'padding'), 

364 ('dim_ordering', 'data_format'), 

365 ('init', 'kernel_initializer'), 

366 ('W_regularizer', 'kernel_regularizer'), 

367 ('b_regularizer', 'bias_regularizer'), 

368 ('W_constraint', 'kernel_constraint'), 

369 ('b_constraint', 'bias_constraint'), 

370 ('bias', 'use_bias')], 

371 value_conversions={'dim_ordering': {'tf': 'channels_last', 

372 'th': 'channels_first', 

373 'default': None}}, 

374 preprocessor=deconv2d_args_preprocessor) 

375 

376 

377def conv3d_args_preprocessor(args, kwargs): 

378 converted = [] 

379 if len(args) > 5: 

380 raise TypeError('Layer can receive at most 4 positional arguments.') 

381 if len(args) == 5: 

382 if all([isinstance(x, int) for x in args[2:5]]): 

383 kernel_size = (args[2], args[3], args[4]) 

384 args = [args[0], args[1], kernel_size] 

385 converted.append(('kernel_size', 'kernel_dim*')) 

386 elif len(args) == 4 and isinstance(args[3], int): 

387 if isinstance(args[2], int) and isinstance(args[3], int): 

388 new_keywords = ['padding', 'strides', 'data_format'] 

389 for kwd in new_keywords: 

390 if kwd in kwargs: 

391 raise ValueError( 

392 'It seems that you are using the Keras 2 ' 

393 'and you are passing both `kernel_size` and `strides` ' 

394 'as integer positional arguments. For safety reasons, ' 

395 'this is disallowed. Pass `strides` ' 

396 'as a keyword argument instead.') 

397 if 'kernel_dim3' in kwargs: 

398 kernel_size = (args[2], args[3], kwargs.pop('kernel_dim3')) 

399 args = [args[0], args[1], kernel_size] 

400 converted.append(('kernel_size', 'kernel_dim*')) 

401 elif len(args) == 3: 

402 if all([x in kwargs for x in ['kernel_dim2', 'kernel_dim3']]): 

403 kernel_size = (args[2], 

404 kwargs.pop('kernel_dim2'), 

405 kwargs.pop('kernel_dim3')) 

406 args = [args[0], args[1], kernel_size] 

407 converted.append(('kernel_size', 'kernel_dim*')) 

408 elif len(args) == 2: 

409 if all([x in kwargs for x in ['kernel_dim1', 'kernel_dim2', 'kernel_dim3']]): 

410 kernel_size = (kwargs.pop('kernel_dim1'), 

411 kwargs.pop('kernel_dim2'), 

412 kwargs.pop('kernel_dim3')) 

413 args = [args[0], args[1], kernel_size] 

414 converted.append(('kernel_size', 'kernel_dim*')) 

415 elif len(args) == 1: 

416 if all([x in kwargs for x in ['kernel_dim1', 'kernel_dim2', 'kernel_dim3']]): 

417 kernel_size = (kwargs.pop('kernel_dim1'), 

418 kwargs.pop('kernel_dim2'), 

419 kwargs.pop('kernel_dim3')) 

420 kwargs['kernel_size'] = kernel_size 

421 converted.append(('kernel_size', 'nb_row/nb_col')) 

422 return args, kwargs, converted 

423 

424legacy_conv3d_support = generate_legacy_interface( 

425 allowed_positional_args=['filters', 'kernel_size'], 

426 conversions=[('nb_filter', 'filters'), 

427 ('subsample', 'strides'), 

428 ('border_mode', 'padding'), 

429 ('dim_ordering', 'data_format'), 

430 ('init', 'kernel_initializer'), 

431 ('W_regularizer', 'kernel_regularizer'), 

432 ('b_regularizer', 'bias_regularizer'), 

433 ('W_constraint', 'kernel_constraint'), 

434 ('b_constraint', 'bias_constraint'), 

435 ('bias', 'use_bias')], 

436 value_conversions={'dim_ordering': {'tf': 'channels_last', 

437 'th': 'channels_first', 

438 'default': None}}, 

439 preprocessor=conv3d_args_preprocessor) 

440 

441 

442def batchnorm_args_preprocessor(args, kwargs): 

443 converted = [] 

444 if len(args) > 1: 

445 raise TypeError('The `BatchNormalization` layer ' 

446 'does not accept positional arguments. ' 

447 'Use keyword arguments instead.') 

448 if 'mode' in kwargs: 

449 value = kwargs.pop('mode') 

450 if value != 0: 

451 raise TypeError('The `mode` argument of `BatchNormalization` ' 

452 'no longer exists. `mode=1` and `mode=2` ' 

453 'are no longer supported.') 

454 converted.append(('mode', None)) 

455 return args, kwargs, converted 

456 

457 

458def convlstm2d_args_preprocessor(args, kwargs): 

459 converted = [] 

460 if 'forget_bias_init' in kwargs: 

461 value = kwargs.pop('forget_bias_init') 

462 if value == 'one': 

463 kwargs['unit_forget_bias'] = True 

464 converted.append(('forget_bias_init', 'unit_forget_bias')) 

465 else: 

466 warnings.warn('The `forget_bias_init` argument ' 

467 'has been ignored. Use `unit_forget_bias=True` ' 

468 'instead to initialize with ones.', stacklevel=3) 

469 args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs) 

470 return args, kwargs, converted + _converted 

471 

472legacy_convlstm2d_support = generate_legacy_interface( 

473 allowed_positional_args=['filters', 'kernel_size'], 

474 conversions=[('nb_filter', 'filters'), 

475 ('subsample', 'strides'), 

476 ('border_mode', 'padding'), 

477 ('dim_ordering', 'data_format'), 

478 ('init', 'kernel_initializer'), 

479 ('inner_init', 'recurrent_initializer'), 

480 ('W_regularizer', 'kernel_regularizer'), 

481 ('U_regularizer', 'recurrent_regularizer'), 

482 ('b_regularizer', 'bias_regularizer'), 

483 ('inner_activation', 'recurrent_activation'), 

484 ('dropout_W', 'dropout'), 

485 ('dropout_U', 'recurrent_dropout'), 

486 ('bias', 'use_bias')], 

487 value_conversions={'dim_ordering': {'tf': 'channels_last', 

488 'th': 'channels_first', 

489 'default': None}}, 

490 preprocessor=convlstm2d_args_preprocessor) 

491 

492legacy_batchnorm_support = generate_legacy_interface( 

493 allowed_positional_args=[], 

494 conversions=[('beta_init', 'beta_initializer'), 

495 ('gamma_init', 'gamma_initializer')], 

496 preprocessor=batchnorm_args_preprocessor) 

497 

498 

499def zeropadding2d_args_preprocessor(args, kwargs): 

500 converted = [] 

501 if 'padding' in kwargs and isinstance(kwargs['padding'], dict): 501 ↛ 502line 501 didn't jump to line 502, because the condition on line 501 was never true

502 if set(kwargs['padding'].keys()) <= {'top_pad', 'bottom_pad', 

503 'left_pad', 'right_pad'}: 

504 top_pad = kwargs['padding'].get('top_pad', 0) 

505 bottom_pad = kwargs['padding'].get('bottom_pad', 0) 

506 left_pad = kwargs['padding'].get('left_pad', 0) 

507 right_pad = kwargs['padding'].get('right_pad', 0) 

508 kwargs['padding'] = ((top_pad, bottom_pad), (left_pad, right_pad)) 

509 warnings.warn('The `padding` argument in the Keras 2 API no longer' 

510 'accepts dict types. You can now input argument as: ' 

511 '`padding=(top_pad, bottom_pad, left_pad, right_pad)`.', 

512 stacklevel=3) 

513 elif len(args) == 2 and isinstance(args[1], dict): 513 ↛ 514line 513 didn't jump to line 514, because the condition on line 513 was never true

514 if set(args[1].keys()) <= {'top_pad', 'bottom_pad', 

515 'left_pad', 'right_pad'}: 

516 top_pad = args[1].get('top_pad', 0) 

517 bottom_pad = args[1].get('bottom_pad', 0) 

518 left_pad = args[1].get('left_pad', 0) 

519 right_pad = args[1].get('right_pad', 0) 

520 args = (args[0], ((top_pad, bottom_pad), (left_pad, right_pad))) 

521 warnings.warn('The `padding` argument in the Keras 2 API no longer' 

522 'accepts dict types. You can now input argument as: ' 

523 '`padding=((top_pad, bottom_pad), (left_pad, right_pad))`', 

524 stacklevel=3) 

525 return args, kwargs, converted 

526 

527legacy_zeropadding2d_support = generate_legacy_interface( 

528 allowed_positional_args=['padding'], 

529 conversions=[('dim_ordering', 'data_format')], 

530 value_conversions={'dim_ordering': {'tf': 'channels_last', 

531 'th': 'channels_first', 

532 'default': None}}, 

533 preprocessor=zeropadding2d_args_preprocessor) 

534 

535legacy_zeropadding3d_support = generate_legacy_interface( 

536 allowed_positional_args=['padding'], 

537 conversions=[('dim_ordering', 'data_format')], 

538 value_conversions={'dim_ordering': {'tf': 'channels_last', 

539 'th': 'channels_first', 

540 'default': None}}) 

541 

542legacy_cropping2d_support = generate_legacy_interface( 

543 allowed_positional_args=['cropping'], 

544 conversions=[('dim_ordering', 'data_format')], 

545 value_conversions={'dim_ordering': {'tf': 'channels_last', 

546 'th': 'channels_first', 

547 'default': None}}) 

548 

549legacy_cropping3d_support = generate_legacy_interface( 

550 allowed_positional_args=['cropping'], 

551 conversions=[('dim_ordering', 'data_format')], 

552 value_conversions={'dim_ordering': {'tf': 'channels_last', 

553 'th': 'channels_first', 

554 'default': None}}) 

555 

556legacy_spatialdropout1d_support = generate_legacy_interface( 

557 allowed_positional_args=['rate'], 

558 conversions=[('p', 'rate')]) 

559 

560legacy_spatialdropoutNd_support = generate_legacy_interface( 

561 allowed_positional_args=['rate'], 

562 conversions=[('p', 'rate'), 

563 ('dim_ordering', 'data_format')], 

564 value_conversions={'dim_ordering': {'tf': 'channels_last', 

565 'th': 'channels_first', 

566 'default': None}}) 

567 

568legacy_lambda_support = generate_legacy_interface( 

569 allowed_positional_args=['function', 'output_shape']) 

570 

571 

572# Model methods 

573 

574def generator_methods_args_preprocessor(args, kwargs): 

575 converted = [] 

576 if len(args) < 3: 

577 if 'samples_per_epoch' in kwargs: 

578 samples_per_epoch = kwargs.pop('samples_per_epoch') 

579 if len(args) > 1: 

580 generator = args[1] 

581 else: 

582 generator = kwargs['generator'] 

583 if hasattr(generator, 'batch_size'): 

584 kwargs['steps_per_epoch'] = samples_per_epoch // generator.batch_size 

585 else: 

586 kwargs['steps_per_epoch'] = samples_per_epoch 

587 converted.append(('samples_per_epoch', 'steps_per_epoch')) 

588 

589 keras1_args = {'samples_per_epoch', 'val_samples', 

590 'nb_epoch', 'nb_val_samples', 'nb_worker'} 

591 if keras1_args.intersection(kwargs.keys()): 

592 warnings.warn('The semantics of the Keras 2 argument ' 

593 '`steps_per_epoch` is not the same as the ' 

594 'Keras 1 argument `samples_per_epoch`. ' 

595 '`steps_per_epoch` is the number of batches ' 

596 'to draw from the generator at each epoch. ' 

597 'Basically steps_per_epoch = samples_per_epoch/batch_size. ' 

598 'Similarly `nb_val_samples`->`validation_steps` and ' 

599 '`val_samples`->`steps` arguments have changed. ' 

600 'Update your method calls accordingly.', stacklevel=3) 

601 

602 return args, kwargs, converted 

603 

604 

605legacy_generator_methods_support = generate_legacy_method_interface( 

606 allowed_positional_args=['generator', 'steps_per_epoch', 'epochs'], 

607 conversions=[('samples_per_epoch', 'steps_per_epoch'), 

608 ('val_samples', 'steps'), 

609 ('nb_epoch', 'epochs'), 

610 ('nb_val_samples', 'validation_steps'), 

611 ('nb_worker', 'workers'), 

612 ('pickle_safe', 'use_multiprocessing'), 

613 ('max_q_size', 'max_queue_size')], 

614 preprocessor=generator_methods_args_preprocessor) 

615 

616 

617legacy_model_constructor_support = generate_legacy_interface( 

618 allowed_positional_args=None, 

619 conversions=[('input', 'inputs'), 

620 ('output', 'outputs')]) 

621 

622legacy_input_support = generate_legacy_interface( 

623 allowed_positional_args=None, 

624 conversions=[('input_dtype', 'dtype')]) 

625 

626 

627def add_weight_args_preprocessing(args, kwargs): 

628 if len(args) > 1: 

629 if isinstance(args[1], (tuple, list)): 

630 kwargs['shape'] = args[1] 

631 args = (args[0],) + args[2:] 

632 if len(args) > 1: 

633 if isinstance(args[1], six.string_types): 

634 kwargs['name'] = args[1] 

635 args = (args[0],) + args[2:] 

636 return args, kwargs, [] 

637 

638 

639legacy_add_weight_support = generate_legacy_interface( 

640 allowed_positional_args=['name', 'shape'], 

641 preprocessor=add_weight_args_preprocessing) 

642 

643 

644def get_updates_arg_preprocessing(args, kwargs): 

645 # Old interface: (params, constraints, loss) 

646 # New interface: (loss, params) 

647 if len(args) > 4: 

648 raise TypeError('`get_update` call received more arguments ' 

649 'than expected.') 

650 elif len(args) == 4: 

651 # Assuming old interface. 

652 opt, params, _, loss = args 

653 kwargs['loss'] = loss 

654 kwargs['params'] = params 

655 return [opt], kwargs, [] 

656 elif len(args) == 3: 

657 if isinstance(args[1], (list, tuple)): 

658 assert isinstance(args[2], dict) 

659 assert 'loss' in kwargs 

660 opt, params, _ = args 

661 kwargs['params'] = params 

662 return [opt], kwargs, [] 

663 return args, kwargs, [] 

664 

665legacy_get_updates_support = generate_legacy_interface( 

666 allowed_positional_args=None, 

667 conversions=[], 

668 preprocessor=get_updates_arg_preprocessing)