向stax.serial对象添加新层。

huangapple go评论133阅读模式
英文:

Adding a new layer to a stax.serial object

问题

def mlp(L, n_list, activation, Cb, Cw):
    init_fun, apply_fun = stax.serial(
        stax.Dense(n_list[1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[0])), b_init=stax.random_normal(stddev=math.sqrt(Cb))),
    )
    for l in range(1, L):
        init_fun, apply_fun = stax.serial(
            init_fun,
            activation,
            stax.Dense(n_list[l+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[l])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
        )
    init_fun, apply_fun = stax.serial(
        init_fun,
        stax.Dense(n_list[L+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[L])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
    
    model = stax.selu(init_fun, apply_fun)
    return model
英文:

I'd like to "convert" the following tensorflow code in jax:

def mlp(L, n_list, activation, Cb, Cw):
    model = tf.keras.Sequential()

    kernel_initializers_list = []
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
    for l in range(1, L): 
        kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
    bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))


    model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
          bias_initializer = bias_initializer))
    for l in range(1, L): 
        model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
              bias_initializer = bias_initializer))
    model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
              bias_initializer = bias_initializer))
    print(model.summary())
    return model

In jax can I add a stax.Dense to the thing I get calling stax.serial() with something equivalent to tensorflow's model.add()? How can I do it?

答案1

得分: 0

是的,你可以。

# 通过jax创建新模型
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'),
    Relu,
    Conv(64, (3, 3), padding='SAME'),
    Relu,
    Conv(128, (3, 3), padding='SAME'),
    Relu,
    Conv(256, (3, 3), padding='SAME'),
    Relu,
    MaxPool((2, 2)),
    Flatten,
    Dense(128),
    Relu,
    Dense(10),
    LogSoftmax,
)

net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))

# 前馈
inputs, targets = batch_data
net_apply(params, inputs)

这是我帮助你的参考资料

英文:

Yes, you can.

#Create new model by jax
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'),
    Relu,
    Conv(64, (3, 3), padding='SAME'),
    Relu,
    Conv(128, (3, 3), padding='SAME'),
    Relu,
    Conv(256, (3, 3), padding='SAME'),
    Relu,
    MaxPool((2, 2)),
    Flatten,
    Dense(128),
    Relu,
    Dense(10),
    LogSoftmax,
)

net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))    

#Feedfoward
inputs, targets = batch_data
net_apply(params, inputs)

This is my reference to help you.

huangapple
  • 本文由 发表于 2023年3月15日 18:32:33
  • 转载请务必保留本文链接:https://go.coder-hub.com/75743490.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定