Marc-Alexandre Côté commited on
Commit
4930b0c
1 Parent(s): 00e376e

Upload ScienceWorld demo

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +89 -0
  3. packages.txt +2 -0
  4. requirements.txt +1 -2
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: ScienceWorld
3
- emoji: 💩
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.9.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: ScienceWorld
3
+ emoji: 👨‍🔬🔬
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: streamlit
7
+ sdk_version: 1.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+
4
+ from scienceworld import ScienceWorldEnv
5
+
6
+ description = """
7
+ [Project Page](https://sciworld.apps.allenai.org) | [ArXiv Paper](https://arxiv.org/abs/2203.07540) | [Github Repo](https://github.com/allenai/ScienceWorld)
8
+ """
9
+ st.title("ScienceWorld Demo")
10
+ st.markdown(description)
11
+
12
+ env = st.session_state.get("env")
13
+ if env is None:
14
+ env = ScienceWorldEnv("")
15
+ st.session_state["env"] = env
16
+
17
+ seed = st.session_state.get("seed")
18
+ obs = st.session_state.get("obs")
19
+ infos = st.session_state.get("infos")
20
+ history = st.session_state.get("history")
21
+ if history is None:
22
+ history = []
23
+ st.session_state["history"] = history
24
+
25
+ def clear_history():
26
+ history.clear()
27
+
28
+
29
+ with st.sidebar:
30
+ st.title("ScienceWorld Demo")
31
+ st.markdown(description)
32
+ task = st.selectbox("Task:", env.getTaskNames(), on_change=clear_history)
33
+
34
+ if len(history) == 0:
35
+ env.load(task, 0, "")
36
+ obs, infos = env.reset()
37
+ st.session_state["obs"] = obs
38
+ st.session_state["infos"] = infos
39
+ history.append(("", env.getTaskDescription()))
40
+ history.append(("look around", obs))
41
+
42
+ def step():
43
+ act = st.session_state.action
44
+ if act:
45
+ obs, reward, done, infos = env.step(act)
46
+ history.append((act, obs))
47
+ st.session_state["obs"] = obs
48
+ st.session_state["infos"] = infos
49
+
50
+ if act == "reset":
51
+ clear_history()
52
+
53
+
54
+ with st.sidebar:
55
+ st.warning(env.getTaskDescription())
56
+ st.success(f"Score: {infos['score']}")
57
+
58
+ valid_actions = [""] + sorted(infos["valid"])
59
+ if infos['score'] == 100:
60
+ valid_actions = ["", "reset"]
61
+
62
+ act = st.selectbox('Action:', options=valid_actions, index=0, on_change=step, key="action")
63
+
64
+ for act, obs in history:
65
+ if act:
66
+ st.write("> " + act)
67
+
68
+ if obs:
69
+ st.info(obs.replace('\n\t', '\n- '))
70
+
71
+
72
+ if infos['score'] == 100:
73
+ with st.sidebar:
74
+ st.balloons()
75
+
76
+ st.success("Congratulations! You have completed the task.")
77
+
78
+
79
+
80
+ # Auto scroll at the bottom of the page.
81
+ components.html(
82
+ f"""
83
+ <p>{st.session_state.obs}</p>
84
+ <script>
85
+ window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
86
+ </script>
87
+ """,
88
+ height=0
89
+ )
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ default-jre
2
+ default-jdk
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
- textworld
2
  scienceworld
3
- jericho
 
1
+ py4j
2
  scienceworld