英文:
Adding a new layer to a stax.serial object
问题
def mlp(L, n_list, activation, Cb, Cw):
init_fun, apply_fun = stax.serial(
stax.Dense(n_list[1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[0])), b_init=stax.random_normal(stddev=math.sqrt(Cb))),
)
for l in range(1, L):
init_fun, apply_fun = stax.serial(
init_fun,
activation,
stax.Dense(n_list[l+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[l])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
)
init_fun, apply_fun = stax.serial(
init_fun,
stax.Dense(n_list[L+1], W_init=stax.random_normal(0, math.sqrt(Cw/n_list[L])), b_init=stax.random_normal(stddev=math.sqrt(Cb)))
model = stax.selu(init_fun, apply_fun)
return model
英文:
I'd like to "convert" the following tensorflow code in jax:
def mlp(L, n_list, activation, Cb, Cw):
model = tf.keras.Sequential()
kernel_initializers_list = []
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
for l in range(1, L):
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))
model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
bias_initializer = bias_initializer))
for l in range(1, L):
model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
bias_initializer = bias_initializer))
model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
bias_initializer = bias_initializer))
print(model.summary())
return model
In jax can I add a stax.Dense
to the thing I get calling stax.serial()
with something equivalent to tensorflow's model.add()
? How can I do it?
答案1
得分: 0
是的,你可以。
# 通过jax创建新模型
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'),
Relu,
Conv(64, (3, 3), padding='SAME'),
Relu,
Conv(128, (3, 3), padding='SAME'),
Relu,
Conv(256, (3, 3), padding='SAME'),
Relu,
MaxPool((2, 2)),
Flatten,
Dense(128),
Relu,
Dense(10),
LogSoftmax,
)
net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
# 前馈
inputs, targets = batch_data
net_apply(params, inputs)
这是我帮助你的参考资料。
英文:
Yes, you can.
#Create new model by jax
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'),
Relu,
Conv(64, (3, 3), padding='SAME'),
Relu,
Conv(128, (3, 3), padding='SAME'),
Relu,
Conv(256, (3, 3), padding='SAME'),
Relu,
MaxPool((2, 2)),
Flatten,
Dense(128),
Relu,
Dense(10),
LogSoftmax,
)
net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
#Feedfoward
inputs, targets = batch_data
net_apply(params, inputs)
This is my reference to help you.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论