1 /// based on https://github.com/openai/gym-http-api/blob/master/binding-rust/src/lib.rs
2 
3 module gym;
4 import std.array : array;
5 import std.algorithm : map;
6 import std.exception : enforce;
7 import std.json : JSONValue, parseJSON, JSON_TYPE;
8 import std.conv : to;
9 import std.stdio : writeln;
10 import std.variant : Algebraic, This;
11 import std.net.curl;
12 
13 enum bool isSpace(T) = is(typeof({ T.from(JSONValue.init); }));
14 
15 /// Discrete Space
16 struct Discrete {
17     long n;
18     alias n this;
19 
20     static from(T: JSONValue)(scope auto ref T info) {
21         enforce(info["name"].str == "Discrete");
22         return Discrete(info["n"].integer);
23     }
24 }
25 
26 ///
27 version (gym_test) unittest {
28     static assert(isSpace!Discrete);
29 
30     auto s = `{
31    "name": "Discrete",
32    "n": 123
33 }`;
34     auto j = s.parseJSON;
35     assert(Discrete.from(j) == Discrete(123));
36 }
37 
38 /// Box Space
39 struct Box {
40     long[] shape;
41     double[] high;
42     double[] low;
43 
44     static from(T: JSONValue)(scope auto ref T info) {
45         enforce(info["name"].str == "Box");
46         return Box(
47             info["shape"].array.map!(a => a.integer).array.dup,
48             info["high"].array.map!(a => a.floating).array.dup,
49             info["low"].array.map!(a => a.floating).array.dup
50             );
51     }
52 }
53 
54 ///
55 version (gym_test) unittest {
56     static assert(isSpace!Box);
57 
58     auto s = `{
59    "name": "Box",
60    "shape": [2],
61    "high": [1.0, 2.0],
62    "low": [0.0, 0.0]
63 }`;
64     auto j = s.parseJSON;
65     assert(Box.from(j) == Box([2], [1.0, 2.0], [0.0, 0.0]));
66 }
67 
68 ///
69 struct State {
70     JSONValue info;
71     alias info this;
72 
73     @property
74     const ref observation() {
75         return this.info["observation"];
76     }
77 
78     @property
79     double reward() {
80         if ("reward" in this.info) {
81             return this.info["reward"].floating;
82         } else {
83             return 0;
84         }
85     }
86 
87     @property
88     bool done() {
89         if ("done" in this.info) {
90             return this.info["done"].type == JSON_TYPE.TRUE;
91         } else {
92             return false;
93         }
94     }
95 }
96 
97 ///
98 struct Environment {
99     import std.format : format;
100 
101     string address, id;
102     string instanceId;
103     JSONValue actionInfo;
104     JSONValue observationInfo;
105 
106     auto _post(T: JSONValue)(scope auto ref T req) const {
107         auto client = HTTP();
108         client.addRequestHeader("Content-Type", "application/json");
109         return post(this.address ~ "/v1/envs/", req.toString, client).parseJSON;
110     }
111 
112     auto _post(T: JSONValue)(string loc, scope auto ref T req) const {
113         auto client = HTTP();
114         client.addRequestHeader("Content-Type", "application/json");
115         return post(this.address ~ "/v1/envs/" ~ this.instanceId ~ "/%s/".format(loc),
116                     req.toString, client).parseJSON;
117     }
118 
119     auto _get(string loc) const {
120         return get(this.address ~ "/v1/envs/" ~ this.instanceId ~ "/%s/".format(loc))
121             .parseJSON;
122     }
123 
124     this(string address, string id) {
125         this.address = address;
126         this.id = id;
127         this.instanceId = this._post(JSONValue(["env_id": this.id]))["instance_id"].str;
128         this.observationInfo = this._get("observation_space")["info"];
129         this.actionInfo = this._get("action_space")["info"];
130     }
131 
132     auto reset() {
133         return State(this._post("reset", JSONValue(null)));
134     }
135 
136     /// step by json action e.g., 0, [1.0, 2.0, ...], etc
137     auto step(A)(A action, bool render = false) {
138         JSONValue req;
139         req["render"] = render;
140         req["action"] = action;
141         auto ret = this._post("step", req);
142         return State(ret);
143     }
144 
145     auto record(string dir, bool force = true, bool resume = false) {
146         JSONValue req;
147         req["directory"] = dir;
148         req["force"] = force;
149         req["resume"] = resume;
150         return this._post("monitor/start", req);
151     }
152 
153     auto stop() {
154         return this._post("monitor/close", JSONValue());
155     }
156 
157     auto upload(string dir, string apiKey, string algorithmId) {
158         JSONValue req = [
159             "training_dir": dir,
160             "api_key": apiKey,
161             "algorithm_id": algorithmId
162             ];
163         auto client = HTTP();
164         client.addRequestHeader("Content-Type", "application/json");
165         return post(this.address ~ "/v1/upload/",
166                     req.toString, client).parseJSON;
167     }
168 }
169 
170 /// simple integration test
171 version (gym_test) unittest {
172     {
173         auto env = Environment("127.0.0.1:5000", "CartPole-v0");
174         assert(Discrete.from(env.actionInfo) == 2);
175         auto o = Box.from(env.observationInfo);
176         assert(o.shape == [4]);
177         assert(o.low.length == 4);
178         env.record("/tmp/d-gym");
179         scope(exit) env.stop();
180 
181         auto state = env.reset;
182         double reward = 0;
183         while (!state.done) {
184             state = env.step(Discrete(0));
185             reward += state.reward;
186         }
187         assert(reward > 0);
188     }
189     // {
190     //     auto env = Environment("127.0.0.1:5000", "MsPacman-v0");
191     //     assert(Discrete.from(env.actionInfo) == 9);
192     //     auto o = Box.from(env.observationInfo);
193     //     assert(o.shape == [210, 160, 3]);
194     //     assert(o.high.length == 210 * 160 * 3);
195     //     auto a = Discrete.from(env.actionInfo);
196     //     assert(a == 9);
197     // }
198 }