Skip to content

Watch files over LSP #806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
12 changes: 11 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func (api *API) PositionEncoding() lsproto.PositionEncodingKind {
return lsproto.PositionEncodingKindUTF8
}

// Client implements ProjectHost.
func (api *API) Client() project.Client {
return nil
}

// IsWatchEnabled implements ProjectHost.
func (api *API) IsWatchEnabled() bool {
return false
}

func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, error) {
params, err := unmarshalPayload(method, payload)
if err != nil {
Expand Down Expand Up @@ -351,7 +361,7 @@ func (api *API) getOrCreateScriptInfo(fileName string, path tspath.Path, scriptK
if !ok {
return nil
}
info = project.NewScriptInfo(fileName, path, scriptKind)
info = project.NewScriptInfo(fileName, path, scriptKind, api.host.FS())
info.SetTextFromDisk(content)
api.scriptInfosMu.Lock()
defer api.scriptInfosMu.Unlock()
Expand Down
22 changes: 21 additions & 1 deletion internal/lsp/lsproto/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type ID struct {
int int32
}

func NewIDString(str string) *ID {
return &ID{str: str}
}

func (id *ID) MarshalJSON() ([]byte, error) {
if id.str != "" {
return json.Marshal(id.str)
Expand All @@ -43,6 +47,13 @@ func (id *ID) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &id.int)
}

func (id *ID) TryInt() (int32, bool) {
if id == nil || id.str != "" {
return 0, false
}
return id.int, true
}

func (id *ID) MustInt() int32 {
if id.str != "" {
panic("ID is not an integer")
Expand All @@ -54,11 +65,20 @@ func (id *ID) MustInt() int32 {

type RequestMessage struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
ID *ID `json:"id"`
ID *ID `json:"id,omitempty"`
Method Method `json:"method"`
Params any `json:"params"`
}

func NewRequestMessage(method Method, id *ID, params any) *RequestMessage {
return &RequestMessage{
JSONRPC: JSONRPCVersion{},
ID: id,
Method: method,
Params: params,
}
}

func (r *RequestMessage) UnmarshalJSON(data []byte) error {
var raw struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
Expand Down
130 changes: 124 additions & 6 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,22 @@ func NewServer(opts *ServerOptions) *Server {
newLine: opts.NewLine,
fs: opts.FS,
defaultLibraryPath: opts.DefaultLibraryPath,
watchers: make(map[project.WatcherHandle]struct{}),
}
}

var _ project.ServiceHost = (*Server)(nil)
var (
_ project.ServiceHost = (*Server)(nil)
_ project.Client = (*Server)(nil)
)

type Server struct {
r *lsproto.BaseReader
w *lsproto.BaseWriter

stderr io.Writer

clientSeq int32
requestMethod string
requestTime time.Time

Expand All @@ -62,36 +67,100 @@ type Server struct {
initializeParams *lsproto.InitializeParams
positionEncoding lsproto.PositionEncodingKind

watcheEnabled bool
watcherID int
watchers map[project.WatcherHandle]struct{}
logger *project.Logger
projectService *project.Service
converters *ls.Converters
}

// FS implements project.ProjectServiceHost.
// FS implements project.ServiceHost.
func (s *Server) FS() vfs.FS {
return s.fs
}

// DefaultLibraryPath implements project.ProjectServiceHost.
// DefaultLibraryPath implements project.ServiceHost.
func (s *Server) DefaultLibraryPath() string {
return s.defaultLibraryPath
}

// GetCurrentDirectory implements project.ProjectServiceHost.
// GetCurrentDirectory implements project.ServiceHost.
func (s *Server) GetCurrentDirectory() string {
return s.cwd
}

// NewLine implements project.ProjectServiceHost.
// NewLine implements project.ServiceHost.
func (s *Server) NewLine() string {
return s.newLine.GetNewLineCharacter()
}

// Trace implements project.ProjectServiceHost.
// Trace implements project.ServiceHost.
func (s *Server) Trace(msg string) {
s.Log(msg)
}

// Client implements project.ServiceHost.
func (s *Server) Client() project.Client {
if !s.watcheEnabled {
return nil
}
return s
}

// WatchFiles implements project.Client.
func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) {
watcherId := fmt.Sprintf("watcher-%d", s.watcherID)
if err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{
Registrations: []*lsproto.Registration{
{
Id: watcherId,
Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles),
RegisterOptions: ptrTo(any(lsproto.DidChangeWatchedFilesRegistrationOptions{
Watchers: watchers,
})),
},
},
}); err != nil {
return "", fmt.Errorf("failed to register file watcher: %w", err)
}

handle := project.WatcherHandle(watcherId)
s.watchers[handle] = struct{}{}
s.watcherID++
return handle, nil
}

// UnwatchFiles implements project.Client.
func (s *Server) UnwatchFiles(handle project.WatcherHandle) error {
if _, ok := s.watchers[handle]; ok {
if err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{
Unregisterations: []*lsproto.Unregistration{
{
Id: string(handle),
Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles),
},
},
}); err != nil {
return fmt.Errorf("failed to unregister file watcher: %w", err)
}
delete(s.watchers, handle)
return nil
}

return fmt.Errorf("no file watcher exists with ID %s", handle)
}

// RefreshDiagnostics implements project.Client.
func (s *Server) RefreshDiagnostics() error {
if ptrIsTrue(s.initializeParams.Capabilities.Workspace.Diagnostics.RefreshSupport) {
if err := s.sendRequest(lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil {
return fmt.Errorf("failed to refresh diagnostics: %w", err)
}
}
return nil
}

func (s *Server) Run() error {
for {
req, err := s.read()
Expand All @@ -105,6 +174,11 @@ func (s *Server) Run() error {
return err
}

// TODO: handle response messages
if req == nil {
continue
}

if s.initializeParams == nil {
if req.Method == lsproto.MethodInitialize {
if err := s.handleInitialize(req); err != nil {
Expand Down Expand Up @@ -132,12 +206,37 @@ func (s *Server) read() (*lsproto.RequestMessage, error) {

req := &lsproto.RequestMessage{}
if err := json.Unmarshal(data, req); err != nil {
res := &lsproto.ResponseMessage{}
if err = json.Unmarshal(data, res); err == nil {
// !!! TODO: handle response
return nil, nil
}
return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err)
}

return req, nil
}

func (s *Server) sendRequest(method lsproto.Method, params any) error {
s.clientSeq++
id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq))
req := lsproto.NewRequestMessage(method, id, params)
data, err := json.Marshal(req)
if err != nil {
return err
}
return s.w.Write(data)
}

func (s *Server) sendNotification(method lsproto.Method, params any) error {
req := lsproto.NewRequestMessage(method, nil /*id*/, params)
data, err := json.Marshal(req)
if err != nil {
return err
}
return s.w.Write(data)
}

func (s *Server) sendResult(id *lsproto.ID, result any) error {
return s.sendResponse(&lsproto.ResponseMessage{
ID: id,
Expand Down Expand Up @@ -189,6 +288,8 @@ func (s *Server) handleMessage(req *lsproto.RequestMessage) error {
return s.handleDidSave(req)
case *lsproto.DidCloseTextDocumentParams:
return s.handleDidClose(req)
case *lsproto.DidChangeWatchedFilesParams:
return s.handleDidChangeWatchedFiles(req)
case *lsproto.DocumentDiagnosticParams:
return s.handleDocumentDiagnostic(req)
case *lsproto.HoverParams:
Expand Down Expand Up @@ -262,9 +363,14 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error {
}

func (s *Server) handleInitialized(req *lsproto.RequestMessage) error {
if s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles != nil && *s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles.DynamicRegistration {
s.watcheEnabled = true
}

s.logger = project.NewLogger([]io.Writer{s.stderr}, "" /*file*/, project.LogLevelVerbose)
s.projectService = project.NewService(s, project.ServiceOptions{
Logger: s.logger,
WatchEnabled: s.watcheEnabled,
PositionEncoding: s.positionEncoding,
})

Expand Down Expand Up @@ -322,6 +428,11 @@ func (s *Server) handleDidClose(req *lsproto.RequestMessage) error {
return nil
}

func (s *Server) handleDidChangeWatchedFiles(req *lsproto.RequestMessage) error {
params := req.Params.(*lsproto.DidChangeWatchedFilesParams)
return s.projectService.OnWatchedFilesChanged(params.Changes)
}

func (s *Server) handleDocumentDiagnostic(req *lsproto.RequestMessage) error {
params := req.Params.(*lsproto.DocumentDiagnosticParams)
file, project := s.getFileAndProject(params.TextDocument.Uri)
Expand Down Expand Up @@ -445,3 +556,10 @@ func codeFence(lang string, code string) string {
func ptrTo[T any](v T) *T {
return &v
}

func ptrIsTrue(v *bool) bool {
if v == nil {
return false
}
return *v
}
6 changes: 3 additions & 3 deletions internal/project/documentregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (r *DocumentRegistry) getDocumentWorker(
if entry, ok := r.documents.Load(key); ok {
// We have an entry for this file. However, it may be for a different version of
// the script snapshot. If so, update it appropriately.
if entry.sourceFile.Version != scriptInfo.version {
if entry.sourceFile.Version != scriptInfo.Version() {
sourceFile := parser.ParseSourceFile(scriptInfo.fileName, scriptInfo.path, scriptInfo.text, scriptTarget, scanner.JSDocParsingModeParseAll)
sourceFile.Version = scriptInfo.version
sourceFile.Version = scriptInfo.Version()
entry.mu.Lock()
defer entry.mu.Unlock()
entry.sourceFile = sourceFile
Expand All @@ -104,7 +104,7 @@ func (r *DocumentRegistry) getDocumentWorker(
} else {
// Have never seen this file with these settings. Create a new source file for it.
sourceFile := parser.ParseSourceFile(scriptInfo.fileName, scriptInfo.path, scriptInfo.text, scriptTarget, scanner.JSDocParsingModeParseAll)
sourceFile.Version = scriptInfo.version
sourceFile.Version = scriptInfo.Version()
entry, _ := r.documents.LoadOrStore(key, &registryEntry{
sourceFile: sourceFile,
refCount: 0,
Expand Down
15 changes: 14 additions & 1 deletion internal/project/host.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
package project

import "github.com/microsoft/typescript-go/internal/vfs"
import (
"github.com/microsoft/typescript-go/internal/lsp/lsproto"
"github.com/microsoft/typescript-go/internal/vfs"
)

type WatcherHandle string

type Client interface {
WatchFiles(watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error)
UnwatchFiles(handle WatcherHandle) error
RefreshDiagnostics() error
}

type ServiceHost interface {
FS() vfs.FS
DefaultLibraryPath() string
GetCurrentDirectory() string
NewLine() string

Client() Client
}
Loading