47 #include <boost/functional/hash.hpp>
48 #include <boost/archive/text_oarchive.hpp>
49 #include <boost/archive/text_iarchive.hpp>
50 #include <boost/serialization/vector.hpp>
52 using GymObservation =
struct GymObservation {
53 friend class boost::serialization::access;
64 GymObservation(
void) {
70 bool operator==(GymObservation
const & other)
const
72 if (
third.size() != other.third.size()) {
76 for (
size_t index = 0; index <
third.size(); index++) {
77 isEqual = isEqual && (
third[index] == other.third[index]);
79 return (first == other.first &&
84 bool operator<(GymObservation
const & other)
const
87 if (
third.size() < other.third.size()) {
89 }
else if (
third.size() > other.third.size()) {
92 for (
size_t index = 0; index <
third.size(); index++) {
93 if (
third[index] < other.third[index]) {
99 return (first < other.first) ||
100 (first == other.first &&
second < other.second) ||
101 (first == other.first &&
second == other.second &&
trackerState < other.trackerState) ||
102 (first == other.first &&
second == other.second &&
trackerState == other.trackerState && thirdL);
104 template<
class Archive>
105 void serialize(Archive & ar,
const unsigned int version) {
115 std::size_t operator()(GymObservation
const & observation)
const
117 std::size_t seed = 0;
118 boost::hash_combine(seed, boost::hash_value(observation.first));
119 boost::hash_combine(seed, boost::hash_value(observation.second));
120 boost::hash_combine(seed, boost::hash_value(observation.trackerState));
121 boost::hash_combine(seed, boost::hash_value(observation.third));
127 bool operator()(GymObservation
const & first, GymObservation
const &
second)
const
135 using GymAction =
struct GymAction {
136 friend class boost::serialization::access;
138 std::pair<Action, State> first;
140 std::vector<State>
third;
142 GymAction(std::pair<Action, State> first, State
second, std::vector<State>
third={}) : first(first),
152 bool operator==(GymAction
const & other)
const {
153 if (
third.size() != other.third.size()) {
157 for (
size_t index = 0; index <
third.size(); index++) {
158 isEqual = isEqual && (
third[index] == other.third[index]);
160 return (first == other.first &&
second == other.second && isEqual);
163 bool operator<(GymAction
const & other)
const {
165 if (
third.size() < other.third.size()) {
167 }
else if (
third.size() > other.third.size()) {
170 for (
size_t index = 0; index <
third.size(); index++) {
171 if (
third[index] < other.third[index]) {
177 return (first < other.first) ||
178 (first == other.first &&
second < other.second) ||
179 (first == other.first &&
second == other.second && thirdL);
182 template<
class Archive>
183 void serialize(Archive & ar,
const unsigned int version) {
191 using GymTransitionTuple =
struct GymTransitionTuple {
195 GymObservation sPrime;
199 static State constexpr invalidState = ~0;
200 static std::pair<Action, State> constexpr invalidAction = {-1, -1};
201 static GymObservation
const terminalState = {(
Node)-1, (State)-1, (
short int)-1};
203 using Qtype = std::unordered_map<GymObservation, std::map<GymAction, double>,
GymObservationHasher>;
210 tolerance = std::vector<double>{0.01};
213 concatActionsInCSV =
false;
233 unsigned int episodeLength;
236 std::vector<double> tolerance;
246 bool concatActionsInCSV;
257 GymObservation observation;
261 std::vector<GymAction> actions;
263 bool discountOverride =
false;
264 double discount = 0.0;
265 bool terminationOverride =
false;
277 void printDotLearn(Qtype
const & Q, std::string filename = std::string(
"-"))
const;
281 void printPrismLearn(Qtype
const & Q, std::string filename = std::string(
"-"))
const;
287 std::map<double, std::map<unsigned int, double>>
getProbabilityOfSat(Qtype
const & Q,
bool statsOn)
const;
298 Model QtoModel(Qtype
const & Q,
double tolerance,
bool p1strategic,
bool statsOn,
double priEpsilon=0.001,
unsigned int objIndex=0)
const;
300 std::vector<GymTransitionTuple>
getParallelUpdates(GymObservation S, GymAction A, GymObservation sPrime);
308 GymInfo
step(GymAction);
316 std::string
toString(GymObservation
const & observation)
const;
318 std::string
toString(GymAction
const & action)
const;
320 std::string
toString(Qtype
const & Q)
const;
322 void saveQ(Qtype
const & Q, std::string saveType, std::string filename)
const;
324 void saveQ(Qtype
const & Q1, Qtype
const & Q2, std::string saveType, std::string filename)
const;
326 void saveQ(Qtype
const & Q,
double Rbar, std::string saveType, std::string filename)
const;
328 void loadQ(Qtype & Q, std::string saveType, std::string filename)
const;
330 void loadQ(Qtype & Q1, Qtype & Q2, std::string saveType, std::string filename)
const;
332 void loadQ(Qtype & Q,
double & Rbar, std::string saveType, std::string filename)
const;
334 void saveStrat(Qtype & Q, std::string filename)
const;
336 void saveStratBDP(Qtype & Q, std::string filename)
const;
342 std::size_t getProductID(
void)
const;
346 std::vector<Parity> exauto;
348 unsigned int episodeStep;
349 unsigned int episodeLength;
354 std::vector<double> tolerance;
356 Priority maxPriority;
361 double cumulativeRew;
364 bool noTerminalUpdate;
366 bool concatActionsInCSV;
370 std::vector<State> exAutoState;