PyTorch 和 JAX 网络在相同的设置下具有不同的准确性。

huangapple go评论65阅读模式
英文:

pytorch and jax networks give different accuracy with same settings

问题

The issue in the JAX code might be related to several factors. Here are some potential reasons why the JAX code is performing poorly compared to the PyTorch code:

  1. Data Preprocessing:

    • In the JAX code, there's an unusual reshaping of the target variables y_train and y_test using jnp.reshape. This may lead to unexpected behavior. Make sure the reshaping is done correctly.
  2. Scaling of Features:

    • In the JAX code, it seems there's an error in scaling the target variables. It should be done on the features instead.
    • Check if X_train_reshaped and X_test_reshaped are correctly scaled.
  3. Loss Function Implementation:

    • The implementation of the cross-entropy loss and softmax in the JAX code might have subtle differences from the PyTorch implementation. Double-check that they are functionally equivalent.
  4. Optimization Algorithm:

    • Different optimizers might have different hyperparameters and behaviors. Ensure that the optimizer used in the JAX code is suitable for the problem.
  5. Random Initialization:

    • The random initialization of weights in the JAX code might not be optimal. Try different initialization strategies or adjust the scale of the random values.
  6. Debugging Information:

    • Print out intermediate values, such as the loss during training, to see if they are behaving as expected.
  7. Batching:

    • The batching process might have a bug. Ensure that the batches are correctly formed and processed.
  8. Activation Function:

    • The activation function used (ReLU) is consistent in both codes, but double-check if there's any subtle difference in implementation.
  9. Check for Nans or Infs:

    • Sometimes, NaNs or Infs in the gradients can lead to poor performance. Use jnp.isnan and jnp.isinf to check for these.
  10. Learning Rate Schedule:

    • Adjusting the learning rate schedule could have an impact on convergence.
  11. Debugging and Profiling:

    • Use JAX's built-in tools for debugging and profiling to identify potential issues.
  12. Comparison of Outputs:

    • Compare the intermediate outputs (logits, probabilities, etc.) from both implementations to see if they are consistent.

Without being able to run the code and see the specific behavior, it's hard to pinpoint the exact issue. I recommend going through the code step by step, comparing it to the PyTorch implementation, and looking for any discrepancies or potential sources of error.

英文:

I have pytorch code which performs with more than 95% accuracy. The code essentially implements a feedforward neural network using PyTorch to classify the digits dataset. It trains the model using the Adam optimizer and computes the cross-entropy loss, and then evaluates the model's performance on the test set by calculating the accuracy.

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load the digits dataset
digits = load_digits()

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, test_size=0.2, random_state=42
)

# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert the data to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# Define the FFN model
class FFN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(FFN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_sizes)):
            if i == 0:
                self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                self.hidden_layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
            self.hidden_layers.append(nn.ReLU())
        self.output_layer = nn.Linear(hidden_sizes[-1], output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

# Define the training parameters
input_size = X_train.shape[1]
hidden_sizes = [64, 32]  # Modify the hidden layer sizes as per your requirement
output_size = len(torch.unique(y_train_tensor))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train)  # Set batch size to the size of the training dataset

# Create the FFN model
model = FFN(input_size, hidden_sizes, output_size)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Evaluate the model on the test set
with torch.no_grad():
    model.eval()
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs.data, 1)
    for j in range(len(predicted)):
        print(predicted[j], y_test_tensor[j])
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0) * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

Also I have the equivalent jax code, with performs with less than 10% of accuracy

import jax
import jax.numpy as jnp
from jax import grad, jit, random, value_and_grad
from jax.scipy.special import logsumexp
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from jax.example_libraries.optimizers import adam, momentum, sgd, nesterov, adagrad, rmsprop
from jax import nn as jnn


# Load the digits dataset
digits = load_digits()

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)

# Reshape the target variables
y_train_reshaped = jnp.reshape(y_train, (-1, 1))
y_test_reshaped = jnp.reshape(y_test, (-1, 1))

X_train_reshaped = jnp.reshape(X_train, (-1, 1))
X_test_reshaped = jnp.reshape(X_test, (-1, 1))
#print(np.shape(X_train),np.shape(y_train_reshaped),np.shape(y_train))

# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_reshaped)
y_test_scaled = scaler.transform(y_test_reshaped)

# Convert the data to JAX arrays
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)

# Define the FFN model
def init_params(rng_key):
    sizes = [X_train_array.shape[1]] + hidden_sizes + [output_size]
    keys = random.split(rng_key, len(sizes))
    params = []
    for i in range(1, len(sizes)):
        params.append((random.normal(keys[i], (sizes[i-1], sizes[i])), 
                       random.normal(keys[i], (sizes[i],))))
    return params

def forward(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    x = jnp.dot(x, w) + b
    return x

def softmax(logits):
    logsumexp_logits = logsumexp(logits, axis=1, keepdims=True)
    return jnp.exp(logits - logsumexp_logits)

def cross_entropy_loss(logits, labels):
    log_probs = logits - logsumexp(logits, axis=1, keepdims=True)
    return -jnp.mean(jnp.sum(log_probs * labels, axis=1))

# Define the training parameters
input_size = X_train_array.shape[1]
hidden_sizes = [64, 32]  # Modify the hidden layer sizes as per your requirement
output_size = len(jnp.unique(y_train_array))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train_array)  # Set batch size to the size of the training dataset
# Create the FFN model
rng_key = random.PRNGKey(0)
params = init_params(rng_key)

# Define the loss function
def loss_fn(params, x, y):
    logits = forward(params, x)
    probs = softmax(logits)
    labels = jax.nn.one_hot(y, output_size)
    return cross_entropy_loss(logits, labels)

# Create the optimizer
opt_init, opt_update, get_params = adam(learning_rate)
opt_state = opt_init(params)

# Define the update step
@jit
def update(params, x, y, opt_state):
    grads = grad(loss_fn)(params, x, y)
    return opt_update(0, grads, opt_state)

# Train the model
for epoch in range(num_epochs):
    perm = random.permutation(rng_key, len(X_train_array))
    for i in range(0, len(X_train_array), batch_size):
        batch_idx = perm[i:i+batch_size]
        X_batch = X_train_array[batch_idx]
        y_batch = y_train_array[batch_idx]
        params = get_params(opt_state)
        opt_state = update(params, X_batch, y_batch, opt_state)

    if (epoch + 1) % 10 == 0:
        params = get_params(opt_state)
        loss = loss_fn(params, X_train_array, y_train_array)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}")

# Evaluate the model on the test set
params = get_params(opt_state)
logits = forward(params, X_test_array)
predicted = jnp.argmax(logits, axis=1)

for j in range(len(predicted)):
    print(predicted[j], y_test_array[j])

accuracy = jnp.mean(predicted == y_test_array) * 100
print(f"Test Accuracy: {accuracy:.2f}%")

I dont understand why the jax code performs poorly. Could you please help me in underding the bug in the jax code.

答案1

得分: 1

你的jax代码中存在两个问题,实际上是在数据处理方面的问题:

  1. 你的数据没有进行标准化。如果你看一下你的X_train_array定义,它是X_train的jax版本,而X_train是原始数据。请考虑使用以下方法进行标准化:
# 对特征进行标准化
scaler = StandardScaler().fit(X_train)  # 无需对其进行扁平化!
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

# 将数据转换为JAX数组
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)
  1. 在进行独热编码之前,你的标签的形状是(N, 1)。而在进行独热编码之后,它变成了(N, 1, n_out),而你的预测结果的形状是(N, n_out),所以在计算损失时,这两个数组被广播成了(N, n_out, n_out),导致了错误的损失计算。你可以通过简单地删除reshape中的1来解决这个问题:
# 重新定义目标变量的形状
y_train_reshaped = jnp.reshape(y_train, (-1,))
y_test_reshaped = jnp.reshape(y_test, (-1,))

我使用了300个epochs和lr=0.01来测试你的代码,得到了90%的测试准确率(损失降到了0.0001)。

英文:

There are 2 probles in your jax code that are, actually, in data processing:

  1. Your data are not scaled. If you look at your X_train_array definition, it is the jax version of X_train, that is the raw data.
    Please consider using:
# Scale the features
scaler = StandardScaler().fit(X_train)  # No need to flat it!
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

# Convert the data to JAX arrays
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)
  1. Your labels are of shape (N, 1) before one-hot encoding. After one-hot encoding it is (N, 1, n_out) while your predictions are of shape (N, n_out) so when you make your loss computation the two arrays are cast in (N, n_out, n_out) with repetitions and your loss is wrong. You can solve it very simply by remove the 1 in the reshape:
# Reshape the target variables
y_train_reshaped = jnp.reshape(y_train, (-1,))
y_test_reshaped = jnp.reshape(y_test, (-1,))

I tested your code with 300 epochs and lr=0.01 and I got an accuracy of 90% in test (and the loss decreased to 0.0001)

huangapple
  • 本文由 发表于 2023年7月17日 18:44:35
  • 转载请务必保留本文链接:https://go.coder-hub.com/76703681.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定