1 /// based on https://github.com/openai/gym-http-api/blob/master/binding-rust/src/lib.rs 2 3 module gym; 4 import std.array : array; 5 import std.algorithm : map; 6 import std.exception : enforce; 7 import std.json : JSONValue, parseJSON, JSON_TYPE; 8 import std.conv : to; 9 import std.stdio : writeln; 10 import std.variant : Algebraic, This; 11 import std.net.curl; 12 13 enum bool isSpace(T) = is(typeof({ T.from(JSONValue.init); })); 14 15 /// Discrete Space 16 struct Discrete { 17 long n; 18 alias n this; 19 20 static from(T: JSONValue)(scope auto ref T info) { 21 enforce(info["name"].str == "Discrete"); 22 return Discrete(info["n"].integer); 23 } 24 } 25 26 /// 27 version (gym_test) unittest { 28 static assert(isSpace!Discrete); 29 30 auto s = `{ 31 "name": "Discrete", 32 "n": 123 33 }`; 34 auto j = s.parseJSON; 35 assert(Discrete.from(j) == Discrete(123)); 36 } 37 38 /// Box Space 39 struct Box { 40 long[] shape; 41 double[] high; 42 double[] low; 43 44 static from(T: JSONValue)(scope auto ref T info) { 45 enforce(info["name"].str == "Box"); 46 return Box( 47 info["shape"].array.map!(a => a.integer).array.dup, 48 info["high"].array.map!(a => a.floating).array.dup, 49 info["low"].array.map!(a => a.floating).array.dup 50 ); 51 } 52 } 53 54 /// 55 version (gym_test) unittest { 56 static assert(isSpace!Box); 57 58 auto s = `{ 59 "name": "Box", 60 "shape": [2], 61 "high": [1.0, 2.0], 62 "low": [0.0, 0.0] 63 }`; 64 auto j = s.parseJSON; 65 assert(Box.from(j) == Box([2], [1.0, 2.0], [0.0, 0.0])); 66 } 67 68 /// 69 struct State { 70 JSONValue info; 71 alias info this; 72 73 @property 74 const ref observation() { 75 return this.info["observation"]; 76 } 77 78 @property 79 double reward() { 80 if ("reward" in this.info) { 81 return this.info["reward"].floating; 82 } else { 83 return 0; 84 } 85 } 86 87 @property 88 bool done() { 89 if ("done" in this.info) { 90 return this.info["done"].type == JSON_TYPE.TRUE; 91 } else { 92 return false; 93 } 94 } 95 } 96 97 /// 98 struct Environment { 99 import std.format : format; 100 101 string address, id; 102 string instanceId; 103 JSONValue actionInfo; 104 JSONValue observationInfo; 105 106 auto _post(T: JSONValue)(scope auto ref T req) const { 107 auto client = HTTP(); 108 client.addRequestHeader("Content-Type", "application/json"); 109 return post(this.address ~ "/v1/envs/", req.toString, client).parseJSON; 110 } 111 112 auto _post(T: JSONValue)(string loc, scope auto ref T req) const { 113 auto client = HTTP(); 114 client.addRequestHeader("Content-Type", "application/json"); 115 return post(this.address ~ "/v1/envs/" ~ this.instanceId ~ "/%s/".format(loc), 116 req.toString, client).parseJSON; 117 } 118 119 auto _get(string loc) const { 120 return get(this.address ~ "/v1/envs/" ~ this.instanceId ~ "/%s/".format(loc)) 121 .parseJSON; 122 } 123 124 this(string address, string id) { 125 this.address = address; 126 this.id = id; 127 this.instanceId = this._post(JSONValue(["env_id": this.id]))["instance_id"].str; 128 this.observationInfo = this._get("observation_space")["info"]; 129 this.actionInfo = this._get("action_space")["info"]; 130 } 131 132 auto reset() { 133 return State(this._post("reset", JSONValue(null))); 134 } 135 136 /// step by json action e.g., 0, [1.0, 2.0, ...], etc 137 auto step(A)(A action, bool render = false) { 138 JSONValue req; 139 req["render"] = render; 140 req["action"] = action; 141 auto ret = this._post("step", req); 142 return State(ret); 143 } 144 145 auto record(string dir, bool force = true, bool resume = false) { 146 JSONValue req; 147 req["directory"] = dir; 148 req["force"] = force; 149 req["resume"] = resume; 150 return this._post("monitor/start", req); 151 } 152 153 auto stop() { 154 return this._post("monitor/close", JSONValue()); 155 } 156 157 auto upload(string dir, string apiKey, string algorithmId) { 158 JSONValue req = [ 159 "training_dir": dir, 160 "api_key": apiKey, 161 "algorithm_id": algorithmId 162 ]; 163 auto client = HTTP(); 164 client.addRequestHeader("Content-Type", "application/json"); 165 return post(this.address ~ "/v1/upload/", 166 req.toString, client).parseJSON; 167 } 168 } 169 170 /// simple integration test 171 version (gym_test) unittest { 172 { 173 auto env = Environment("127.0.0.1:5000", "CartPole-v0"); 174 assert(Discrete.from(env.actionInfo) == 2); 175 auto o = Box.from(env.observationInfo); 176 assert(o.shape == [4]); 177 assert(o.low.length == 4); 178 env.record("/tmp/d-gym"); 179 scope(exit) env.stop(); 180 181 auto state = env.reset; 182 double reward = 0; 183 while (!state.done) { 184 state = env.step(Discrete(0)); 185 reward += state.reward; 186 } 187 assert(reward > 0); 188 } 189 // { 190 // auto env = Environment("127.0.0.1:5000", "MsPacman-v0"); 191 // assert(Discrete.from(env.actionInfo) == 9); 192 // auto o = Box.from(env.observationInfo); 193 // assert(o.shape == [210, 160, 3]); 194 // assert(o.high.length == 210 * 160 * 3); 195 // auto a = Discrete.from(env.actionInfo); 196 // assert(a == 9); 197 // } 198 }