向stax.serial对象添加新层。

huangapple go评论186阅读模式
英文:

Adding a new layer to a stax.serial object

问题

  1. def mlp(L, n_list, activation, Cb, Cw):
  2. init_fun, apply_fun = stax.serial(
  3. stax.Dense(n_list[1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[0])), b_init=stax.random_normal(stddev=math.sqrt(Cb))),
  4. )
  5. for l in range(1, L):
  6. init_fun, apply_fun = stax.serial(
  7. init_fun,
  8. activation,
  9. stax.Dense(n_list[l+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[l])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
  10. )
  11. init_fun, apply_fun = stax.serial(
  12. init_fun,
  13. stax.Dense(n_list[L+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[L])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
  14. model = stax.selu(init_fun, apply_fun)
  15. return model
英文:

I'd like to "convert" the following tensorflow code in jax:

  1. def mlp(L, n_list, activation, Cb, Cw):
  2. model = tf.keras.Sequential()
  3. kernel_initializers_list = []
  4. kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
  5. for l in range(1, L):
  6. kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
  7. kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
  8. bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))
  9. model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
  10. bias_initializer = bias_initializer))
  11. for l in range(1, L):
  12. model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
  13. bias_initializer = bias_initializer))
  14. model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
  15. bias_initializer = bias_initializer))
  16. print(model.summary())
  17. return model

In jax can I add a stax.Dense to the thing I get calling stax.serial() with something equivalent to tensorflow's model.add()? How can I do it?

答案1

得分: 0

是的,你可以。

  1. # 通过jax创建新模型
  2. net_init, net_apply = stax.serial(
  3. Conv(32, (3, 3), padding='SAME'),
  4. Relu,
  5. Conv(64, (3, 3), padding='SAME'),
  6. Relu,
  7. Conv(128, (3, 3), padding='SAME'),
  8. Relu,
  9. Conv(256, (3, 3), padding='SAME'),
  10. Relu,
  11. MaxPool((2, 2)),
  12. Flatten,
  13. Dense(128),
  14. Relu,
  15. Dense(10),
  16. LogSoftmax,
  17. )
  18. net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
  19. # 前馈
  20. inputs, targets = batch_data
  21. net_apply(params, inputs)

这是我帮助你的参考资料

英文:

Yes, you can.

  1. #Create new model by jax
  2. net_init, net_apply = stax.serial(
  3. Conv(32, (3, 3), padding='SAME'),
  4. Relu,
  5. Conv(64, (3, 3), padding='SAME'),
  6. Relu,
  7. Conv(128, (3, 3), padding='SAME'),
  8. Relu,
  9. Conv(256, (3, 3), padding='SAME'),
  10. Relu,
  11. MaxPool((2, 2)),
  12. Flatten,
  13. Dense(128),
  14. Relu,
  15. Dense(10),
  16. LogSoftmax,
  17. )
  18. net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
  19. #Feedfoward
  20. inputs, targets = batch_data
  21. net_apply(params, inputs)

This is my reference to help you.

huangapple
  • 本文由 发表于 2023年3月15日 18:32:33
  • 转载请务必保留本文链接:https://go.coder-hub.com/75743490.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定