diff --git a/ephemetoot/ephemetoot.py b/ephemetoot/ephemetoot.py index b456f8b..9615e2a 100644 --- a/ephemetoot/ephemetoot.py +++ b/ephemetoot/ephemetoot.py @@ -2,6 +2,7 @@ from datetime import date, datetime, timedelta, timezone import json import os +import re import subprocess import sys import time @@ -29,7 +30,14 @@ def compulsory_input(tags, name, example): else: value = input(tags[0] + name + tags[2]) - return value + sanitised = sanitise_input(value, name, tags) + + if len(value) > 0 and (sanitised == "ok" or sanitised == None): + return value + else: + if len(value) > 0 and sanitised != None: + print(sanitised) + value = "" def digit_input(tags, name, example): @@ -53,8 +61,108 @@ def yes_no_input(tags, name): def optional_input(tags, name, example): - value = input(tags[0] + name + tags[1] + example + tags[2]) - return value + + incomplete = True + while incomplete: + value = input(tags[0] + name + tags[1] + example + tags[2]) + sanitised = sanitise_input(value, name, tags) + if len(value) > 0 and (sanitised == "ok" or sanitised == None): + incomplete = False + return value + elif len(value) > 0 and sanitised != None: + print(sanitised) + else: + return "" + + +def sanitise_input(value, input_type, tags): + """ + Check that data entered when running --init complies with requirements + """ + + if input_type == "Username": + return ( + "Do not include '@' in username, please try again" + if value.startswith("@") + else "ok" + ) + + if input_type == "Base URL": + error = value.startswith("http") or value.find(".") == -1 + return ( + "Provide full domain without protocol prefix (e.g. " + + tags[1] + + "example.social" + + tags[2] + + ", not " + + tags[1] + + "http://example.social" + + tags[2] + + ")" + if error + else "ok" + ) + + if input_type == "Toots to keep": + l = value.split(",") + + def check(s): + d = s.strip() + if not d.isdigit(): + return False + + allnum = map(check, l) + return ( + "Toot IDs must be numeric and separated with commas" + if False in list(allnum) + else "ok" + ) + + if input_type == "Hashtags to keep": + l = value.split(",") + + def check(s): + d = s.strip() + if d.isdigit(): + return False + if not re.fullmatch(r"[\w]+", d, flags=re.IGNORECASE): + return False + + complies = map(check, l) + return_string = ( + "Hashtags must not include '#' and must match rules at " + + tags[0] + + "https://docs.joinmastodon.org/user/posting/#hashtags" + + tags[2] + ) + return return_string if False in list(complies) else "ok" + + if input_type == "Visibility to keep": + l = value.split(",") + viz_options = set(["public", "unlisted", "private", "direct"]) + + def check(s): + d = [s.strip().lower()] + intersects = viz_options.intersection(d) + if len(intersects) == 0: + return False + + complies = map(check, l) + return_string = "Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'" + return return_string if False in list(complies) else "ok" + + if input_type == "Archive path": + path = ( + os.path.expanduser(value) + if len(str(value)) > 0 and str(value)[0] == "~" + else value + ) + response = ( + "ok" + if os.path.exists(path) + else "That directory does not exist, please try again" + ) + return response def init(): diff --git a/tests/test_ephemetoot.py b/tests/test_ephemetoot.py index 6e684a9..a13cef2 100644 --- a/tests/test_ephemetoot.py +++ b/tests/test_ephemetoot.py @@ -291,6 +291,94 @@ def test_init(monkeypatch, tmpdir): assert os.path.exists(os.path.join(current_dir, "config.yaml")) +def test_init_archive_path(tmpdir): + + good_path = tmpdir.mkdir("archive_dir") # temporary directory for testing + wrong = ephemetoot.sanitise_input( + os.path.join(good_path, "/bad/path/"), "Archive path", None + ) + ok = ephemetoot.sanitise_input(good_path, "Archive path", None) + also_ok = ephemetoot.sanitise_input("~/Desktop", "Archive path", None) + + assert ok == "ok" + assert wrong == "That directory does not exist, please try again" + + +def test_init_sanitise_id_list(): + tags = ("\033[96m", "\033[2m", "\033[0m") + wrong = ephemetoot.sanitise_input( + "987654321, toot_id_number", "Toots to keep", tags + ) + also_wrong = ephemetoot.sanitise_input("toot_id_number", "Toots to keep", tags) + ok = ephemetoot.sanitise_input("1234598745, 999933335555", "Toots to keep", tags) + also_ok = ephemetoot.sanitise_input("1234598745", "Toots to keep", tags) + + assert wrong == "Toot IDs must be numeric and separated with commas" + assert also_wrong == "Toot IDs must be numeric and separated with commas" + assert ok == "ok" + assert also_ok == "ok" + + +def test_init_sanitise_tag_list(): + tags = ("\033[96m", "\033[2m", "\033[0m") + wrong = ephemetoot.sanitise_input("#tag, another_tag", "Hashtags to keep", tags) + also_wrong = ephemetoot.sanitise_input("tag, another tag", "Hashtags to keep", tags) + still_wrong = ephemetoot.sanitise_input("tag, 12345", "Hashtags to keep", tags) + ok = ephemetoot.sanitise_input("tag123, another_TAG", "Hashtags to keep", tags) + also_ok = ephemetoot.sanitise_input("single_tag", "Hashtags to keep", tags) + + error = ( + "Hashtags must not include '#' and must match rules at " + + tags[0] + + "https://docs.joinmastodon.org/user/posting/#hashtags" + + tags[2] + ) + + assert ok == "ok" + assert also_ok == "ok" + assert wrong == error + assert also_wrong == error + assert still_wrong == error + + +def test_init_sanitise_url(): + tags = ("\033[96m", "\033[2m", "\033[0m") + wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags) + also_wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags) + ok = ephemetoot.sanitise_input("example.social", "Base URL", tags) + + assert ( + wrong + == "Provide full domain without protocol prefix (e.g. \033[2mexample.social\033[0m, not \033[2mhttp://example.social\033[0m)" + ) + assert ok == "ok" + + +def test_init_sanitise_username(): + tags = ("\033[96m", "\033[2m", "\033[0m") + wrong = ephemetoot.sanitise_input("@alice", "Username", tags) + ok = ephemetoot.sanitise_input("alice", "Username", tags) + + assert wrong == "Do not include '@' in username, please try again" + assert ok == "ok" + + +def test_init_sanitise_visibility_list(): + tags = ("\033[96m", "\033[2m", "\033[0m") + wrong = ephemetoot.sanitise_input("nonexistent", "Visibility to keep", tags) + also_wrong = ephemetoot.sanitise_input("direct public", "Visibility to keep", tags) + ok = ephemetoot.sanitise_input("direct", "Visibility to keep", tags) + also_ok = ephemetoot.sanitise_input("direct, public", "Visibility to keep", tags) + + error = ( + "Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'" + ) + assert ok == "ok" + assert also_ok == "ok" + assert wrong == error + assert also_wrong == error + + def test_jsondefault(): d = ephemetoot.jsondefault(toot.created_at) assert d == "2020-05-09T02:17:18.598000+00:00"