我无法理解“nb_steps_warmup”的含义,它__init__是 Keras_RL 模块的 DQNAgent 类的函数的参数。
我只知道当我为“nb_steps_warmup”设置小值时,命令行会打印:UserWarning: Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!
这是我的代码:
import numpy as np
import gym
import gym_briscola
import argparse
import os
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
import logging
def getModel(input_shape, nb_actions):
# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape = input_shape))
model.add(Dense(nb_actions, activation = 'relu'))
for i in range(2):
model.add(Dense(2, activation = 'relu'))
model.add(Dense(nb_actions, activation = 'relu'))
# print(model.summary())
return model
def init():
ENV_NAME = 'Briscola-v0'
# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
env.setName("Inteligence")
env.cicle = True
nb_actions = env.action_space.n
window_length = 10
input_shape = (window_length, 5)
# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=window_length)
#Maxwell-boltzmann distribution
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=getModel(input_shape, nb_actions), nb_actions=nb_actions, memory=memory, nb_steps_warmup=11,
target_model_update=1e-5, policy=policy, dueling_type='avg', enable_dueling_network=True)
print("Compila")
dqn.compile(Adam(lr=1e-5), metrics=['mae'])
try:
dqn.load_weights('dqn_{}_weights.h5f'.format(ENV_NAME))
except OSError:
print("File non trovato")
return dqn, env, ENV_NAME
def startTraining():
dqn, env, ENV_NAME = init()
print("Fit")
dqn.fit(env, nb_steps=5E6, visualize=False, verbose=1, log_interval=1000)
# After training is done, we save the final weights.
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True)
def startTest():
dqn, env, _ = init()
print("Test")
# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=10, visualize=False)
#Log config
def setLogging(show = True):
logging.getLogger("Briscola").propagate = show
logging.getLogger("IA").propagate = True
logging.getLogger("Client").propagate = show
logging.getLogger("Vincitore").propagate = show
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
#Parameter settings
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--modality", help="The modality of the program", const="train", nargs='?')
parser.add_argument("-l", "--logging", help="Enable logging", type=bool, const=False, nargs='?')
args = parser.parse_args()
setLogging(True)
print("Avvio modalita' ", args.modality)
if args.modality == "test":
startTest()
else:
startTraining()
print("Fine")
这是模块的文档:https ://keras-rl.readthedocs.io/en/latest/agents/dqn/
我希望我的英语很清楚。