英文:
Adding a new layer to a stax.serial object
问题
def mlp(L, n_list, activation, Cb, Cw):
    init_fun, apply_fun = stax.serial(
        stax.Dense(n_list[1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[0])), b_init=stax.random_normal(stddev=math.sqrt(Cb))),
    )
    for l in range(1, L):
        init_fun, apply_fun = stax.serial(
            init_fun,
            activation,
            stax.Dense(n_list[l+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[l])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
        )
    init_fun, apply_fun = stax.serial(
        init_fun,
        stax.Dense(n_list[L+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[L])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
    
    model = stax.selu(init_fun, apply_fun)
    return model
英文:
I'd like to "convert" the following tensorflow code in jax:
def mlp(L, n_list, activation, Cb, Cw):
    model = tf.keras.Sequential()
    kernel_initializers_list = []
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
    for l in range(1, L): 
        kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
    bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))
    model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
          bias_initializer = bias_initializer))
    for l in range(1, L): 
        model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
              bias_initializer = bias_initializer))
    model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
              bias_initializer = bias_initializer))
    print(model.summary())
    return model
In jax can I add a stax.Dense to the thing I get calling stax.serial() with something equivalent to tensorflow's model.add()? How can I do it?
答案1
得分: 0
是的,你可以。
# 通过jax创建新模型
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'),
    Relu,
    Conv(64, (3, 3), padding='SAME'),
    Relu,
    Conv(128, (3, 3), padding='SAME'),
    Relu,
    Conv(256, (3, 3), padding='SAME'),
    Relu,
    MaxPool((2, 2)),
    Flatten,
    Dense(128),
    Relu,
    Dense(10),
    LogSoftmax,
)
net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
# 前馈
inputs, targets = batch_data
net_apply(params, inputs)
这是我帮助你的参考资料。
英文:
Yes, you can.
#Create new model by jax
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'),
    Relu,
    Conv(64, (3, 3), padding='SAME'),
    Relu,
    Conv(128, (3, 3), padding='SAME'),
    Relu,
    Conv(256, (3, 3), padding='SAME'),
    Relu,
    MaxPool((2, 2)),
    Flatten,
    Dense(128),
    Relu,
    Dense(10),
    LogSoftmax,
)
net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))    
#Feedfoward
inputs, targets = batch_data
net_apply(params, inputs)
This is my reference to help you.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。


评论